update to latest comfy

This commit is contained in:
drunkplato 2025-03-06 12:02:34 +00:00
parent 3f7db39fda
commit c827cba127
4 changed files with 14 additions and 49 deletions

View File

@ -1,4 +1,3 @@
<<<<<<< HEAD
import logging
import math
import torch
@ -8,15 +7,6 @@ from typing import Any, Dict, Tuple, Union
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
=======
import torch
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config
>>>>>>> 0e1536b4 (logic to upload images from this server)
from comfy.ldm.modules.ema import LitEma
import comfy.ops
@ -64,11 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
<<<<<<< HEAD
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
=======
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
>>>>>>> 0e1536b4 (logic to upload images from this server)
def get_input(self, batch) -> Any:
raise NotImplementedError()
@ -84,22 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
<<<<<<< HEAD
logging.info(f"{context}: Switched to EMA weights")
=======
logpy.info(f"{context}: Switched to EMA weights")
>>>>>>> 0e1536b4 (logic to upload images from this server)
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
<<<<<<< HEAD
logging.info(f"{context}: Restored training weights")
=======
logpy.info(f"{context}: Restored training weights")
>>>>>>> 0e1536b4 (logic to upload images from this server)
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@ -108,11 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
<<<<<<< HEAD
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
=======
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
>>>>>>> 0e1536b4 (logic to upload images from this server)
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
@ -140,11 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
<<<<<<< HEAD
self.regularization = instantiate_from_config(
=======
self.regularization: AbstractRegularizer = instantiate_from_config(
>>>>>>> 0e1536b4 (logic to upload images from this server)
regularizer_config
)
@ -192,7 +162,6 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
**kwargs,
)
<<<<<<< HEAD
if ddconfig.get("conv3d", False):
conv_op = comfy.ops.disable_weight_init.Conv3d
@ -200,19 +169,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
conv_op = comfy.ops.disable_weight_init.Conv2d
self.quant_conv = conv_op(
=======
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
>>>>>>> 0e1536b4 (logic to upload images from this server)
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
<<<<<<< HEAD
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
=======
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
>>>>>>> 0e1536b4 (logic to upload images from this server)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:

23
main.py
View File

@ -161,9 +161,9 @@ def cuda_malloc_warning():
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server, memedeck_worker):
def prompt_worker(q, server_instance, memedeck_worker):
current_time: float = 0.0
e = execution.PromptExecutor(server, memedeck_worker, lru_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, memedeck_worker, lru_size=args.cache_lru)
# threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start()
last_gc_collect = 0
@ -181,6 +181,8 @@ def prompt_worker(q, server, memedeck_worker):
execution_start_time = time.perf_counter()
prompt_id = item[1]
server_instance.last_prompt_id = prompt_id
print(item[2])
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
@ -228,7 +230,7 @@ async def run(server_instance, memedeck_worker, address='', port=8188, verbose=T
)
def hijack_progress(server_instance):
def hijack_progress(server_instance, memedeck_worker):
def hook(value, total, preview_image):
comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
@ -312,9 +314,9 @@ def start_comfyui(asyncio_loop=None):
cuda_malloc_warning()
prompt_server.add_routes()
hijack_progress(server, memedeck_worker)
hijack_progress(prompt_server, memedeck_worker)
threading.Thread(target=prompt_worker, daemon=True, args=(q, server, memedeck_worker)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server, memedeck_worker)).start()
threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start()
# set logging level to info
@ -357,20 +359,21 @@ def start_comfyui(asyncio_loop=None):
async def start_all():
await prompt_server.setup()
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
await run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
# await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
await run(prompt_server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
# Returning these so that other code can integrate with the ComfyUI loop and server
return asyncio_loop, prompt_server, start_all
return asyncio_loop, prompt_server, memedeck_worker, start_all
if __name__ == "__main__":
# Running directly, just start ComfyUI.
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
event_loop, _, start_all_func = start_comfyui()
event_loop, server, memedeck_worker, start_all_func = start_comfyui()
try:
event_loop.run_until_complete(start_all_func())
# event_loop.run_until_complete(run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
# event_loop.run_until_complete(start_all_func())
# event_loop.run_until_complete(run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
except KeyboardInterrupt:
logging.info("\nStopped server")

View File

@ -70,7 +70,7 @@ class MemedeckWorker:
self.azure_storage = MemedeckAzureStorage()
if self.queue_name == 'video-gen-queue':
if self.queue_name == 'video-gen-queue' or self.queue_name == 'scene-gen-queue':
print(f"[memedeck]: video gen only mode enabled")
self.video_gen_only = True