diff --git a/server.py b/server.py index 76a99167..19568622 100644 --- a/server.py +++ b/server.py @@ -551,35 +551,30 @@ class PromptServer(): @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) - + # use getattr speedup 2x times in load node info def node_info(node_class): obj_class = nodes.NODE_CLASS_MAPPINGS[node_class] - info = {} - info['input'] = obj_class.INPUT_TYPES() - info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} - info['output'] = obj_class.RETURN_TYPES - info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) - info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] - info['name'] = node_class - info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class - info['description'] = obj_class.DESCRIPTION if hasattr(obj_class,'DESCRIPTION') else '' - info['python_module'] = getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes") - info['category'] = 'sd' - if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: - info['output_node'] = True - else: - info['output_node'] = False - - if hasattr(obj_class, 'CATEGORY'): - info['category'] = obj_class.CATEGORY - + input_types = obj_class.INPUT_TYPES() + + info = { + 'input': input_types, + 'input_order': {key: list(value.keys()) for key, value in input_types.items()}, + 'output': obj_class.RETURN_TYPES, + 'output_is_list': getattr(obj_class, 'OUTPUT_IS_LIST', [False] * len(obj_class.RETURN_TYPES)), + 'output_name': getattr(obj_class, 'RETURN_NAMES', obj_class.RETURN_TYPES), + 'name': node_class, + 'display_name': nodes.NODE_DISPLAY_NAME_MAPPINGS.get(node_class, node_class), + 'description': getattr(obj_class, 'DESCRIPTION', ''), + 'python_module': getattr(obj_class, "RELATIVE_PYTHON_MODULE", "nodes"), + 'category': getattr(obj_class, 'CATEGORY', 'sd'), + 'output_node': hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE, + 'deprecated': getattr(obj_class, "DEPRECATED", False), + 'experimental': getattr(obj_class, "EXPERIMENTAL", False) + } + if hasattr(obj_class, 'OUTPUT_TOOLTIPS'): info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS - - if getattr(obj_class, "DEPRECATED", False): - info['deprecated'] = True - if getattr(obj_class, "EXPERIMENTAL", False): - info['experimental'] = True + return info @routes.get("/object_info") @@ -592,7 +587,9 @@ class PromptServer(): except Exception: logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(traceback.format_exc()) - return web.json_response(out) + response = web.json_response(out) + response.enable_compression() + return response @routes.get("/object_info/{node_class}") async def get_object_info_node(request): @@ -600,19 +597,27 @@ class PromptServer(): out = {} if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS): out[node_class] = node_info(node_class) - return web.json_response(out) + response = web.json_response(out) + response.enable_compression() + return response @routes.get("/history") async def get_history(request): max_items = request.rel_url.query.get("max_items", None) if max_items is not None: max_items = int(max_items) - return web.json_response(self.prompt_queue.get_history(max_items=max_items)) + history_data = self.prompt_queue.get_history(max_items=max_items) + response = web.json_response(history_data) + response.enable_compression() + return response @routes.get("/history/{prompt_id}") async def get_history_prompt_id(request): prompt_id = request.match_info.get("prompt_id", None) - return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) + history_prompt_id_data = self.prompt_queue.get_history(prompt_id=prompt_id) + response = web.json_response(history_prompt_id_data) + response.enable_compression() + return response @routes.get("/queue") async def get_queue(request): @@ -620,7 +625,9 @@ class PromptServer(): current_queue = self.prompt_queue.get_current_queue() queue_info['queue_running'] = current_queue[0] queue_info['queue_pending'] = current_queue[1] - return web.json_response(queue_info) + response = web.json_response(queue_info) + response.enable_compression() + return response @routes.post("/prompt") async def post_prompt(request):