mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
update to latest comfy
This commit is contained in:
parent
3f7db39fda
commit
c827cba127
Binary file not shown.
@ -1,4 +1,3 @@
|
||||
<<<<<<< HEAD
|
||||
import logging
|
||||
import math
|
||||
import torch
|
||||
@ -8,15 +7,6 @@ from typing import Any, Dict, Tuple, Union
|
||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
|
||||
=======
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from comfy.ldm.util import instantiate_from_config
|
||||
>>>>>>> 0e1536b4 (logic to upload images from this server)
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
import comfy.ops
|
||||
|
||||
@ -64,11 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
<<<<<<< HEAD
|
||||
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
=======
|
||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
>>>>>>> 0e1536b4 (logic to upload images from this server)
|
||||
|
||||
def get_input(self, batch) -> Any:
|
||||
raise NotImplementedError()
|
||||
@ -84,22 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
<<<<<<< HEAD
|
||||
logging.info(f"{context}: Switched to EMA weights")
|
||||
=======
|
||||
logpy.info(f"{context}: Switched to EMA weights")
|
||||
>>>>>>> 0e1536b4 (logic to upload images from this server)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
<<<<<<< HEAD
|
||||
logging.info(f"{context}: Restored training weights")
|
||||
=======
|
||||
logpy.info(f"{context}: Restored training weights")
|
||||
>>>>>>> 0e1536b4 (logic to upload images from this server)
|
||||
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("encode()-method of abstract base class called")
|
||||
@ -108,11 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
<<<<<<< HEAD
|
||||
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
=======
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
>>>>>>> 0e1536b4 (logic to upload images from this server)
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
@ -140,11 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
|
||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||
<<<<<<< HEAD
|
||||
self.regularization = instantiate_from_config(
|
||||
=======
|
||||
self.regularization: AbstractRegularizer = instantiate_from_config(
|
||||
>>>>>>> 0e1536b4 (logic to upload images from this server)
|
||||
regularizer_config
|
||||
)
|
||||
|
||||
@ -192,7 +162,6 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
<<<<<<< HEAD
|
||||
|
||||
if ddconfig.get("conv3d", False):
|
||||
conv_op = comfy.ops.disable_weight_init.Conv3d
|
||||
@ -200,19 +169,12 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
conv_op = comfy.ops.disable_weight_init.Conv2d
|
||||
|
||||
self.quant_conv = conv_op(
|
||||
=======
|
||||
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
|
||||
>>>>>>> 0e1536b4 (logic to upload images from this server)
|
||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||
(1 + ddconfig["double_z"]) * embed_dim,
|
||||
1,
|
||||
)
|
||||
<<<<<<< HEAD
|
||||
|
||||
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
||||
=======
|
||||
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
>>>>>>> 0e1536b4 (logic to upload images from this server)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
|
23
main.py
23
main.py
@ -161,9 +161,9 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def prompt_worker(q, server, memedeck_worker):
|
||||
def prompt_worker(q, server_instance, memedeck_worker):
|
||||
current_time: float = 0.0
|
||||
e = execution.PromptExecutor(server, memedeck_worker, lru_size=args.cache_lru)
|
||||
e = execution.PromptExecutor(server_instance, memedeck_worker, lru_size=args.cache_lru)
|
||||
|
||||
# threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start()
|
||||
last_gc_collect = 0
|
||||
@ -181,6 +181,8 @@ def prompt_worker(q, server, memedeck_worker):
|
||||
execution_start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
server_instance.last_prompt_id = prompt_id
|
||||
|
||||
print(item[2])
|
||||
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
@ -228,7 +230,7 @@ async def run(server_instance, memedeck_worker, address='', port=8188, verbose=T
|
||||
)
|
||||
|
||||
|
||||
def hijack_progress(server_instance):
|
||||
def hijack_progress(server_instance, memedeck_worker):
|
||||
def hook(value, total, preview_image):
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
||||
@ -312,9 +314,9 @@ def start_comfyui(asyncio_loop=None):
|
||||
cuda_malloc_warning()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(server, memedeck_worker)
|
||||
hijack_progress(prompt_server, memedeck_worker)
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, server, memedeck_worker)).start()
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server, memedeck_worker)).start()
|
||||
threading.Thread(target=memedeck_worker.start, daemon=True, args=(q, execution.validate_prompt)).start()
|
||||
# set logging level to info
|
||||
|
||||
@ -357,20 +359,21 @@ def start_comfyui(asyncio_loop=None):
|
||||
|
||||
async def start_all():
|
||||
await prompt_server.setup()
|
||||
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
||||
await run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
||||
# await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
||||
await run(prompt_server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
||||
|
||||
# Returning these so that other code can integrate with the ComfyUI loop and server
|
||||
return asyncio_loop, prompt_server, start_all
|
||||
|
||||
return asyncio_loop, prompt_server, memedeck_worker, start_all
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Running directly, just start ComfyUI.
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
event_loop, _, start_all_func = start_comfyui()
|
||||
event_loop, server, memedeck_worker, start_all_func = start_comfyui()
|
||||
try:
|
||||
event_loop.run_until_complete(start_all_func())
|
||||
# event_loop.run_until_complete(run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
# event_loop.run_until_complete(start_all_func())
|
||||
# event_loop.run_until_complete(run(server, memedeck_worker, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logging.info("\nStopped server")
|
||||
|
@ -70,7 +70,7 @@ class MemedeckWorker:
|
||||
|
||||
self.azure_storage = MemedeckAzureStorage()
|
||||
|
||||
if self.queue_name == 'video-gen-queue':
|
||||
if self.queue_name == 'video-gen-queue' or self.queue_name == 'scene-gen-queue':
|
||||
print(f"[memedeck]: video gen only mode enabled")
|
||||
self.video_gen_only = True
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user