diff --git a/execution.py b/execution.py index 9a5e27771..1aff9276c 100644 --- a/execution.py +++ b/execution.py @@ -23,6 +23,22 @@ class ExecutionResult(Enum): FAILURE = 1 PENDING = 2 +# Add global variables to track the number of nodes +total_nodes = 0 +executed_nodes = 0 + +# Add a function to clear statistical information +def reset_node_counts(): + global total_nodes, executed_nodes + total_nodes = 0 + executed_nodes = 0 + +# Add a function to count the total number of nodes +def count_total_nodes(task_graph): + global total_nodes + total_nodes = len(task_graph) +#end + class DuplicateNodeError(Exception): pass @@ -39,6 +55,9 @@ class IsChangedCache: node = self.dynprompt.get_node(node_id) class_type = node["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + #Initialization statistics + reset_node_counts() + #Initialization statistics if not hasattr(class_def, "IS_CHANGED"): self.is_changed[node_id] = False return self.is_changed[node_id] @@ -271,6 +290,10 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + #execute node count + global executed_nodes + executed_nodes+=1 + #execute node count if caches.outputs.get(unique_id) is not None: if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} @@ -502,6 +525,11 @@ class PromptExecutor: for node_id in prompt: if self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) + #Count the number of task nodes + else: + global total_nodes + total_nodes+=1 + #Count the number of task nodes comfy.model_management.cleanup_models_gc() self.add_message("execution_cached", diff --git a/server.py b/server.py index 62667ce18..1977b3a66 100644 --- a/server.py +++ b/server.py @@ -189,6 +189,16 @@ class PromptServer(): self.on_prompt_handlers = [] + #add api get progress status + async def get_node_status(request): + return web.json_response({ + "total_nodes": execution.total_nodes, + "executed_nodes": execution.executed_nodes + }) + routes.get('/node_status')(get_node_status) + self.app.add_routes(routes) + #add api get progress status + @routes.get('/ws') async def websocket_handler(request): ws = web.WebSocketResponse()