Add API polling to check the execution progress status of the current task

This commit is contained in:
hmwl 2024-11-09 16:31:46 +08:00
parent 6ee066a14f
commit beb66a7f33
2 changed files with 39 additions and 0 deletions

View File

@ -23,6 +23,22 @@ class ExecutionResult(Enum):
FAILURE = 1
PENDING = 2
# Add global variables to track the number of nodes
total_nodes = 0
executed_nodes = 0
# Add a function to clear statistical information
def reset_node_counts():
global total_nodes, executed_nodes
total_nodes = 0
executed_nodes = 0
# Add a function to count the total number of nodes
def count_total_nodes(task_graph):
global total_nodes
total_nodes = len(task_graph)
#end
class DuplicateNodeError(Exception):
pass
@ -39,6 +55,10 @@ class IsChangedCache:
node = self.dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
#Initialization statistics <start>
# print("任务包含的节点",node_id)
reset_node_counts()
#Initialization statistics <end>
if not hasattr(class_def, "IS_CHANGED"):
self.is_changed[node_id] = False
return self.is_changed[node_id]
@ -249,6 +269,10 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
#execute node count <start>
global executed_nodes
executed_nodes+=1
#execute node count <end>
if caches.outputs.get(unique_id) is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
@ -479,6 +503,11 @@ class PromptExecutor:
for node_id in prompt:
if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id)
#Count the number of task nodes <start>
else:
global total_nodes
total_nodes+=1
#Count the number of task nodes <end>
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",

View File

@ -182,6 +182,16 @@ class PromptServer():
self.on_prompt_handlers = []
#add api get progress status <start>
async def get_node_status(request):
return web.json_response({
"total_nodes": execution.total_nodes,
"executed_nodes": execution.executed_nodes
})
routes.get('/node_status')(get_node_status)
self.app.add_routes(routes)
#add api get progress status <end>
@routes.get('/ws')
async def websocket_handler(request):
ws = web.WebSocketResponse()