From 1b3d65bd84c8026dea234643861491279886218c Mon Sep 17 00:00:00 2001
From: realazthat <realazthat@gmail.com>
Date: Thu, 11 Jan 2024 08:38:18 -0500
Subject: [PATCH] Add error, status to /history endpoint

---
 execution.py | 70 +++++++++++++++++++++++++++++++++-------------------
 main.py      |  7 +++++-
 2 files changed, 51 insertions(+), 26 deletions(-)

diff --git a/execution.py b/execution.py
index 260a0897..554e8dc1 100644
--- a/execution.py
+++ b/execution.py
@@ -8,6 +8,7 @@ import heapq
 import traceback
 import gc
 import inspect
+from typing import List, Literal, NamedTuple, Optional
 
 import torch
 import nodes
@@ -275,8 +276,15 @@ class PromptExecutor:
         self.outputs = {}
         self.object_storage = {}
         self.outputs_ui = {}
+        self.status_notes = []
+        self.success = True
         self.old_prompt = {}
 
+    def add_note(self, event, data, broadcast: bool):
+        self.status_notes.append((event, data))
+        if self.server.client_id is not None or broadcast:
+            self.server.send_sync(event, data, self.server.client_id)
+
     def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
         node_id = error["node_id"]
         class_type = prompt[node_id]["class_type"]
@@ -290,23 +298,22 @@ class PromptExecutor:
                 "node_type": class_type,
                 "executed": list(executed),
             }
-            self.server.send_sync("execution_interrupted", mes, self.server.client_id)
+            self.add_note("execution_interrupted", mes, broadcast=True)
         else:
-            if self.server.client_id is not None:
-                mes = {
-                    "prompt_id": prompt_id,
-                    "node_id": node_id,
-                    "node_type": class_type,
-                    "executed": list(executed),
-
-                    "exception_message": error["exception_message"],
-                    "exception_type": error["exception_type"],
-                    "traceback": error["traceback"],
-                    "current_inputs": error["current_inputs"],
-                    "current_outputs": error["current_outputs"],
-                }
-                self.server.send_sync("execution_error", mes, self.server.client_id)
+            mes = {
+                "prompt_id": prompt_id,
+                "node_id": node_id,
+                "node_type": class_type,
+                "executed": list(executed),
 
+                "exception_message": error["exception_message"],
+                "exception_type": error["exception_type"],
+                "traceback": error["traceback"],
+                "current_inputs": error["current_inputs"],
+                "current_outputs": error["current_outputs"],
+            }
+            self.add_note("execution_error", mes, broadcast=False)
+        
         # Next, remove the subsequent outputs since they will not be executed
         to_delete = []
         for o in self.outputs:
@@ -327,8 +334,7 @@ class PromptExecutor:
         else:
             self.server.client_id = None
 
-        if self.server.client_id is not None:
-            self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id)
+        self.add_note("execution_start", { "prompt_id": prompt_id}, broadcast=False)
 
         with torch.inference_mode():
             #delete cached outputs if nodes don't exist for them
@@ -361,8 +367,9 @@ class PromptExecutor:
                     del d
 
             comfy.model_management.cleanup_models()
-            if self.server.client_id is not None:
-                self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id)
+            self.add_note("execution_cached",
+                          { "nodes": list(current_outputs) , "prompt_id": prompt_id},
+                          broadcast=False)
             executed = set()
             output_node_id = None
             to_execute = []
@@ -378,8 +385,8 @@ class PromptExecutor:
                 # This call shouldn't raise anything if there's an error deep in
                 # the actual SD code, instead it will report the node where the
                 # error was raised
-                success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
-                if success is not True:
+                self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage)
+                if self.success is not True:
                     self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
                     break
 
@@ -731,14 +738,27 @@ class PromptQueue:
             self.server.queue_updated()
             return (item, i)
 
-    def task_done(self, item_id, outputs):
+    class ExecutionStatus(NamedTuple):
+        status_str: Literal['success', 'error']
+        completed: bool
+        notes: List[str]
+
+    def task_done(self, item_id, outputs,
+                  status: Optional['PromptQueue.ExecutionStatus']):
         with self.mutex:
             prompt = self.currently_running.pop(item_id)
             if len(self.history) > MAXIMUM_HISTORY_SIZE:
                 self.history.pop(next(iter(self.history)))
-            self.history[prompt[1]] = { "prompt": prompt, "outputs": {} }
-            for o in outputs:
-                self.history[prompt[1]]["outputs"][o] = outputs[o]
+
+            status_dict: dict|None = None
+            if status is not None:
+                status_dict = copy.deepcopy(status._asdict())
+
+            self.history[prompt[1]] = {
+                "prompt": prompt,
+                "outputs": copy.deepcopy(outputs),
+                'status': status_dict,
+            }
             self.server.queue_updated()
 
     def get_current_queue(self):
diff --git a/main.py b/main.py
index 45d5d41b..a40ad2a4 100644
--- a/main.py
+++ b/main.py
@@ -110,7 +110,12 @@ def prompt_worker(q, server):
 
             e.execute(item[2], prompt_id, item[3], item[4])
             need_gc = True
-            q.task_done(item_id, e.outputs_ui)
+            q.task_done(item_id,
+                        e.outputs_ui,
+                        status=execution.PromptQueue.ExecutionStatus(
+                            status_str='success' if e.success else 'error',
+                            completed=e.success,
+                            notes=e.status_notes))
             if server.client_id is not None:
                 server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)