Compare commits

...

37 Commits

Author SHA1 Message Date
comfyanonymous
22ad513c72 Refactor node cache code to more easily add other types of cache. 2025-04-11 07:16:52 -04:00
Chargeuk
ed945a1790
Dependency Aware Node Caching for low RAM/VRAM machines (#7509)
* add dependency aware cache that removed a cached node as soon as all of its decendents have executed. This allows users with lower RAM to run workflows they would otherwise not be able to run. The downside is that every workflow will fully run each time even if no nodes have changed.

* remove test code

* tidy code
2025-04-11 06:55:51 -04:00
Chenlei Hu
f9207c6936
Update frontend to 1.15 (#7564) 2025-04-11 06:46:20 -04:00
Christian Byrne
8ad7477647
dont cache templates index (#7569) 2025-04-11 06:06:53 -04:00
Chenlei Hu
98bdca4cb2
Deprecate InputTypeOptions.defaultInput (#7551)
* Deprecate InputTypeOptions.defaultInput

* nit

* nit
2025-04-10 06:57:06 -04:00
comfyanonymous
a26da20a76 Fix custom nodes not importing when path contains a dot. 2025-04-10 03:37:52 -04:00
Jedrzej Kosinski
e346d8584e
Add prepare_sampling wrapper allowing custom nodes to more accurately report noise_shape (#7500) 2025-04-09 09:43:35 -04:00
comfyanonymous
ab31b64412 Make "surface net" the default in the VoxelToMesh node. 2025-04-09 09:42:08 -04:00
thot experiment
fe29739c68
add VoxelToMesh node w/ surfacenet meshing (#7446)
* add VoxelToMesh node w/ surfacenet meshing

could delete the VoxelToMeshBasic node now probably?

* fix ruff
2025-04-09 09:41:03 -04:00
Chenlei Hu
e8345a9b7b
Align /prompt response schema (#7423) 2025-04-09 09:10:36 -04:00
comfyanonymous
8c6b9f4481
Prevent custom nodes from accidentally overwriting global modules. (#7167)
* Prevent custom nodes from accidentally overwriting global modules.

* Improve.
2025-04-09 09:08:57 -04:00
Christian Byrne
cc7e023a4a
handle palette mode in loadimage node (#7539) 2025-04-09 09:07:07 -04:00
comfyanonymous
2f7d8159c3 Show the user an error when the controlnet file is invalid. 2025-04-08 08:11:59 -04:00
comfyanonymous
70d7242e57 Support the wan fun reward loras. 2025-04-07 05:01:47 -04:00
comfyanonymous
49b732afd5 Show a proper error to the user when a vision model file is invalid. 2025-04-06 22:43:56 -04:00
comfyanonymous
3bfe4e5276 Support 512 siglip model. 2025-04-05 07:01:01 -04:00
Raphael Walker
89e4ea0175
Add activations_shape info in UNet models (#7482)
* Add activations_shape info in UNet models

* activations_shape should be a list
2025-04-04 21:27:54 -04:00
comfyanonymous
3a100b9a55 Disable partial offloading of audio VAE. 2025-04-04 21:24:56 -04:00
comfyanonymous
721253cb05 Fix problem. 2025-04-03 20:57:59 -04:00
comfyanonymous
3d2e3a6f29 Fix alpha image issue in more nodes. 2025-04-02 19:32:49 -04:00
BiologicalExplosion
2222cf67fd
MLU memory optimization (#7470)
Co-authored-by: huzhan <huzhan@cambricon.com>
2025-04-02 19:24:04 -04:00
comfyanonymous
ab5413351e Fix comment.
This function does not support quads.
2025-04-01 14:09:31 -04:00
Laurent Erignoux
2b71aab299
User missing (#7439)
* Ensuring a 401 error is returned when user data is not found in multi-user context.

* Returning a 401 error when provided comfy-user does not exists on server side.
2025-04-01 13:53:52 -04:00
BVH
301e26b131
Add option to store TE in bf16 (#7461) 2025-04-01 13:48:53 -04:00
comfyanonymous
548457bac4 Fix alpha channel mismatch on destination in ImageCompositeMasked 2025-03-31 20:59:12 -04:00
comfyanonymous
0b4584c741 Fix latent composite node not working when source has alpha. 2025-03-30 21:47:05 -04:00
comfyanonymous
a3100c8452 Remove useless code. 2025-03-29 20:12:56 -04:00
Michael Kupchick
832fc02330
ltxv: fix preprocessing exception when compression is 0. (#7431) 2025-03-29 20:03:02 -04:00
comfyanonymous
2d17d8910c Don't error if wan concat image has extra channels. 2025-03-28 08:49:29 -04:00
Chenlei Hu
a40fcfc2d5
Update frontend to 1.14.6 (#7416)
Cherry-pick the fix: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3252
2025-03-28 02:27:01 -04:00
comfyanonymous
0a1f8869c9 Add WanFunInpaintToVideo node for the Wan fun inpaint models. 2025-03-27 11:13:27 -04:00
comfyanonymous
3661c833bc Support the WAN 2.1 fun control models.
Use the new WanFunControlToVideo node.
2025-03-26 19:54:54 -04:00
comfyanonymous
84fdaf7b0e Add CFGZeroStar node.
Works on all models that use a negative prompt but is meant for rectified
flow models.
2025-03-26 05:09:52 -04:00
comfyanonymous
8edc1f44c1 Support more float8 types. 2025-03-25 05:23:49 -04:00
comfyanonymous
eade1551bb Add Hunyuan3D to readme. 2025-03-24 07:14:32 -04:00
comfyanonymous
581a9991ff Add model merging node for WAN 2.1 2025-03-23 08:06:36 -04:00
comfyanonymous
e471c726e5 Fallback to pytorch attention if sage attention fails. 2025-03-22 15:45:56 -04:00
28 changed files with 771 additions and 62 deletions

View File

@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- 3D Models
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.

View File

@ -9,8 +9,14 @@ class AppSettings():
self.user_manager = user_manager
def get_settings(self, request):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
try:
file = self.user_manager.get_request_user_filepath(
request,
"comfy.settings.json"
)
except KeyError as e:
logging.error("User settings not found.")
raise web.HTTPUnauthorized() from e
if os.path.isfile(file):
try:
with open(file) as f:

View File

@ -79,6 +79,7 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
@ -100,6 +101,7 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@ -110,9 +110,13 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
elif embed_shape == 577:
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
else:

View File

@ -0,0 +1,13 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 512,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 16,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@ -102,9 +102,13 @@ class InputTypeOptions(TypedDict):
default: bool | str | float | int | list | tuple
"""The default value of the widget"""
defaultInput: bool
"""Defaults to an input slot rather than a widget"""
"""@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
- defaultInput on required inputs should be dropped.
- defaultInput on optional inputs should be replaced with forceInput.
Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
"""
forceInput: bool
"""`defaultInput` and also don't allow converting to a widget"""
"""Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
lazy: bool
"""Declares that this input uses lazy evaluation"""
rawLink: bool

View File

@ -471,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
tensor_layout = "HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
@ -479,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
tensor_layout="NHD"
tensor_layout = "NHD"
if mask is not None:
# add a batch dimension if there isn't already one
@ -489,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
@ -837,6 +847,7 @@ class SpatialTransformer(nn.Module):
if not isinstance(context, list):
context = [context] * len(self.transformer_blocks)
b, c, h, w = x.shape
transformer_options["activations_shape"] = list(x.shape)
x_in = x
x = self.norm(x)
if not self.use_linear:
@ -952,6 +963,7 @@ class SpatialVideoTransformer(SpatialTransformer):
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
transformer_options["activations_shape"] = list(x.shape)
x_in = x
spatial_context = None
if exists(context):

View File

@ -1,4 +1,5 @@
import torch
import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
@ -11,7 +12,13 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
return sd_out
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
return sd

View File

@ -992,31 +992,41 @@ class WAN21(BaseModel):
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]:
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
shape_image = list(noise.shape)
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], 16):
image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
image = utils.resize_to_batch_size(image, noise.shape[0])
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0])
if not self.image_to_video:
if not self.image_to_video or extra_channels == image.shape[1]:
return image
if image.shape[1] > (extra_channels - 4):
image = image[:, :(extra_channels - 4)]
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :4]
else:
mask = 1.0 - torch.mean(mask, dim=1, keepdim=True)
if mask.shape[1] != 4:
mask = torch.mean(mask, dim=1, keepdim=True)
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = mask.repeat(1, 4, 1, 1, 1)
if mask.shape[1] == 1:
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((mask, image), dim=1)

View File

@ -46,6 +46,32 @@ cpu_state = CPUState.GPU
total_vram = 0
def get_supported_float8_types():
float8_types = []
try:
float8_types.append(torch.float8_e4m3fn)
except:
pass
try:
float8_types.append(torch.float8_e4m3fnuz)
except:
pass
try:
float8_types.append(torch.float8_e5m2)
except:
pass
try:
float8_types.append(torch.float8_e5m2fnuz)
except:
pass
try:
float8_types.append(torch.float8_e8m0fnu)
except:
pass
return float8_types
FLOAT8_TYPES = get_supported_float8_types()
xpu_available = False
torch_version = ""
try:
@ -701,11 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e5m2
fp8_dtype = None
try:
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
fp8_dtype = weight_dtype
except:
pass
if weight_dtype in FLOAT8_TYPES:
fp8_dtype = weight_dtype
if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
@ -800,6 +823,8 @@ def text_encoder_dtype(device=None):
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.bf16_text_enc:
return torch.bfloat16
elif args.fp32_text_enc:
return torch.float32
@ -1212,6 +1237,8 @@ def soft_empty_cache(force=False):
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
elif is_mlu():
torch.mlu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

View File

@ -48,6 +48,7 @@ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_option
class WrappersMP:
OUTER_SAMPLE = "outer_sample"
PREPARE_SAMPLING = "prepare_sampling"
SAMPLER_SAMPLE = "sampler_sample"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"

View File

@ -106,6 +106,13 @@ def cleanup_additional_models(models):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
return executor.execute(model, noise_shape, conds, model_options=model_options)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)

View File

@ -265,6 +265,7 @@ class VAE:
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.downscale_index_formula = None
self.upscale_index_formula = None
@ -337,6 +338,7 @@ class VAE:
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.disable_offload = True
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
@ -515,7 +517,7 @@ class VAE:
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@ -544,7 +546,7 @@ class VAE:
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
dims = samples.ndim - 2
args = {}
if tile_x is not None:
@ -578,7 +580,7 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
@ -612,7 +614,7 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
args = {}
if tile_x is not None:

View File

@ -969,12 +969,24 @@ class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
"in_dim": 36,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=True, device=device)
return out
class WAN21_FunControl2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "i2v",
"in_dim": 48,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21(self, image_to_video=False, device=device)
return out
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@ -1013,6 +1025,6 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2]
models += [SVD_img2vid]

View File

@ -316,3 +316,156 @@ class LRUCache(BasicCache):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
class DependencyAwareCache(BasicCache):
"""
A cache implementation that tracks dependencies between nodes and manages
their execution and caching accordingly. It extends the BasicCache class.
Nodes are removed from this cache once all of their descendants have been
executed.
"""
def __init__(self, key_class):
"""
Initialize the DependencyAwareCache.
Args:
key_class: The class used for generating cache keys.
"""
super().__init__(key_class)
self.descendants = {} # Maps node_id -> set of descendant node_ids
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
self.executed_nodes = set() # Tracks nodes that have been executed
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
"""
Clear the entire cache and rebuild the dependency graph.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to initialize the cache for.
is_changed_cache: Flag indicating if the cache has changed.
"""
# Clear all existing cache data
self.cache.clear()
self.subcaches.clear()
self.descendants.clear()
self.ancestors.clear()
self.executed_nodes.clear()
# Call the parent method to initialize the cache with the new prompt
super().set_prompt(dynprompt, node_ids, is_changed_cache)
# Rebuild the dependency graph
self._build_dependency_graph(dynprompt, node_ids)
def _build_dependency_graph(self, dynprompt, node_ids):
"""
Build the dependency graph for all nodes.
Args:
dynprompt: The dynamic prompt object containing node information.
node_ids: List of node IDs to build the graph for.
"""
self.descendants.clear()
self.ancestors.clear()
for node_id in node_ids:
self.descendants[node_id] = set()
self.ancestors[node_id] = set()
for node_id in node_ids:
inputs = dynprompt.get_node(node_id)["inputs"]
for input_data in inputs.values():
if is_link(input_data): # Check if the input is a link to another node
ancestor_id = input_data[0]
self.descendants[ancestor_id].add(node_id)
self.ancestors[node_id].add(ancestor_id)
def set(self, node_id, value):
"""
Mark a node as executed and store its value in the cache.
Args:
node_id: The ID of the node to store.
value: The value to store for the node.
"""
self._set_immediate(node_id, value)
self.executed_nodes.add(node_id)
self._cleanup_ancestors(node_id)
def get(self, node_id):
"""
Retrieve the cached value for a node.
Args:
node_id: The ID of the node to retrieve.
Returns:
The cached value for the node.
"""
return self._get_immediate(node_id)
def ensure_subcache_for(self, node_id, children_ids):
"""
Ensure a subcache exists for a node and update dependencies.
Args:
node_id: The ID of the parent node.
children_ids: List of child node IDs to associate with the parent node.
Returns:
The subcache object for the node.
"""
subcache = super()._ensure_subcache(node_id, children_ids)
for child_id in children_ids:
self.descendants[node_id].add(child_id)
self.ancestors[child_id].add(node_id)
return subcache
def _cleanup_ancestors(self, node_id):
"""
Check if ancestors of a node can be removed from the cache.
Args:
node_id: The ID of the node whose ancestors are to be checked.
"""
for ancestor_id in self.ancestors.get(node_id, []):
if ancestor_id in self.executed_nodes:
# Remove ancestor if all its descendants have been executed
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
self._remove_node(ancestor_id)
def _remove_node(self, node_id):
"""
Remove a node from the cache.
Args:
node_id: The ID of the node to remove.
"""
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
del self.cache[cache_key]
subcache_key = self.cache_key_set.get_subcache_key(node_id)
if subcache_key in self.subcaches:
del self.subcaches[subcache_key]
def clean_unused(self):
"""
Clean up unused nodes. This is a no-op for this cache implementation.
"""
pass
def recursive_debug_dump(self):
"""
Dump the cache and dependency graph for debugging.
Returns:
A list containing the cache state and dependency graph.
"""
result = super().recursive_debug_dump()
result.append({
"descendants": self.descendants,
"ancestors": self.ancestors,
"executed_nodes": list(self.executed_nodes),
})
return result

45
comfy_extras/nodes_cfg.py Normal file
View File

@ -0,0 +1,45 @@
import torch
# https://github.com/WeichenFan/CFG-Zero-star
def optimized_scale(positive, negative):
positive_flat = positive.reshape(positive.shape[0], -1)
negative_flat = negative.reshape(negative.shape[0], -1)
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
class CFGZeroStar:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
def patch(self, model):
m = model.clone()
def cfg_zero_star(args):
guidance_scale = args['cond_scale']
x = args['input']
cond_p = args['cond_denoised']
uncond_p = args['uncond_denoised']
out = args["denoised"]
alpha = optimized_scale(x - cond_p, x - uncond_p)
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
m.set_model_sampler_post_cfg_function(cfg_zero_star)
return (m, )
NODE_CLASS_MAPPINGS = {
"CFGZeroStar": CFGZeroStar
}

View File

@ -209,6 +209,196 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None):
vertices = torch.fliplr(vertices)
return vertices, faces
def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
voxels = voxels.to(device)
D, H, W = voxels.shape
padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
z, y, x = torch.meshgrid(
torch.arange(D, device=device),
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
corner_offsets = torch.tensor([
[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
], device=device)
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
for c, (dz, dy, dx) in enumerate(corner_offsets):
corner_values[:, c] = padded[
cell_positions[:, 0] + dz,
cell_positions[:, 1] + dy,
cell_positions[:, 2] + dx
]
corner_signs = corner_values > threshold
has_inside = torch.any(corner_signs, dim=1)
has_outside = torch.any(~corner_signs, dim=1)
contains_surface = has_inside & has_outside
active_cells = cell_positions[contains_surface]
active_signs = corner_signs[contains_surface]
active_values = corner_values[contains_surface]
if active_cells.shape[0] == 0:
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
edges = torch.tensor([
[0, 1], [0, 2], [0, 4], [1, 3],
[1, 5], [2, 3], [2, 6], [3, 7],
[4, 5], [4, 6], [5, 7], [6, 7]
], device=device)
cell_vertices = {}
progress = comfy.utils.ProgressBar(100)
for edge_idx, (e1, e2) in enumerate(edges):
progress.update(1)
crossing = active_signs[:, e1] != active_signs[:, e2]
if not crossing.any():
continue
cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
v1 = active_values[cell_indices, e1]
v2 = active_values[cell_indices, e2]
t = torch.zeros_like(v1, device=device)
denom = v2 - v1
valid = denom != 0
t[valid] = (threshold - v1[valid]) / denom[valid]
t[~valid] = 0.5
p1 = corner_offsets[e1].float()
p2 = corner_offsets[e2].float()
intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
for i, point in zip(cell_indices.tolist(), intersection):
if i not in cell_vertices:
cell_vertices[i] = []
cell_vertices[i].append(point)
# Calculate the final vertices as the average of intersection points for each cell
vertices = []
vertex_lookup = {}
vert_progress_mod = round(len(cell_vertices)/50)
for i, points in cell_vertices.items():
if not i % vert_progress_mod:
progress.update(1)
if points:
vertex = torch.stack(points).mean(dim=0)
vertex = vertex + active_cells[i].float()
vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
vertices.append(vertex)
if not vertices:
return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
final_vertices = torch.stack(vertices)
inside_corners_mask = active_signs
outside_corners_mask = ~active_signs
inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
for i in range(8):
mask_inside = inside_corners_mask[:, i].unsqueeze(1)
mask_outside = outside_corners_mask[:, i].unsqueeze(1)
inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
inside_pos /= inside_counts
outside_pos /= outside_counts
gradients = inside_pos - outside_pos
pos_dirs = torch.tensor([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1]
], device=device)
cross_products = [
torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
for i in range(3) for j in range(i+1, 3)
]
faces = []
all_keys = set(vertex_lookup.keys())
face_progress_mod = round(len(active_cells)/38*3)
for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
dir_i = pos_dirs[i]
dir_j = pos_dirs[j]
cross_product = cross_products[pair_idx]
ni_positions = active_cells + dir_i
nj_positions = active_cells + dir_j
diag_positions = active_cells + dir_i + dir_j
alignments = torch.matmul(gradients, cross_product)
valid_quads = []
quad_indices = []
for idx, active_cell in enumerate(active_cells):
if not idx % face_progress_mod:
progress.update(1)
cell_key = tuple(active_cell.tolist())
ni_key = tuple(ni_positions[idx].tolist())
nj_key = tuple(nj_positions[idx].tolist())
diag_key = tuple(diag_positions[idx].tolist())
if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
v0 = vertex_lookup[cell_key]
v1 = vertex_lookup[ni_key]
v2 = vertex_lookup[nj_key]
v3 = vertex_lookup[diag_key]
valid_quads.append((v0, v1, v2, v3))
quad_indices.append(idx)
for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
cell_idx = quad_indices[q_idx]
if alignments[cell_idx] > 0:
faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
else:
faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
if faces:
faces = torch.stack(faces)
else:
faces = torch.zeros((0, 3), dtype=torch.long, device=device)
v_min = 0
v_max = max(D, H, W)
final_vertices = final_vertices - (v_min + v_max) / 2
scale = (v_max - v_min) / 2
if scale > 0:
final_vertices = final_vertices / scale
final_vertices = torch.fliplr(final_vertices)
return final_vertices, faces
class MESH:
def __init__(self, vertices, faces):
@ -237,6 +427,34 @@ class VoxelToMeshBasic:
return (MESH(torch.stack(vertices), torch.stack(faces)), )
class VoxelToMesh:
@classmethod
def INPUT_TYPES(s):
return {"required": {"voxel": ("VOXEL", ),
"algorithm": (["surface net", "basic"], ),
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MESH",)
FUNCTION = "decode"
CATEGORY = "3d"
def decode(self, voxel, algorithm, threshold):
vertices = []
faces = []
if algorithm == "basic":
mesh_function = voxel_to_mesh
elif algorithm == "surface net":
mesh_function = voxel_to_mesh_surfnet
for x in voxel.data:
v, f = mesh_function(x, threshold=threshold, device=None)
vertices.append(v)
faces.append(f)
return (MESH(torch.stack(vertices), torch.stack(faces)), )
def save_glb(vertices, faces, filepath, metadata=None):
"""
@ -244,7 +462,7 @@ def save_glb(vertices, faces, filepath, metadata=None):
Parameters:
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
faces: torch.Tensor of shape (M, 4) or (M, 3) - The face indices (quad or triangle faces)
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
filepath: str - Output filepath (should end with .glb)
"""
@ -411,5 +629,6 @@ NODE_CLASS_MAPPINGS = {
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
"VoxelToMeshBasic": VoxelToMeshBasic,
"VoxelToMesh": VoxelToMesh,
"SaveGLB": SaveGLB,
}

View File

@ -446,10 +446,9 @@ class LTXVPreprocess:
CATEGORY = "image"
def preprocess(self, image, img_compression):
if img_compression > 0:
output_images = []
for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression))
output_images = []
for i in range(image.shape[0]):
output_images.append(preprocess(image[i], img_compression))
return (torch.stack(output_images),)

View File

@ -2,6 +2,7 @@ import numpy as np
import scipy.ndimage
import torch
import comfy.utils
import node_helpers
from nodes import MAX_RESOLUTION
@ -87,6 +88,7 @@ class ImageCompositeMasked:
CATEGORY = "image"
def composite(self, destination, source, x, y, resize_source, mask = None):
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
return (output,)

View File

@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["patch_embedding."] = argument
arg_dict["time_embedding."] = argument
arg_dict["time_projection."] = argument
arg_dict["text_embedding."] = argument
arg_dict["img_emb."] = argument
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["head."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
}

View File

@ -6,7 +6,7 @@ import math
import comfy.utils
import comfy.model_management
import node_helpers
class Blend:
def __init__(self):
@ -34,6 +34,7 @@ class Blend:
CATEGORY = "image/postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
image1, image2 = node_helpers.image_alpha_fix(image1, image2)
image2 = image2.to(image1.device)
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)

View File

@ -3,6 +3,7 @@ import node_helpers
import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
class WanImageToVideo:
@ -49,6 +50,110 @@ class WanImageToVideo:
return (positive, negative, out_latent)
class WanFunControlToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"control_video": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
if control_video is not None:
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(control_video[:, :, :, :3])
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFunInpaintToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,
}

View File

@ -15,7 +15,7 @@ import nodes
import comfy.model_management
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.caching import HierarchicalCache, LRUCache, DependencyAwareCache, CacheKeySetInputSignature, CacheKeySetID
from comfy_execution.validation import validate_node_input
class ExecutionResult(Enum):
@ -59,20 +59,27 @@ class IsChangedCache:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
class CacheSet:
def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0:
self.init_classic_cache()
else:
self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects]
# Useful for those with ample RAM/VRAM -- allows experimenting without
# blowing away the cache every time
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
class CacheType(Enum):
CLASSIC = 0
LRU = 1
DEPENDENCY_AWARE = 2
class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
self.init_lru_cache(cache_size)
logging.info("Using LRU cache")
else:
self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects]
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
@ -80,6 +87,17 @@ class CacheSet:
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
@ -414,13 +432,14 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor:
def __init__(self, server, lru_size=None):
self.lru_size = lru_size
def __init__(self, server, cache_type=False, cache_size=None):
self.cache_size = cache_size
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
self.caches = CacheSet(self.lru_size)
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.status_messages = []
self.success = True
@ -775,7 +794,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
return (False, error, [], {})
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
@ -786,7 +805,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'",
"extra_info": {}
}
return (False, error, [], [])
return (False, error, [], {})
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x)
@ -798,7 +817,7 @@ def validate_prompt(prompt):
"details": "",
"extra_info": {}
}
return (False, error, [], [])
return (False, error, [], {})
good_outputs = set()
errors = []

View File

@ -156,7 +156,13 @@ def cuda_malloc_warning():
def prompt_worker(q, server_instance):
current_time: float = 0.0
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0

View File

@ -44,3 +44,11 @@ def string_to_torch_dtype(string):
return torch.float16
if string == "bf16":
return torch.bfloat16
def image_alpha_fix(destination, source):
if destination.shape[-1] < source.shape[-1]:
source = source[...,:destination.shape[-1]]
elif destination.shape[-1] > source.shape[-1]:
destination = torch.nn.functional.pad(destination, (0, 1))
destination[..., -1] = 1.0
return destination, source

View File

@ -786,6 +786,8 @@ class ControlNetLoader:
def load_controlnet(self, control_net_name):
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
if controlnet is None:
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
return (controlnet,)
class DiffControlNetLoader:
@ -1006,6 +1008,8 @@ class CLIPVisionLoader:
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name)
clip_vision = comfy.clip_vision.load(clip_path)
if clip_vision is None:
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
return (clip_vision,)
class CLIPVisionEncode:
@ -1688,6 +1692,9 @@ class LoadImage:
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
@ -2123,21 +2130,25 @@ def get_module_name(module_path: str) -> str:
def load_custom_node(module_path: str, ignore=set(), module_parent="custom_nodes") -> bool:
module_name = os.path.basename(module_path)
module_name = get_module_name(module_path)
if os.path.isfile(module_path):
sp = os.path.splitext(module_path)
module_name = sp[0]
sys_module_name = module_name
elif os.path.isdir(module_path):
sys_module_name = module_path.replace(".", "_x_")
try:
logging.debug("Trying to load custom node {}".format(module_path))
if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module_spec = importlib.util.spec_from_file_location(sys_module_name, module_path)
module_dir = os.path.split(module_path)[0]
else:
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
module_spec = importlib.util.spec_from_file_location(sys_module_name, os.path.join(module_path, "__init__.py"))
module_dir = module_path
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
sys.modules[sys_module_name] = module
module_spec.loader.exec_module(module)
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
@ -2267,6 +2278,7 @@ def init_builtin_extra_nodes():
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_primitive.py",
"nodes_cfg.py",
]
import_failed = []

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.14.5
comfyui-frontend-package==1.15.13
torch
torchsde
torchvision

View File

@ -48,7 +48,7 @@ async def send_socket_catch_exception(function, message):
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
if request.path.endswith('.js') or request.path.endswith('.css'):
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
@ -657,7 +657,13 @@ class PromptServer():
logging.warning("invalid prompt: {}".format(valid[1]))
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
error = {
"type": "no_prompt",
"message": "No prompt provided",
"details": "No prompt provided",
"extra_info": {}
}
return web.json_response({"error": error, "node_errors": {}}, status=400)
@routes.post("/queue")
async def post_queue(request):