mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-14 23:53:30 +00:00
Add API polling to check the execution progress status of the current task
This commit is contained in:
parent
6ee066a14f
commit
beb66a7f33
29
execution.py
29
execution.py
@ -23,6 +23,22 @@ class ExecutionResult(Enum):
|
|||||||
FAILURE = 1
|
FAILURE = 1
|
||||||
PENDING = 2
|
PENDING = 2
|
||||||
|
|
||||||
|
# Add global variables to track the number of nodes
|
||||||
|
total_nodes = 0
|
||||||
|
executed_nodes = 0
|
||||||
|
|
||||||
|
# Add a function to clear statistical information
|
||||||
|
def reset_node_counts():
|
||||||
|
global total_nodes, executed_nodes
|
||||||
|
total_nodes = 0
|
||||||
|
executed_nodes = 0
|
||||||
|
|
||||||
|
# Add a function to count the total number of nodes
|
||||||
|
def count_total_nodes(task_graph):
|
||||||
|
global total_nodes
|
||||||
|
total_nodes = len(task_graph)
|
||||||
|
#end
|
||||||
|
|
||||||
class DuplicateNodeError(Exception):
|
class DuplicateNodeError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -39,6 +55,10 @@ class IsChangedCache:
|
|||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
#Initialization statistics <start>
|
||||||
|
# print("任务包含的节点",node_id)
|
||||||
|
reset_node_counts()
|
||||||
|
#Initialization statistics <end>
|
||||||
if not hasattr(class_def, "IS_CHANGED"):
|
if not hasattr(class_def, "IS_CHANGED"):
|
||||||
self.is_changed[node_id] = False
|
self.is_changed[node_id] = False
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
@ -249,6 +269,10 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
#execute node count <start>
|
||||||
|
global executed_nodes
|
||||||
|
executed_nodes+=1
|
||||||
|
#execute node count <end>
|
||||||
if caches.outputs.get(unique_id) is not None:
|
if caches.outputs.get(unique_id) is not None:
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
cached_output = caches.ui.get(unique_id) or {}
|
cached_output = caches.ui.get(unique_id) or {}
|
||||||
@ -479,6 +503,11 @@ class PromptExecutor:
|
|||||||
for node_id in prompt:
|
for node_id in prompt:
|
||||||
if self.caches.outputs.get(node_id) is not None:
|
if self.caches.outputs.get(node_id) is not None:
|
||||||
cached_nodes.append(node_id)
|
cached_nodes.append(node_id)
|
||||||
|
#Count the number of task nodes <start>
|
||||||
|
else:
|
||||||
|
global total_nodes
|
||||||
|
total_nodes+=1
|
||||||
|
#Count the number of task nodes <end>
|
||||||
|
|
||||||
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||||
self.add_message("execution_cached",
|
self.add_message("execution_cached",
|
||||||
|
10
server.py
10
server.py
@ -182,6 +182,16 @@ class PromptServer():
|
|||||||
|
|
||||||
self.on_prompt_handlers = []
|
self.on_prompt_handlers = []
|
||||||
|
|
||||||
|
#add api get progress status <start>
|
||||||
|
async def get_node_status(request):
|
||||||
|
return web.json_response({
|
||||||
|
"total_nodes": execution.total_nodes,
|
||||||
|
"executed_nodes": execution.executed_nodes
|
||||||
|
})
|
||||||
|
routes.get('/node_status')(get_node_status)
|
||||||
|
self.app.add_routes(routes)
|
||||||
|
#add api get progress status <end>
|
||||||
|
|
||||||
@routes.get('/ws')
|
@routes.get('/ws')
|
||||||
async def websocket_handler(request):
|
async def websocket_handler(request):
|
||||||
ws = web.WebSocketResponse()
|
ws = web.WebSocketResponse()
|
||||||
|
Loading…
Reference in New Issue
Block a user