From d40ac012bc7eaa68a9f2cfa9a85a0c221d2f2b14 Mon Sep 17 00:00:00 2001 From: meimeilook Date: Fri, 14 Mar 2025 15:39:21 +0800 Subject: [PATCH 1/4] 1.use getattr speedup 2x times in load node info when you opening browser at first time. 2.enable gzip in return response, make json files 10x smaller --- server.py | 67 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/server.py b/server.py index 76a99167..19568622 100644 --- a/server.py +++ b/server.py @@ -551,35 +551,30 @@ class PromptServer(): @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) - + # use getattr speedup 2x times in load node info def node_info(node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = node_class - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class - info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' - info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes") - info['category'] = 'sd' - if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: - info['output_node'] = True - else: - info['output_node'] = False - - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - + input_types = obj_class.INPUT_TYPES() + + info = { + 'input': input_types, + 'input_order': {key: list(value.keys()) for key, value in input_types.items()}, + 'output': obj_class.RETURN_TYPES, + 'output_is_list': getattr(obj_class, 'OUTPUT_IS_LIST', [False] * len(obj_class.RETURN_TYPES)), + 'output_name': getattr(obj_class, 'RETURN_NAMES', obj_class.RETURN_TYPES), + 'name': node_class, + 'display_name': nodes.NODE_DISPLAY_NAME_MAPPINGS.get(node_class, node_class), + 'description': getattr(obj_class, 'DESCRIPTION', ''), + 'python_module': getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes"), + 'category': getattr(obj_class, 'CATEGORY', 'sd'), + 'output_node': hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE, + 'deprecated': getattr(obj_class, "DEPRECATED", False), + 'experimental': getattr(obj_class, "EXPERIMENTAL", False) + } + if hasattr(obj_class, 'OUTPUT_TOOLTIPS'): info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS - - if getattr(obj_class, "DEPRECATED", False): - info['deprecated'] = True - if getattr(obj_class, "EXPERIMENTAL", False): - info['experimental'] = True + return info @routes.get("/object_info") @@ -592,7 +587,9 @@ class PromptServer(): except Exception: logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(traceback.format_exc()) - return web.json_response(out) + response = web.json_response(out) + response.enable_compression() + return response @routes.get("/object_info/{node_class}") async def get_object_info_node(request): @@ -600,19 +597,27 @@ class PromptServer(): out = {} if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): out[node_class] = node_info(node_class) - return web.json_response(out) + response = web.json_response(out) + response.enable_compression() + return response @routes.get("/history") async def get_history(request): max_items = request.rel_url.query.get("max_items", None) if max_items is not None: max_items = int(max_items) - return web.json_response(self.prompt_queue.get_history(max_items=max_items)) + history_data = self.prompt_queue.get_history(max_items=max_items) + response = web.json_response(history_data) + response.enable_compression() + return response @routes.get("/history/{prompt_id}") async def get_history_prompt_id(request): prompt_id = request.match_info.get("prompt_id", None) - return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) + history_prompt_id_data = self.prompt_queue.get_history(prompt_id=prompt_id) + response = web.json_response(history_prompt_id_data) + response.enable_compression() + return response @routes.get("/queue") async def get_queue(request): @@ -620,7 +625,9 @@ class PromptServer(): current_queue = self.prompt_queue.get_current_queue() queue_info['queue_running'] = current_queue[0] queue_info['queue_pending'] = current_queue[1] - return web.json_response(queue_info) + response = web.json_response(queue_info) + response.enable_compression() + return response @routes.post("/prompt") async def post_prompt(request): From 4b2b24906a288e82c28dc0fe9b1b1d30ccbeefc5 Mon Sep 17 00:00:00 2001 From: meimeilook Date: Sat, 15 Mar 2025 14:08:57 +0800 Subject: [PATCH 2/4] Put node_info caching in memory,and now 30x speedup when loading the page. --- server.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/server.py b/server.py index 19568622..a3d4b94a 100644 --- a/server.py +++ b/server.py @@ -21,6 +21,7 @@ from io import BytesIO import aiohttp from aiohttp import web import logging +from functools import lru_cache import mimetypes from comfy.cli_args import args @@ -552,6 +553,7 @@ class PromptServer(): async def get_prompt(request): return web.json_response(self.get_queue_info()) # use getattr speedup 2x times in load node info + @lru_cache(maxsize=None) def node_info(node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] input_types = obj_class.INPUT_TYPES() @@ -587,6 +589,11 @@ class PromptServer(): except Exception: logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(traceback.format_exc()) + + # Debug node_info in the current memory cache + #cache_stats = node_info.cache_info() + #print(f"node_info Cache Hits: {cache_stats.hits}, Misses: {cache_stats.misses}, Current Memory Cache Size: {cache_stats.currsize}") + response = web.json_response(out) response.enable_compression() return response From 01ceafa498479b8966cc71851374b797c1680c9e Mon Sep 17 00:00:00 2001 From: meimeilook Date: Sat, 15 Mar 2025 17:45:35 +0800 Subject: [PATCH 3/4] Given that gzip compression parameters already exist in cli args, revert the hardcoded code for this --- server.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/server.py b/server.py index a3d4b94a..c06f5acd 100644 --- a/server.py +++ b/server.py @@ -589,42 +589,32 @@ class PromptServer(): except Exception: logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(traceback.format_exc()) - + # Debug node_info in the current memory cache #cache_stats = node_info.cache_info() #print(f"node_info Cache Hits: {cache_stats.hits}, Misses: {cache_stats.misses}, Current Memory Cache Size: {cache_stats.currsize}") - response = web.json_response(out) - response.enable_compression() - return response - + return web.json_response(out) + @routes.get("/object_info/{node_class}") async def get_object_info_node(request): node_class = request.match_info.get("node_class", None) out = {} if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): out[node_class] = node_info(node_class) - response = web.json_response(out) - response.enable_compression() - return response + return web.json_response(out) @routes.get("/history") async def get_history(request): max_items = request.rel_url.query.get("max_items", None) if max_items is not None: max_items = int(max_items) - history_data = self.prompt_queue.get_history(max_items=max_items) - response = web.json_response(history_data) - response.enable_compression() - return response + return web.json_response(self.prompt_queue.get_history(max_items=max_items)) @routes.get("/history/{prompt_id}") async def get_history_prompt_id(request): prompt_id = request.match_info.get("prompt_id", None) - history_prompt_id_data = self.prompt_queue.get_history(prompt_id=prompt_id) - response = web.json_response(history_prompt_id_data) - response.enable_compression() - return response + return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) @routes.get("/queue") async def get_queue(request): @@ -632,9 +622,7 @@ class PromptServer(): current_queue = self.prompt_queue.get_current_queue() queue_info['queue_running'] = current_queue[0] queue_info['queue_pending'] = current_queue[1] - response = web.json_response(queue_info) - response.enable_compression() - return response + return web.json_response(queue_info) @routes.post("/prompt") async def post_prompt(request): From e7a23f95f31d1b16ccc741a4eb33d963d4c984e3 Mon Sep 17 00:00:00 2001 From: meimeilook Date: Mon, 24 Mar 2025 12:52:36 +0800 Subject: [PATCH 4/4] Cancel memory cache; it will not refresh when the refresh button is clicked --- server.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/server.py b/server.py index c06f5acd..2703564d 100644 --- a/server.py +++ b/server.py @@ -21,7 +21,6 @@ from io import BytesIO import aiohttp from aiohttp import web import logging -from functools import lru_cache import mimetypes from comfy.cli_args import args @@ -552,8 +551,7 @@ class PromptServer(): @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) - # use getattr speedup 2x times in load node info - @lru_cache(maxsize=None) + def node_info(node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] input_types = obj_class.INPUT_TYPES() @@ -590,10 +588,6 @@ class PromptServer(): logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(traceback.format_exc()) - # Debug node_info in the current memory cache - #cache_stats = node_info.cache_info() - #print(f"node_info Cache Hits: {cache_stats.hits}, Misses: {cache_stats.misses}, Current Memory Cache Size: {cache_stats.currsize}") - return web.json_response(out) @routes.get("/object_info/{node_class}")