1.use getattr speedup 2x times in load node info when you opening browser at first time.

2.enable gzip in return response, make json files 10x smaller
This commit is contained in:
meimeilook 2025-03-14 15:39:21 +08:00
parent 35504e2f93
commit d40ac012bc

View File

@ -551,35 +551,30 @@ class PromptServer():
@routes.get("/prompt") @routes.get("/prompt")
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())
# use getattr speedup 2x times in load node info
def node_info(node_class): def node_info(node_class):
obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] obj_class = nodes.NODE_CLASS_MAPPINGS[node_class]
info = {} input_types = obj_class.INPUT_TYPES()
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
info['name'] = node_class
info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class
info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else ''
info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes")
info['category'] = 'sd'
if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True:
info['output_node'] = True
else:
info['output_node'] = False
if hasattr(obj_class, 'CATEGORY'): info = {
info['category'] = obj_class.CATEGORY 'input': input_types,
'input_order': {key: list(value.keys()) for key, value in input_types.items()},
'output': obj_class.RETURN_TYPES,
'output_is_list': getattr(obj_class, 'OUTPUT_IS_LIST', [False] * len(obj_class.RETURN_TYPES)),
'output_name': getattr(obj_class, 'RETURN_NAMES', obj_class.RETURN_TYPES),
'name': node_class,
'display_name': nodes.NODE_DISPLAY_NAME_MAPPINGS.get(node_class, node_class),
'description': getattr(obj_class, 'DESCRIPTION', ''),
'python_module': getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes"),
'category': getattr(obj_class, 'CATEGORY', 'sd'),
'output_node': hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE,
'deprecated': getattr(obj_class, "DEPRECATED", False),
'experimental': getattr(obj_class, "EXPERIMENTAL", False)
}
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'): if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
if getattr(obj_class, "DEPRECATED", False):
info['deprecated'] = True
if getattr(obj_class, "EXPERIMENTAL", False):
info['experimental'] = True
return info return info
@routes.get("/object_info") @routes.get("/object_info")
@ -592,7 +587,9 @@ class PromptServer():
except Exception: except Exception:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
return web.json_response(out) response = web.json_response(out)
response.enable_compression()
return response
@routes.get("/object_info/{node_class}") @routes.get("/object_info/{node_class}")
async def get_object_info_node(request): async def get_object_info_node(request):
@ -600,19 +597,27 @@ class PromptServer():
out = {} out = {}
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
out[node_class] = node_info(node_class) out[node_class] = node_info(node_class)
return web.json_response(out) response = web.json_response(out)
response.enable_compression()
return response
@routes.get("/history") @routes.get("/history")
async def get_history(request): async def get_history(request):
max_items = request.rel_url.query.get("max_items", None) max_items = request.rel_url.query.get("max_items", None)
if max_items is not None: if max_items is not None:
max_items = int(max_items) max_items = int(max_items)
return web.json_response(self.prompt_queue.get_history(max_items=max_items)) history_data = self.prompt_queue.get_history(max_items=max_items)
response = web.json_response(history_data)
response.enable_compression()
return response
@routes.get("/history/{prompt_id}") @routes.get("/history/{prompt_id}")
async def get_history_prompt_id(request): async def get_history_prompt_id(request):
prompt_id = request.match_info.get("prompt_id", None) prompt_id = request.match_info.get("prompt_id", None)
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) history_prompt_id_data = self.prompt_queue.get_history(prompt_id=prompt_id)
response = web.json_response(history_prompt_id_data)
response.enable_compression()
return response
@routes.get("/queue") @routes.get("/queue")
async def get_queue(request): async def get_queue(request):
@ -620,7 +625,9 @@ class PromptServer():
current_queue = self.prompt_queue.get_current_queue() current_queue = self.prompt_queue.get_current_queue()
queue_info['queue_running'] = current_queue[0] queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1] queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info) response = web.json_response(queue_info)
response.enable_compression()
return response
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):