From beb66a7f337dca22f9e512e412f04d16ebeb822d Mon Sep 17 00:00:00 2001 From: hmwl <402245847@qq.com> Date: Sat, 9 Nov 2024 16:31:46 +0800 Subject: [PATCH 1/2] Add API polling to check the execution progress status of the current task --- execution.py | 29 +++++++++++++++++++++++++++++ server.py | 10 ++++++++++ 2 files changed, 39 insertions(+) diff --git a/execution.py b/execution.py index 6c386341b..f23bdd221 100644 --- a/execution.py +++ b/execution.py @@ -23,6 +23,22 @@ class ExecutionResult(Enum): FAILURE = 1 PENDING = 2 +# Add global variables to track the number of nodes +total_nodes = 0 +executed_nodes = 0 + +# Add a function to clear statistical information +def reset_node_counts(): + global total_nodes, executed_nodes + total_nodes = 0 + executed_nodes = 0 + +# Add a function to count the total number of nodes +def count_total_nodes(task_graph): + global total_nodes + total_nodes = len(task_graph) +#end + class DuplicateNodeError(Exception): pass @@ -39,6 +55,10 @@ class IsChangedCache: node = self.dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + #Initialization statistics + # print("任务包含的节点",node_id) + reset_node_counts() + #Initialization statistics if not hasattr(class_def, "IS_CHANGED"): self.is_changed[node_id] = False return self.is_changed[node_id] @@ -249,6 +269,10 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + #execute node count + global executed_nodes + executed_nodes+=1 + #execute node count if caches.outputs.get(unique_id) is not None: if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} @@ -479,6 +503,11 @@ class PromptExecutor: for node_id in prompt: if self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) + #Count the number of task nodes + else: + global total_nodes + total_nodes+=1 + #Count the number of task nodes comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) self.add_message("execution_cached", diff --git a/server.py b/server.py index e663095bc..5aa2e7902 100644 --- a/server.py +++ b/server.py @@ -182,6 +182,16 @@ class PromptServer(): self.on_prompt_handlers = [] + #add api get progress status + async def get_node_status(request): + return web.json_response({ + "total_nodes": execution.total_nodes, + "executed_nodes": execution.executed_nodes + }) + routes.get('/node_status')(get_node_status) + self.app.add_routes(routes) + #add api get progress status + @routes.get('/ws') async def websocket_handler(request): ws = web.WebSocketResponse() From 84104f608311b75277a1140686b2a183719ea7b7 Mon Sep 17 00:00:00 2001 From: hmwl <402245847@qq.com> Date: Sat, 9 Nov 2024 16:33:32 +0800 Subject: [PATCH 2/2] Add API polling to check the execution progress status of the current task --- execution.py | 1 - 1 file changed, 1 deletion(-) diff --git a/execution.py b/execution.py index f23bdd221..c38e00af8 100644 --- a/execution.py +++ b/execution.py @@ -56,7 +56,6 @@ class IsChangedCache: class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] #Initialization statistics - # print("任务包含的节点",node_id) reset_node_counts() #Initialization statistics if not hasattr(class_def, "IS_CHANGED"):