mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 14:13:29 +00:00
Add API polling to check the execution progress status of the current task
This commit is contained in:
parent
6ee066a14f
commit
beb66a7f33
29
execution.py
29
execution.py
@ -23,6 +23,22 @@ class ExecutionResult(Enum):
|
||||
FAILURE = 1
|
||||
PENDING = 2
|
||||
|
||||
# Add global variables to track the number of nodes
|
||||
total_nodes = 0
|
||||
executed_nodes = 0
|
||||
|
||||
# Add a function to clear statistical information
|
||||
def reset_node_counts():
|
||||
global total_nodes, executed_nodes
|
||||
total_nodes = 0
|
||||
executed_nodes = 0
|
||||
|
||||
# Add a function to count the total number of nodes
|
||||
def count_total_nodes(task_graph):
|
||||
global total_nodes
|
||||
total_nodes = len(task_graph)
|
||||
#end
|
||||
|
||||
class DuplicateNodeError(Exception):
|
||||
pass
|
||||
|
||||
@ -39,6 +55,10 @@ class IsChangedCache:
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
#Initialization statistics <start>
|
||||
# print("任务包含的节点",node_id)
|
||||
reset_node_counts()
|
||||
#Initialization statistics <end>
|
||||
if not hasattr(class_def, "IS_CHANGED"):
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
@ -249,6 +269,10 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
#execute node count <start>
|
||||
global executed_nodes
|
||||
executed_nodes+=1
|
||||
#execute node count <end>
|
||||
if caches.outputs.get(unique_id) is not None:
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
@ -479,6 +503,11 @@ class PromptExecutor:
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
#Count the number of task nodes <start>
|
||||
else:
|
||||
global total_nodes
|
||||
total_nodes+=1
|
||||
#Count the number of task nodes <end>
|
||||
|
||||
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||
self.add_message("execution_cached",
|
||||
|
10
server.py
10
server.py
@ -182,6 +182,16 @@ class PromptServer():
|
||||
|
||||
self.on_prompt_handlers = []
|
||||
|
||||
#add api get progress status <start>
|
||||
async def get_node_status(request):
|
||||
return web.json_response({
|
||||
"total_nodes": execution.total_nodes,
|
||||
"executed_nodes": execution.executed_nodes
|
||||
})
|
||||
routes.get('/node_status')(get_node_status)
|
||||
self.app.add_routes(routes)
|
||||
#add api get progress status <end>
|
||||
|
||||
@routes.get('/ws')
|
||||
async def websocket_handler(request):
|
||||
ws = web.WebSocketResponse()
|
||||
|
Loading…
Reference in New Issue
Block a user