From 6fb4cc0179805245f136fe5d44bf8cf6de2c83bb Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 2 Apr 2025 09:21:17 +0800 Subject: [PATCH 01/10] Weight Adapter Scheme --- comfy/weight_adapter/__init__.py | 13 +++++++++++ comfy/weight_adapter/base.py | 37 ++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 comfy/weight_adapter/__init__.py create mode 100644 comfy/weight_adapter/base.py diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py new file mode 100644 index 000000000..a021cfd00 --- /dev/null +++ b/comfy/weight_adapter/__init__.py @@ -0,0 +1,13 @@ +from .base import WeightAdapterBase +from .lora import LoRAAdapter +from .loha import LoHaAdapter +from .lokr import LoKrAdapter +from .glora import GLoRAAdapter + + +adapters: list[type[WeightAdapterBase]] = [ + LoRAAdapter, + LoHaAdapter, + LoKrAdapter, + GLoRAAdapter, +] \ No newline at end of file diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py new file mode 100644 index 000000000..2ba2850eb --- /dev/null +++ b/comfy/weight_adapter/base.py @@ -0,0 +1,37 @@ +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class WeightAdapterBase: + name: str + loaded_keys: set[str] + weights: list[torch.Tensor] + + @classmethod + def load(cls, x: str, lora: dict[str, torch.Tensor]) -> "WeightAdapterBase" | None: + raise NotImplementedError + + def to_train(self) -> "WeightAdapterTrainBase": + raise NotImplementedError + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + raise NotImplementedError + + +class WeightAdapterTrainBase(nn.Module): + def __init__(self): + super().__init__() + + # [TODO] Collaborate with LoRA training PR #7032 From 4774c3244eaa93c0f089c85e6967be7d76f93342 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 2 Apr 2025 09:21:39 +0800 Subject: [PATCH 02/10] Initial impl LoRA load/calculate_weight LoHa/LoKr/GLoRA load --- comfy/weight_adapter/glora.py | 54 +++++++++++++ comfy/weight_adapter/loha.py | 65 +++++++++++++++ comfy/weight_adapter/lokr.py | 89 +++++++++++++++++++++ comfy/weight_adapter/lora.py | 144 ++++++++++++++++++++++++++++++++++ 4 files changed, 352 insertions(+) create mode 100644 comfy/weight_adapter/glora.py create mode 100644 comfy/weight_adapter/loha.py create mode 100644 comfy/weight_adapter/lokr.py create mode 100644 comfy/weight_adapter/lora.py diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py new file mode 100644 index 000000000..bdb9220e5 --- /dev/null +++ b/comfy/weight_adapter/glora.py @@ -0,0 +1,54 @@ +import logging +import torch +import comfy.utils +import comfy.model_management +import comfy.model_base +from comfy.lora import weight_decompose, pad_tensor_to_shape + +from .base import WeightAdapterBase + + +class GLoRAAdapter(WeightAdapterBase): + name = "glora" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> "GLoRAAdapter" | None: + if loaded_keys is None: + loaded_keys = set() + a1_name = "{}.a1.weight".format(x) + a2_name = "{}.a2.weight".format(x) + b1_name = "{}.b1.weight".format(x) + b2_name = "{}.b2.weight".format(x) + if a1_name in lora: + weights = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) + loaded_keys.add(a1_name) + loaded_keys.add(a2_name) + loaded_keys.add(b1_name) + loaded_keys.add(b2_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + pass diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py new file mode 100644 index 000000000..b3267bc0a --- /dev/null +++ b/comfy/weight_adapter/loha.py @@ -0,0 +1,65 @@ +import logging +import torch +import comfy.utils +import comfy.model_management +import comfy.model_base +from comfy.lora import weight_decompose, pad_tensor_to_shape + +from .base import WeightAdapterBase + + +class LoHaAdapter(WeightAdapterBase): + name = "loha" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> "LoHaAdapter" | None: + if loaded_keys is None: + loaded_keys = set() + + hada_w1_a_name = "{}.hada_w1_a".format(x) + hada_w1_b_name = "{}.hada_w1_b".format(x) + hada_w2_a_name = "{}.hada_w2_a".format(x) + hada_w2_b_name = "{}.hada_w2_b".format(x) + hada_t1_name = "{}.hada_t1".format(x) + hada_t2_name = "{}.hada_t2".format(x) + if hada_w1_a_name in lora.keys(): + hada_t1 = None + hada_t2 = None + if hada_t1_name in lora.keys(): + hada_t1 = lora[hada_t1_name] + hada_t2 = lora[hada_t2_name] + loaded_keys.add(hada_t1_name) + loaded_keys.add(hada_t2_name) + + weights = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) + loaded_keys.add(hada_w1_a_name) + loaded_keys.add(hada_w1_b_name) + loaded_keys.add(hada_w2_a_name) + loaded_keys.add(hada_w2_b_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + pass diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py new file mode 100644 index 000000000..206c80ae9 --- /dev/null +++ b/comfy/weight_adapter/lokr.py @@ -0,0 +1,89 @@ +import logging +import torch +import comfy.utils +import comfy.model_management +import comfy.model_base +from comfy.lora import weight_decompose, pad_tensor_to_shape + +from .base import WeightAdapterBase + + +class LoKrAdapter(WeightAdapterBase): + name = "lokr" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> "LoKrAdapter" | None: + if loaded_keys is None: + loaded_keys = set() + lokr_w1_name = "{}.lokr_w1".format(x) + lokr_w2_name = "{}.lokr_w2".format(x) + lokr_w1_a_name = "{}.lokr_w1_a".format(x) + lokr_w1_b_name = "{}.lokr_w1_b".format(x) + lokr_t2_name = "{}.lokr_t2".format(x) + lokr_w2_a_name = "{}.lokr_w2_a".format(x) + lokr_w2_b_name = "{}.lokr_w2_b".format(x) + + lokr_w1 = None + if lokr_w1_name in lora.keys(): + lokr_w1 = lora[lokr_w1_name] + loaded_keys.add(lokr_w1_name) + + lokr_w2 = None + if lokr_w2_name in lora.keys(): + lokr_w2 = lora[lokr_w2_name] + loaded_keys.add(lokr_w2_name) + + lokr_w1_a = None + if lokr_w1_a_name in lora.keys(): + lokr_w1_a = lora[lokr_w1_a_name] + loaded_keys.add(lokr_w1_a_name) + + lokr_w1_b = None + if lokr_w1_b_name in lora.keys(): + lokr_w1_b = lora[lokr_w1_b_name] + loaded_keys.add(lokr_w1_b_name) + + lokr_w2_a = None + if lokr_w2_a_name in lora.keys(): + lokr_w2_a = lora[lokr_w2_a_name] + loaded_keys.add(lokr_w2_a_name) + + lokr_w2_b = None + if lokr_w2_b_name in lora.keys(): + lokr_w2_b = lora[lokr_w2_b_name] + loaded_keys.add(lokr_w2_b_name) + + lokr_t2 = None + if lokr_t2_name in lora.keys(): + lokr_t2 = lora[lokr_t2_name] + loaded_keys.add(lokr_t2_name) + + if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): + weights = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + pass diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py new file mode 100644 index 000000000..b79bfd820 --- /dev/null +++ b/comfy/weight_adapter/lora.py @@ -0,0 +1,144 @@ +import logging +import torch +import comfy.utils +import comfy.model_management +import comfy.model_base +from comfy.lora import weight_decompose, pad_tensor_to_shape + +from .base import WeightAdapterBase + + +class LoRAAdapter(WeightAdapterBase): + name = "lora" + + def __init__(self, loaded_keys, weights): + self.loaded_keys = loaded_keys + self.weights = weights + + @classmethod + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + loaded_keys: set[str] = None, + ) -> "LoRAAdapter" | None: + if loaded_keys is None: + loaded_keys = set() + + reshape_name = "{}.reshape_weight".format(x) + regular_lora = "{}.lora_up.weight".format(x) + diffusers_lora = "{}_lora.up.weight".format(x) + diffusers2_lora = "{}.lora_B.weight".format(x) + diffusers3_lora = "{}.lora.up.weight".format(x) + mochi_lora = "{}.lora_B".format(x) + transformers_lora = "{}.lora_linear_layer.up.weight".format(x) + A_name = None + + if regular_lora in lora.keys(): + A_name = regular_lora + B_name = "{}.lora_down.weight".format(x) + mid_name = "{}.lora_mid.weight".format(x) + elif diffusers_lora in lora.keys(): + A_name = diffusers_lora + B_name = "{}_lora.down.weight".format(x) + mid_name = None + elif diffusers2_lora in lora.keys(): + A_name = diffusers2_lora + B_name = "{}.lora_A.weight".format(x) + mid_name = None + elif diffusers3_lora in lora.keys(): + A_name = diffusers3_lora + B_name = "{}.lora.down.weight".format(x) + mid_name = None + elif mochi_lora in lora.keys(): + A_name = mochi_lora + B_name = "{}.lora_A".format(x) + mid_name = None + elif transformers_lora in lora.keys(): + A_name = transformers_lora + B_name = "{}.lora_linear_layer.down.weight".format(x) + mid_name = None + + if A_name is not None: + mid = None + if mid_name is not None and mid_name in lora.keys(): + mid = lora[mid_name] + loaded_keys.add(mid_name) + reshape = None + if reshape_name in lora.keys(): + try: + reshape = lora[reshape_name].tolist() + loaded_keys.add(reshape_name) + except: + pass + weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape) + loaded_keys.add(A_name) + loaded_keys.add(B_name) + return cls(loaded_keys, weights) + else: + return None + + def calculate_weight( + self, + weight, + key, + strength, + strength_model, + offset, + function, + intermediate_dtype=torch.float32, + original_weight=None, + ): + v = self.weights[1] + mat1 = comfy.model_management.cast_to_device( + v[0], weight.device, intermediate_dtype + ) + mat2 = comfy.model_management.cast_to_device( + v[1], weight.device, intermediate_dtype + ) + dora_scale = v[4] + reshape = v[5] + + if reshape is not None: + weight = pad_tensor_to_shape(weight, reshape) + + if v[2] is not None: + alpha = v[2] / mat2.shape[0] + else: + alpha = 1.0 + + if v[3] is not None: + # locon mid weights, hopefully the math is fine because I didn't properly test it + mat3 = comfy.model_management.cast_to_device( + v[3], weight.device, intermediate_dtype + ) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = ( + torch.mm( + mat2.transpose(0, 1).flatten(start_dim=1), + mat3.transpose(0, 1).flatten(start_dim=1), + ) + .reshape(final_shape) + .transpose(0, 1) + ) + try: + lora_diff = torch.mm( + mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) + ).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight From c40686eb429c27af4183339c14d5a3d0b5438839 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 2 Apr 2025 09:22:05 +0800 Subject: [PATCH 03/10] Utilize new weight adapter in lora.py For calculate weight I implement a fallback mechnism temporary for dev --- comfy/lora.py | 151 ++++++-------------------------------------------- 1 file changed, 18 insertions(+), 133 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index bc9f3022a..ab053e7d3 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -20,6 +20,7 @@ from __future__ import annotations import comfy.utils import comfy.model_management import comfy.model_base +import comfy.weight_adapter as weight_adapter import logging import torch @@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True): dora_scale = lora[dora_scale_name] loaded_keys.add(dora_scale_name) - reshape_name = "{}.reshape_weight".format(x) - reshape = None - if reshape_name in lora.keys(): - try: - reshape = lora[reshape_name].tolist() - loaded_keys.add(reshape_name) - except: - pass - - regular_lora = "{}.lora_up.weight".format(x) - diffusers_lora = "{}_lora.up.weight".format(x) - diffusers2_lora = "{}.lora_B.weight".format(x) - diffusers3_lora = "{}.lora.up.weight".format(x) - mochi_lora = "{}.lora_B".format(x) - transformers_lora = "{}.lora_linear_layer.up.weight".format(x) - A_name = None - - if regular_lora in lora.keys(): - A_name = regular_lora - B_name = "{}.lora_down.weight".format(x) - mid_name = "{}.lora_mid.weight".format(x) - elif diffusers_lora in lora.keys(): - A_name = diffusers_lora - B_name = "{}_lora.down.weight".format(x) - mid_name = None - elif diffusers2_lora in lora.keys(): - A_name = diffusers2_lora - B_name = "{}.lora_A.weight".format(x) - mid_name = None - elif diffusers3_lora in lora.keys(): - A_name = diffusers3_lora - B_name = "{}.lora.down.weight".format(x) - mid_name = None - elif mochi_lora in lora.keys(): - A_name = mochi_lora - B_name = "{}.lora_A".format(x) - mid_name = None - elif transformers_lora in lora.keys(): - A_name = transformers_lora - B_name ="{}.lora_linear_layer.down.weight".format(x) - mid_name = None - - if A_name is not None: - mid = None - if mid_name is not None and mid_name in lora.keys(): - mid = lora[mid_name] - loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)) - loaded_keys.add(A_name) - loaded_keys.add(B_name) - - - ######## loha - hada_w1_a_name = "{}.hada_w1_a".format(x) - hada_w1_b_name = "{}.hada_w1_b".format(x) - hada_w2_a_name = "{}.hada_w2_a".format(x) - hada_w2_b_name = "{}.hada_w2_b".format(x) - hada_t1_name = "{}.hada_t1".format(x) - hada_t2_name = "{}.hada_t2".format(x) - if hada_w1_a_name in lora.keys(): - hada_t1 = None - hada_t2 = None - if hada_t1_name in lora.keys(): - hada_t1 = lora[hada_t1_name] - hada_t2 = lora[hada_t2_name] - loaded_keys.add(hada_t1_name) - loaded_keys.add(hada_t2_name) - - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) - loaded_keys.add(hada_w1_a_name) - loaded_keys.add(hada_w1_b_name) - loaded_keys.add(hada_w2_a_name) - loaded_keys.add(hada_w2_b_name) - - - ######## lokr - lokr_w1_name = "{}.lokr_w1".format(x) - lokr_w2_name = "{}.lokr_w2".format(x) - lokr_w1_a_name = "{}.lokr_w1_a".format(x) - lokr_w1_b_name = "{}.lokr_w1_b".format(x) - lokr_t2_name = "{}.lokr_t2".format(x) - lokr_w2_a_name = "{}.lokr_w2_a".format(x) - lokr_w2_b_name = "{}.lokr_w2_b".format(x) - - lokr_w1 = None - if lokr_w1_name in lora.keys(): - lokr_w1 = lora[lokr_w1_name] - loaded_keys.add(lokr_w1_name) - - lokr_w2 = None - if lokr_w2_name in lora.keys(): - lokr_w2 = lora[lokr_w2_name] - loaded_keys.add(lokr_w2_name) - - lokr_w1_a = None - if lokr_w1_a_name in lora.keys(): - lokr_w1_a = lora[lokr_w1_a_name] - loaded_keys.add(lokr_w1_a_name) - - lokr_w1_b = None - if lokr_w1_b_name in lora.keys(): - lokr_w1_b = lora[lokr_w1_b_name] - loaded_keys.add(lokr_w1_b_name) - - lokr_w2_a = None - if lokr_w2_a_name in lora.keys(): - lokr_w2_a = lora[lokr_w2_a_name] - loaded_keys.add(lokr_w2_a_name) - - lokr_w2_b = None - if lokr_w2_b_name in lora.keys(): - lokr_w2_b = lora[lokr_w2_b_name] - loaded_keys.add(lokr_w2_b_name) - - lokr_t2 = None - if lokr_t2_name in lora.keys(): - lokr_t2 = lora[lokr_t2_name] - loaded_keys.add(lokr_t2_name) - - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) - - #glora - a1_name = "{}.a1.weight".format(x) - a2_name = "{}.a2.weight".format(x) - b1_name = "{}.b1.weight".format(x) - b2_name = "{}.b2.weight".format(x) - if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) - loaded_keys.add(a1_name) - loaded_keys.add(a2_name) - loaded_keys.add(b1_name) - loaded_keys.add(b2_name) + for adapter_cls in weight_adapter.adapters: + adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys) + if adapter is not None: + patch_dict[to_load[x]] = adapter + loaded_keys.update(adapter.loaded_keys) + continue w_norm_name = "{}.w_norm".format(x) b_norm_name = "{}.b_norm".format(x) @@ -482,6 +356,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori if isinstance(v, list): v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), ) + if isinstance(v, weight_adapter.WeightAdapterBase): + output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights) + if output is not None: + weight = output + if old_weight is not None: + weight = old_weight + continue + else: + #Fallback when calculate_weight haven't implemented + v = v.weights + if len(v) == 1: patch_type = "diff" elif len(v) == 2: From 84317474fd6e6d0717cb28c1cc0e2569b1ea730e Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 2 Apr 2025 09:31:24 +0800 Subject: [PATCH 04/10] lint --- comfy/weight_adapter/__init__.py | 2 +- comfy/weight_adapter/base.py | 2 -- comfy/weight_adapter/glora.py | 5 ----- comfy/weight_adapter/loha.py | 5 ----- comfy/weight_adapter/lokr.py | 5 ----- 5 files changed, 1 insertion(+), 18 deletions(-) diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index a021cfd00..e6cd805b6 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -10,4 +10,4 @@ adapters: list[type[WeightAdapterBase]] = [ LoHaAdapter, LoKrAdapter, GLoRAAdapter, -] \ No newline at end of file +] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 2ba2850eb..9a7aa9de8 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -1,7 +1,5 @@ -from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F class WeightAdapterBase: diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index bdb9220e5..89574f9fc 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -1,9 +1,4 @@ -import logging import torch -import comfy.utils -import comfy.model_management -import comfy.model_base -from comfy.lora import weight_decompose, pad_tensor_to_shape from .base import WeightAdapterBase diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index b3267bc0a..c0303925d 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -1,9 +1,4 @@ -import logging import torch -import comfy.utils -import comfy.model_management -import comfy.model_base -from comfy.lora import weight_decompose, pad_tensor_to_shape from .base import WeightAdapterBase diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 206c80ae9..85bad7ec7 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -1,9 +1,4 @@ -import logging import torch -import comfy.utils -import comfy.model_management -import comfy.model_base -from comfy.lora import weight_decompose, pad_tensor_to_shape from .base import WeightAdapterBase From 88d9168df07a44674fd89fe3891ecc2dcc1965df Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 8 Apr 2025 18:38:44 +0800 Subject: [PATCH 05/10] Sync (#1) * Allow disabling pe in flux code for some other models. * Initial Hunyuan3Dv2 implementation. Supports the multiview, mini, turbo models and VAEs. * Fix orientation of hunyuan 3d model. * A few fixes for the hunyuan3d models. * Update frontend to 1.13 (#7331) * Add backend primitive nodes (#7328) * Add backend primitive nodes * Add control after generate to int primitive * Nodes to convert images to YUV and back. Can be used to convert an image to black and white. * Update frontend to 1.14 (#7343) * Native LotusD Implementation (#7125) * draft pass at a native comfy implementation of Lotus-D depth and normal est * fix model_sampling kludges * fix ruff --------- Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> * Automatically set the right sampling type for lotus. * support output normal and lineart once (#7290) * [nit] Format error strings (#7345) * ComfyUI version v0.3.27 * Fallback to pytorch attention if sage attention fails. * Add model merging node for WAN 2.1 * Add Hunyuan3D to readme. * Support more float8 types. * Add CFGZeroStar node. Works on all models that use a negative prompt but is meant for rectified flow models. * Support the WAN 2.1 fun control models. Use the new WanFunControlToVideo node. * Add WanFunInpaintToVideo node for the Wan fun inpaint models. * Update frontend to 1.14.6 (#7416) Cherry-pick the fix: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3252 * Don't error if wan concat image has extra channels. * ltxv: fix preprocessing exception when compression is 0. (#7431) * Remove useless code. * Fix latent composite node not working when source has alpha. * Fix alpha channel mismatch on destination in ImageCompositeMasked * Add option to store TE in bf16 (#7461) * User missing (#7439) * Ensuring a 401 error is returned when user data is not found in multi-user context. * Returning a 401 error when provided comfy-user does not exists on server side. * Fix comment. This function does not support quads. * MLU memory optimization (#7470) Co-authored-by: huzhan * Fix alpha image issue in more nodes. * Fix problem. * Disable partial offloading of audio VAE. * Add activations_shape info in UNet models (#7482) * Add activations_shape info in UNet models * activations_shape should be a list * Support 512 siglip model. * Show a proper error to the user when a vision model file is invalid. * Support the wan fun reward loras. --------- Co-authored-by: comfyanonymous Co-authored-by: Chenlei Hu Co-authored-by: thot experiment <94414189+thot-experiment@users.noreply.github.com> Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Co-authored-by: Terry Jia Co-authored-by: Michael Kupchick Co-authored-by: BVH <82035780+bvhari@users.noreply.github.com> Co-authored-by: Laurent Erignoux Co-authored-by: BiologicalExplosion <49753622+BiologicalExplosion@users.noreply.github.com> Co-authored-by: huzhan Co-authored-by: Raphael Walker --- README.md | 2 + app/app_settings.py | 10 +- app/frontend_management.py | 53 +- comfy/cli_args.py | 1 + comfy/clip_vision.py | 8 +- comfy/clip_vision_siglip_512.json | 13 + comfy/latent_formats.py | 10 + comfy/ldm/flux/math.py | 9 +- comfy/ldm/flux/model.py | 7 +- comfy/ldm/hunyuan3d/model.py | 135 ++++ comfy/ldm/hunyuan3d/vae.py | 587 ++++++++++++++++++ comfy/ldm/modules/attention.py | 18 +- comfy/lora_convert.py | 7 + comfy/model_base.py | 61 +- comfy/model_detection.py | 24 +- comfy/model_management.py | 37 +- comfy/model_sampling.py | 9 + comfy/sd.py | 26 +- comfy/supported_models.py | 68 +- comfy_extras/nodes_cfg.py | 45 ++ comfy_extras/nodes_hunyuan3d.py | 415 +++++++++++++ comfy_extras/nodes_load_3d.py | 18 +- comfy_extras/nodes_lotus.py | 29 + comfy_extras/nodes_lt.py | 7 +- comfy_extras/nodes_mask.py | 2 + comfy_extras/nodes_model_advanced.py | 10 +- .../nodes_model_merging_model_specific.py | 25 + comfy_extras/nodes_morphology.py | 38 ++ comfy_extras/nodes_post_processing.py | 3 +- comfy_extras/nodes_primitive.py | 79 +++ comfy_extras/nodes_wan.py | 105 ++++ comfyui_version.py | 2 +- node_helpers.py | 8 + nodes.py | 6 + pyproject.toml | 2 +- requirements.txt | 2 +- 36 files changed, 1818 insertions(+), 63 deletions(-) create mode 100644 comfy/clip_vision_siglip_512.json create mode 100644 comfy/ldm/hunyuan3d/model.py create mode 100644 comfy/ldm/hunyuan3d/vae.py create mode 100644 comfy_extras/nodes_cfg.py create mode 100644 comfy_extras/nodes_hunyuan3d.py create mode 100644 comfy_extras/nodes_lotus.py create mode 100644 comfy_extras/nodes_primitive.py diff --git a/README.md b/README.md index a807ea9d6..a99aca0e7 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) +- 3D Models + - [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2) - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. diff --git a/app/app_settings.py b/app/app_settings.py index a545df92e..c7ac73bf6 100644 --- a/app/app_settings.py +++ b/app/app_settings.py @@ -9,8 +9,14 @@ class AppSettings(): self.user_manager = user_manager def get_settings(self, request): - file = self.user_manager.get_request_user_filepath( - request, "comfy.settings.json") + try: + file = self.user_manager.get_request_user_filepath( + request, + "comfy.settings.json" + ) + except KeyError as e: + logging.error("User settings not found.") + raise web.HTTPUnauthorized() from e if os.path.isfile(file): try: with open(file) as f: diff --git a/app/frontend_management.py b/app/frontend_management.py index b4ba994d1..c56ea86e0 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -22,13 +22,21 @@ import app.logger # The path to the requirements.txt file req_path = Path(__file__).parents[1] / "requirements.txt" + def frontend_install_warning_message(): """The warning message to display when the frontend version is not up to date.""" extra = "" if sys.flags.no_user_site: extra = "-s " - return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem" + return f""" +Please install the updated requirements.txt file by running: +{sys.executable} {extra}-m pip install -r {req_path} + +This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. + +If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem +""".strip() def check_frontend_version(): @@ -43,7 +51,17 @@ def check_frontend_version(): with open(req_path, "r", encoding="utf-8") as f: required_frontend = parse_version(f.readline().split("=")[-1]) if frontend_version < required_frontend: - app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message())) + app.logger.log_startup_warning( + f""" +________________________________________________________________________ +WARNING WARNING WARNING WARNING WARNING + +Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}. + +{frontend_install_warning_message()} +________________________________________________________________________ +""".strip() + ) else: logging.info("ComfyUI frontend version: {}".format(frontend_version_str)) except Exception as e: @@ -150,9 +168,20 @@ class FrontendManager: def default_frontend_path(cls) -> str: try: import comfyui_frontend_package + return str(importlib.resources.files(comfyui_frontend_package) / "static") except ImportError: - logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n") + logging.error( + f""" +********** ERROR *********** + +comfyui-frontend-package is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) sys.exit(-1) @classmethod @@ -175,7 +204,9 @@ class FrontendManager: return match_result.group(1), match_result.group(2), match_result.group(3) @classmethod - def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: + def init_frontend_unsafe( + cls, version_string: str, provider: Optional[FrontEndProvider] = None + ) -> str: """ Initializes the frontend for the specified version. @@ -197,12 +228,20 @@ class FrontendManager: repo_owner, repo_name, version = cls.parse_version_string(version_string) if version.startswith("v"): - expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v")) + expected_path = str( + Path(cls.CUSTOM_FRONTENDS_ROOT) + / f"{repo_owner}_{repo_name}" + / version.lstrip("v") + ) if os.path.exists(expected_path): - logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}") + logging.info( + f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}" + ) return expected_path - logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...") + logging.info( + f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..." + ) provider = provider or FrontEndProvider(repo_owner, repo_name) release = provider.get_release(version) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 91c1fe705..62079e6a7 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -79,6 +79,7 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") +fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 87d32a66e..11bc57789 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -110,9 +110,13 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: + embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") - elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: + if embed_shape == 729: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif embed_shape == 1024: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") + elif embed_shape == 577: if "multi_modal_projector.linear_1.bias" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") else: diff --git a/comfy/clip_vision_siglip_512.json b/comfy/clip_vision_siglip_512.json new file mode 100644 index 000000000..7fb93ce15 --- /dev/null +++ b/comfy/clip_vision_siglip_512.json @@ -0,0 +1,13 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 512, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 16, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 622c1df54..556c39512 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -456,3 +456,13 @@ class Wan21(LatentFormat): latents_mean = self.latents_mean.to(latent.device, latent.dtype) latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean + +class Hunyuan3Dv2(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 0.9990943042622529 + +class Hunyuan3Dv2mini(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + scale_factor = 1.0188137142395404 diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index c0cbd2914..3e0978176 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: q_shape = q.shape k_shape = k.shape - q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) - k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) + if pe is not None: + q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) + k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) + q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) + k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index cc34f7585..ef4ba4106 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -115,8 +115,11 @@ class Flux(nn.Module): vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) txt = self.txt_in(txt) - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) + if img_ids is not None: + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + else: + pe = None blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.double_blocks): diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py new file mode 100644 index 000000000..4e18358f0 --- /dev/null +++ b/comfy/ldm/hunyuan3d/model.py @@ -0,0 +1,135 @@ +import torch +from torch import nn +from comfy.ldm.flux.layers import ( + DoubleStreamBlock, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +class Hunyuan3Dv2(nn.Module): + def __init__( + self, + in_channels=64, + context_in_dim=1536, + hidden_size=1024, + mlp_ratio=4.0, + num_heads=16, + depth=16, + depth_single_blocks=32, + qkv_bias=True, + guidance_embed=False, + image_model=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dtype = dtype + + if hidden_size % num_heads != 0: + raise ValueError( + f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}" + ) + + self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead + self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None + ) + self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device) + self.double_blocks = nn.ModuleList( + [ + DoubleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth) + ] + ) + self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + dtype=dtype, device=device, operations=operations + ) + for _ in range(depth_single_blocks) + ] + ) + self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs): + x = x.movedim(-1, -2) + timestep = 1.0 - timestep + txt = context + img = self.latent_in(x) + + vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype)) + if self.guidance_in is not None: + if guidance is not None: + vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype)) + + txt = self.cond_in(txt) + pe = None + attn_mask = None + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.double_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"], out["txt"] = block(img=args["img"], + txt=args["txt"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("double_block", i)]({"img": img, + "txt": txt, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + txt = out["txt"] + img = out["img"] + else: + img, txt = block(img=img, + txt=txt, + vec=vec, + pe=pe, + attn_mask=attn_mask) + + img = torch.cat((txt, img), 1) + + for i, block in enumerate(self.single_blocks): + if ("single_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], + vec=args["vec"], + pe=args["pe"], + attn_mask=args.get("attn_mask")) + return out + + out = blocks_replace[("single_block", i)]({"img": img, + "vec": vec, + "pe": pe, + "attn_mask": attn_mask}, + {"original_block": block_wrap}) + img = out["img"] + else: + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + + img = img[:, txt.shape[1]:, ...] + img = self.final_layer(img, vec) + return img.movedim(-2, -1) * (-1.0) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py new file mode 100644 index 000000000..5eb2c6548 --- /dev/null +++ b/comfy/ldm/hunyuan3d/vae.py @@ -0,0 +1,587 @@ +# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py +# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from typing import Union, Tuple, List, Callable, Optional + +import numpy as np +from einops import repeat, rearrange +from tqdm import tqdm +import logging + +import comfy.ops +ops = comfy.ops.disable_weight_init + +def generate_dense_grid_points( + bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_resolution: int, + indexing: str = "ij", +): + length = bbox_max - bbox_min + num_cells = octree_resolution + + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + + +class VanillaVolumeDecoder: + @torch.no_grad() + def __call__( + self, + latents: torch.FloatTensor, + geo_decoder: Callable, + bounds: Union[Tuple[float], List[float], float] = 1.01, + num_chunks: int = 10000, + octree_resolution: int = None, + enable_pbar: bool = True, + **kwargs, + ): + device = latents.device + dtype = latents.dtype + batch_size = latents.shape[0] + + # 1. generate query points + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_resolution=octree_resolution, + indexing="ij" + ) + xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + + # 2. latents to 3d volume + batch_logits = [] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", + disable=not enable_pbar): + chunk_queries = xyz_samples[start: start + num_chunks, :] + chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) + logits = geo_decoder(queries=chunk_queries, latents=latents) + batch_logits.append(logits) + + grid_logits = torch.cat(batch_logits, dim=1) + grid_logits = grid_logits.view((batch_size, *grid_size)).float() + + return grid_logits + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class CrossAttentionProcessor: + def __call__(self, attn, q, k, v): + out = F.scaled_dot_product_attention(q, k, v) + return out + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if self.drop_prob == 0. or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob, 3):0.3f}' + + +class MLP(nn.Module): + def __init__( + self, *, + width: int, + expand_ratio: int = 4, + output_width: int = None, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.c_fc = ops.Linear(width, width * expand_ratio) + self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width) + self.gelu = nn.GELU() + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.attn_processor = CrossAttentionProcessor() + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = self.attn_processor(self, q, k, v) + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + data_width: Optional[int] = None, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + kv_cache: bool = False, + ): + super().__init__() + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = ops.Linear(width, width, bias=qkv_bias) + self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadCrossAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.kv_cache = kv_cache + self.data = None + + def forward(self, x, data): + x = self.c_q(x) + if self.kv_cache: + if self.data is None: + self.data = self.c_kv(data) + logging.info('Save kv cache,this should be called only once for one mesh') + data = self.data + else: + data = self.c_kv(data) + x = self.attention(x, data) + x = self.c_proj(x) + return x + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + data_width: Optional[int] = None, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6) + self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__( + self, + *, + heads: int, + width=None, + qk_norm=False, + norm_layer=ops.LayerNorm + ): + super().__init__() + self.heads = heads + self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + q = self.q_norm(q) + k = self.k_norm(k) + + q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) + return out + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.heads = heads + self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) + self.c_proj = ops.Linear(width, width) + self.attention = QKVMultiheadAttention( + heads=heads, + width=width, + norm_layer=norm_layer, + qk_norm=qk_norm + ) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward(self, x): + x = self.c_qkv(x) + x = self.attention(x) + x = self.drop_path(self.c_proj(x)) + return x + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + width: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.attn = MultiheadAttention( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(width=width, drop_path_rate=drop_path_rate) + self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + norm_layer=ops.LayerNorm, + qk_norm: bool = False, + drop_path_rate: float = 0.0 + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + width=width, + heads=heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class CrossAttentionDecoder(nn.Module): + + def __init__( + self, + *, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + mlp_expand_ratio: int = 4, + downsample_ratio: int = 1, + enable_ln_post: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary" + ): + super().__init__() + + self.enable_ln_post = enable_ln_post + self.fourier_embedder = fourier_embedder + self.downsample_ratio = downsample_ratio + self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) + if self.downsample_ratio != 1: + self.latents_proj = ops.Linear(width * downsample_ratio, width) + if self.enable_ln_post == False: + qk_norm = False + self.cross_attn_decoder = ResidualCrossAttentionBlock( + width=width, + mlp_expand_ratio=mlp_expand_ratio, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm + ) + + if self.enable_ln_post: + self.ln_post = ops.LayerNorm(width) + self.output_proj = ops.Linear(width, out_channels) + self.label_type = label_type + self.count = 0 + + def forward(self, queries=None, query_embeddings=None, latents=None): + if query_embeddings is None: + query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype)) + self.count += query_embeddings.shape[1] + if self.downsample_ratio != 1: + latents = self.latents_proj(latents) + x = self.cross_attn_decoder(query_embeddings, latents) + if self.enable_ln_post: + x = self.ln_post(x) + occ = self.output_proj(x) + return occ + + +class ShapeVAE(nn.Module): + def __init__( + self, + *, + embed_dim: int, + width: int, + heads: int, + num_decoder_layers: int, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + include_pi: bool = True, + qkv_bias: bool = True, + qk_norm: bool = False, + label_type: str = "binary", + drop_path_rate: float = 0.0, + scale_factor: float = 1.0, + ): + super().__init__() + self.geo_decoder_ln_post = geo_decoder_ln_post + + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + self.post_kl = ops.Linear(embed_dim, width) + + self.transformer = Transformer( + width=width, + layers=num_decoder_layers, + heads=heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + drop_path_rate=drop_path_rate + ) + + self.geo_decoder = CrossAttentionDecoder( + fourier_embedder=self.fourier_embedder, + out_channels=1, + mlp_expand_ratio=geo_decoder_mlp_expand_ratio, + downsample_ratio=geo_decoder_downsample_ratio, + enable_ln_post=self.geo_decoder_ln_post, + width=width // geo_decoder_downsample_ratio, + heads=heads // geo_decoder_downsample_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + label_type=label_type, + ) + + self.volume_decoder = VanillaVolumeDecoder() + self.scale_factor = scale_factor + + def decode(self, latents, **kwargs): + latents = self.post_kl(latents.movedim(-2, -1)) + latents = self.transformer(latents) + + bounds = kwargs.get("bounds", 1.01) + num_chunks = kwargs.get("num_chunks", 8000) + octree_resolution = kwargs.get("octree_resolution", 256) + enable_pbar = kwargs.get("enable_pbar", True) + + grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) + return grid_logits.movedim(-2, -1) + + def encode(self, x): + return None diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7908d1313..45f9e311e 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -471,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): if skip_reshape: b, _, _, dim_head = q.shape - tensor_layout="HND" + tensor_layout = "HND" else: b, _, dim_head = q.shape dim_head //= heads @@ -479,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.view(b, -1, heads, dim_head), (q, k, v), ) - tensor_layout="NHD" + tensor_layout = "NHD" if mask is not None: # add a batch dimension if there isn't already one @@ -489,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= if mask.ndim == 3: mask = mask.unsqueeze(1) - out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + try: + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + except Exception as e: + logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + if tensor_layout == "NHD": + q, k, v = map( + lambda t: t.transpose(1, 2), + (q, k, v), + ) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape) + if tensor_layout == "HND": if not skip_output_reshape: out = ( @@ -837,6 +847,7 @@ class SpatialTransformer(nn.Module): if not isinstance(context, list): context = [context] * len(self.transformer_blocks) b, c, h, w = x.shape + transformer_options["activations_shape"] = list(x.shape) x_in = x x = self.norm(x) if not self.use_linear: @@ -952,6 +963,7 @@ class SpatialVideoTransformer(SpatialTransformer): transformer_options={} ) -> torch.Tensor: _, _, h, w = x.shape + transformer_options["activations_shape"] = list(x.shape) x_in = x spatial_context = None if exists(context): diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 05032c690..3e00b63db 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -1,4 +1,5 @@ import torch +import comfy.utils def convert_lora_bfl_control(sd): #BFL loras for Flux @@ -11,7 +12,13 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux return sd_out +def convert_lora_wan_fun(sd): #Wan Fun loras + return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"}) + + def convert_lora(sd): if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd: return convert_lora_bfl_control(sd) + if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd: + return convert_lora_wan_fun(sd) return sd diff --git a/comfy/model_base.py b/comfy/model_base.py index 976702b60..6bc627ae3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -36,6 +36,7 @@ import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model import comfy.ldm.lumina.model import comfy.ldm.wan.model +import comfy.ldm.hunyuan3d.model import comfy.model_management import comfy.patcher_extension @@ -58,6 +59,7 @@ class ModelType(Enum): FLOW = 6 V_PREDICTION_CONTINUOUS = 7 FLUX = 8 + IMG_TO_IMG = 9 from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV @@ -88,6 +90,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLUX: c = comfy.model_sampling.CONST s = comfy.model_sampling.ModelSamplingFlux + elif model_type == ModelType.IMG_TO_IMG: + c = comfy.model_sampling.IMG_TO_IMG class ModelSampling(s, c): pass @@ -139,6 +143,7 @@ class BaseModel(torch.nn.Module): def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.model_sampling.calculate_input(sigma, x) + if c_concat is not None: xc = torch.cat([xc] + [c_concat], dim=1) @@ -600,6 +605,19 @@ class SDXL_instructpix2pix(IP2P, SDXL): else: self.process_ip2p_image_in = lambda image: image #diffusers ip2p +class Lotus(BaseModel): + def extra_conds(self, **kwargs): + out = {} + cross_attn = kwargs.get("cross_attn", None) + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + device = kwargs["device"] + task_emb = torch.tensor([1, 0]).float().to(device) + task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0) + out['y'] = comfy.conds.CONDRegular(task_emb) + return out + + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None): + super().__init__(model_config, model_type, device=device) class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): @@ -974,31 +992,41 @@ class WAN21(BaseModel): def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) - if self.diffusion_model.patch_embedding.weight.shape[1] == noise.shape[1]: + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels == 0: return None image = kwargs.get("concat_latent_image", None) device = kwargs["device"] if image is None: - image = torch.zeros_like(noise) + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], 16): + image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16]) + image = utils.resize_to_batch_size(image, noise.shape[0]) - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - image = self.process_latent_in(image) - image = utils.resize_to_batch_size(image, noise.shape[0]) - - if not self.image_to_video: + if not self.image_to_video or extra_channels == image.shape[1]: return image + if image.shape[1] > (extra_channels - 4): + image = image[:, :(extra_channels - 4)] + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.zeros_like(noise)[:, :4] else: - mask = 1.0 - torch.mean(mask, dim=1, keepdim=True) + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) - mask = mask.repeat(1, 4, 1, 1, 1) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) return torch.cat((mask, image), dim=1) @@ -1013,3 +1041,18 @@ class WAN21(BaseModel): if clip_vision_output is not None: out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) return out + +class Hunyuan3Dv2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + guidance = kwargs.get("guidance", 5.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 403da5855..4217f5831 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -154,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux + if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux dit_config = {} dit_config["image_model"] = "flux" dit_config["in_channels"] = 16 @@ -323,6 +323,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "t2v" return dit_config + if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D + in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape + dit_config = {} + dit_config["image_model"] = "hunyuan3d2" + dit_config["in_channels"] = in_shape[1] + dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1] + dit_config["hidden_size"] = in_shape[0] + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 16 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') + dit_config["qkv_bias"] = True + dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -667,8 +682,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'use_temporal_attention': False, 'use_temporal_resblock': False} + LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4, + 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8, + 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] + supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] for unet_config in supported_models: matches = True diff --git a/comfy/model_management.py b/comfy/model_management.py index 2a9b022be..19e6c8dff 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -46,6 +46,32 @@ cpu_state = CPUState.GPU total_vram = 0 +def get_supported_float8_types(): + float8_types = [] + try: + float8_types.append(torch.float8_e4m3fn) + except: + pass + try: + float8_types.append(torch.float8_e4m3fnuz) + except: + pass + try: + float8_types.append(torch.float8_e5m2) + except: + pass + try: + float8_types.append(torch.float8_e5m2fnuz) + except: + pass + try: + float8_types.append(torch.float8_e8m0fnu) + except: + pass + return float8_types + +FLOAT8_TYPES = get_supported_float8_types() + xpu_available = False torch_version = "" try: @@ -701,11 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor return torch.float8_e5m2 fp8_dtype = None - try: - if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - fp8_dtype = weight_dtype - except: - pass + if weight_dtype in FLOAT8_TYPES: + fp8_dtype = weight_dtype if fp8_dtype is not None: if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive @@ -800,6 +823,8 @@ def text_encoder_dtype(device=None): return torch.float8_e5m2 elif args.fp16_text_enc: return torch.float16 + elif args.bf16_text_enc: + return torch.bfloat16 elif args.fp32_text_enc: return torch.float32 @@ -1212,6 +1237,8 @@ def soft_empty_cache(force=False): torch.xpu.empty_cache() elif is_ascend_npu(): torch.npu.empty_cache() + elif is_mlu(): + torch.mlu.empty_cache() elif torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index ff27b09a8..b79af1e92 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -69,6 +69,15 @@ class CONST: sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1)) return latent / (1.0 - sigma) +class X0(EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + +class IMG_TO_IMG(X0): + def calculate_input(self, sigma, noise): + return noise + + class ModelSamplingDiscrete(torch.nn.Module): def __init__(self, model_config=None, zsnr=None): super().__init__() diff --git a/comfy/sd.py b/comfy/sd.py index 3d72a04d6..4d3aef3e1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.hunyuan3d.vae import yaml import math @@ -264,6 +265,7 @@ class VAE: self.process_input = lambda image: image * 2.0 - 1.0 self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] + self.disable_offload = False self.downscale_index_formula = None self.upscale_index_formula = None @@ -336,6 +338,7 @@ class VAE: self.process_output = lambda audio: audio self.process_input = lambda audio: audio self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.disable_offload = True elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae if "blocks.2.blocks.3.stack.5.weight" in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."}) @@ -412,6 +415,17 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: + self.latent_dim = 1 + ln_post = "geo_decoder.ln_post.weight" in sd + inner_size = sd["geo_decoder.output_proj.weight"].shape[1] + downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size + mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size + self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO + self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO + ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} + self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -498,19 +512,19 @@ class VAE: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - def decode(self, samples_in): + def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) + out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out @@ -532,7 +546,7 @@ class VAE: def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) dims = samples.ndim - 2 args = {} if tile_x is not None: @@ -566,7 +580,7 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) @@ -600,7 +614,7 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile - model_management.load_models_gpu([self.patcher], memory_required=memory_used) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) args = {} if tile_x is not None: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index b4d7bfe20..2a6a61560 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -506,6 +506,22 @@ class SDXL_instructpix2pix(SDXL): def get_model(self, state_dict, prefix="", device=None): return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device) +class LotusD(SD20): + unet_config = { + "model_channels": 320, + "use_linear_in_transformer": True, + "use_temporal_attention": False, + "adm_in_channels": 4, + "in_channels": 4, + } + + unet_extra_config = { + "num_classes": 'sequential' + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Lotus(self, device=device) + class SD3(supported_models_base.BASE): unet_config = { "in_channels": 16, @@ -953,12 +969,62 @@ class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", "model_type": "i2v", + "in_dim": 36, } def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21(self, image_to_video=True, device=device) return out -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V] +class WAN21_FunControl2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "i2v", + "in_dim": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21(self, image_to_video=False, device=device) + return out + +class Hunyuan3Dv2(supported_models_base.BASE): + unet_config = { + "image_model": "hunyuan3d2", + } + + unet_extra_config = {} + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.0, + } + + memory_usage_factor = 3.5 + + clip_vision_prefix = "conditioner.main_image_encoder.model." + vae_key_prefix = ["vae."] + + latent_format = latent_formats.Hunyuan3Dv2 + + def process_unet_state_dict_for_saving(self, state_dict): + replace_prefix = {"": "model."} + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Hunyuan3Dv2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None + +class Hunyuan3Dv2mini(Hunyuan3Dv2): + unet_config = { + "image_model": "hunyuan3d2", + "depth": 8, + } + + latent_format = latent_formats.Hunyuan3Dv2mini + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, Hunyuan3Dv2mini, Hunyuan3Dv2] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py new file mode 100644 index 000000000..1fb686644 --- /dev/null +++ b/comfy_extras/nodes_cfg.py @@ -0,0 +1,45 @@ +import torch + +# https://github.com/WeichenFan/CFG-Zero-star +def optimized_scale(positive, negative): + positive_flat = positive.reshape(positive.shape[0], -1) + negative_flat = negative.reshape(negative.shape[0], -1) + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1)) + +class CFGZeroStar: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("patched_model",) + FUNCTION = "patch" + CATEGORY = "advanced/guidance" + + def patch(self, model): + m = model.clone() + def cfg_zero_star(args): + guidance_scale = args['cond_scale'] + x = args['input'] + cond_p = args['cond_denoised'] + uncond_p = args['uncond_denoised'] + out = args["denoised"] + alpha = optimized_scale(x - cond_p, x - uncond_p) + + return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha) + m.set_model_sampler_post_cfg_function(cfg_zero_star) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "CFGZeroStar": CFGZeroStar +} diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py new file mode 100644 index 000000000..5adc6b654 --- /dev/null +++ b/comfy_extras/nodes_hunyuan3d.py @@ -0,0 +1,415 @@ +import torch +import os +import json +import struct +import numpy as np +from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch +import folder_paths +import comfy.model_management +from comfy.cli_args import args + + +class EmptyLatentHunyuan3Dv2: + @classmethod + def INPUT_TYPES(s): + return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/3d" + + def generate(self, resolution, batch_size): + latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) + return ({"samples": latent, "type": "hunyuan3dv2"}, ) + + +class Hunyuan3Dv2Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, clip_vision_output): + embeds = clip_vision_output.last_hidden_state + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class Hunyuan3Dv2ConditioningMultiView: + @classmethod + def INPUT_TYPES(s): + return {"required": {}, + "optional": {"front": ("CLIP_VISION_OUTPUT",), + "left": ("CLIP_VISION_OUTPUT",), + "back": ("CLIP_VISION_OUTPUT",), + "right": ("CLIP_VISION_OUTPUT",), }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING") + RETURN_NAMES = ("positive", "negative") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, front=None, left=None, back=None, right=None): + all_embeds = [front, left, back, right] + out = [] + pos_embeds = None + for i, e in enumerate(all_embeds): + if e is not None: + if pos_embeds is None: + pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4)) + out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1)) + + embeds = torch.cat(out, dim=1) + positive = [[embeds, {}]] + negative = [[torch.zeros_like(embeds), {}]] + return (positive, negative) + + +class VOXEL: + def __init__(self, data): + self.data = data + + +class VAEDecodeHunyuan3D: + @classmethod + def INPUT_TYPES(s): + return {"required": {"samples": ("LATENT", ), + "vae": ("VAE", ), + "num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}), + "octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}), + }} + RETURN_TYPES = ("VOXEL",) + FUNCTION = "decode" + + CATEGORY = "latent/3d" + + def decode(self, vae, samples, num_chunks, octree_resolution): + voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) + return (voxels, ) + + +def voxel_to_mesh(voxels, threshold=0.5, device=None): + if device is None: + device = torch.device("cpu") + voxels = voxels.to(device) + + binary = (voxels > threshold).float() + padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0) + + D, H, W = binary.shape + + neighbors = torch.tensor([ + [0, 0, 1], + [0, 0, -1], + [0, 1, 0], + [0, -1, 0], + [1, 0, 0], + [-1, 0, 0] + ], device=device) + + z, y, x = torch.meshgrid( + torch.arange(D, device=device), + torch.arange(H, device=device), + torch.arange(W, device=device), + indexing='ij' + ) + voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1) + + solid_mask = binary.flatten() > 0 + solid_indices = voxel_indices[solid_mask] + + corner_offsets = [ + torch.tensor([ + [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1] + ], device=device), + torch.tensor([ + [0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0] + ], device=device), + torch.tensor([ + [1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0] + ], device=device), + torch.tensor([ + [0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0] + ], device=device) + ] + + all_vertices = [] + all_indices = [] + + vertex_count = 0 + + for face_idx, offset in enumerate(neighbors): + neighbor_indices = solid_indices + offset + + padded_indices = neighbor_indices + 1 + + is_exposed = padded[ + padded_indices[:, 0], + padded_indices[:, 1], + padded_indices[:, 2] + ] == 0 + + if not is_exposed.any(): + continue + + exposed_indices = solid_indices[is_exposed] + + corners = corner_offsets[face_idx].unsqueeze(0) + + face_vertices = exposed_indices.unsqueeze(1) + corners + + all_vertices.append(face_vertices.reshape(-1, 3)) + + num_faces = exposed_indices.shape[0] + face_indices = torch.arange( + vertex_count, + vertex_count + 4 * num_faces, + device=device + ).reshape(-1, 4) + + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1)) + all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1)) + + vertex_count += 4 * num_faces + + if len(all_vertices) > 0: + vertices = torch.cat(all_vertices, dim=0) + faces = torch.cat(all_indices, dim=0) + else: + vertices = torch.zeros((1, 3)) + faces = torch.zeros((1, 3)) + + v_min = 0 + v_max = max(voxels.shape) + + vertices = vertices - (v_min + v_max) / 2 + + scale = (v_max - v_min) / 2 + if scale > 0: + vertices = vertices / scale + + vertices = torch.fliplr(vertices) + return vertices, faces + + +class MESH: + def __init__(self, vertices, faces): + self.vertices = vertices + self.faces = faces + + +class VoxelToMeshBasic: + @classmethod + def INPUT_TYPES(s): + return {"required": {"voxel": ("VOXEL", ), + "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MESH",) + FUNCTION = "decode" + + CATEGORY = "3d" + + def decode(self, voxel, threshold): + vertices = [] + faces = [] + for x in voxel.data: + v, f = voxel_to_mesh(x, threshold=threshold, device=None) + vertices.append(v) + faces.append(f) + + return (MESH(torch.stack(vertices), torch.stack(faces)), ) + + +def save_glb(vertices, faces, filepath, metadata=None): + """ + Save PyTorch tensor vertices and faces as a GLB file without external dependencies. + + Parameters: + vertices: torch.Tensor of shape (N, 3) - The vertex coordinates + faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces) + filepath: str - Output filepath (should end with .glb) + """ + + # Convert tensors to numpy arrays + vertices_np = vertices.cpu().numpy().astype(np.float32) + faces_np = faces.cpu().numpy().astype(np.uint32) + + vertices_buffer = vertices_np.tobytes() + indices_buffer = faces_np.tobytes() + + def pad_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b'\x00' * padding_length + + vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) + indices_buffer_padded = pad_to_4_bytes(indices_buffer) + + buffer_data = vertices_buffer_padded + indices_buffer_padded + + vertices_byte_length = len(vertices_buffer) + vertices_byte_offset = 0 + indices_byte_length = len(indices_buffer) + indices_byte_offset = len(vertices_buffer_padded) + + gltf = { + "asset": {"version": "2.0", "generator": "ComfyUI"}, + "buffers": [ + { + "byteLength": len(buffer_data) + } + ], + "bufferViews": [ + { + "buffer": 0, + "byteOffset": vertices_byte_offset, + "byteLength": vertices_byte_length, + "target": 34962 # ARRAY_BUFFER + }, + { + "buffer": 0, + "byteOffset": indices_byte_offset, + "byteLength": indices_byte_length, + "target": 34963 # ELEMENT_ARRAY_BUFFER + } + ], + "accessors": [ + { + "bufferView": 0, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(vertices_np), + "type": "VEC3", + "max": vertices_np.max(axis=0).tolist(), + "min": vertices_np.min(axis=0).tolist() + }, + { + "bufferView": 1, + "byteOffset": 0, + "componentType": 5125, # UNSIGNED_INT + "count": faces_np.size, + "type": "SCALAR" + } + ], + "meshes": [ + { + "primitives": [ + { + "attributes": { + "POSITION": 0 + }, + "indices": 1, + "mode": 4 # TRIANGLES + } + ] + } + ], + "nodes": [ + { + "mesh": 0 + } + ], + "scenes": [ + { + "nodes": [0] + } + ], + "scene": 0 + } + + if metadata is not None: + gltf["asset"]["extras"] = metadata + + # Convert the JSON to bytes + gltf_json = json.dumps(gltf).encode('utf8') + + def pad_json_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b' ' * padding_length + + gltf_json_padded = pad_json_to_4_bytes(gltf_json) + + # Create the GLB header + # Magic glTF + glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) + + # Create JSON chunk header (chunk type 0) + json_chunk_header = struct.pack(' 0: - output_images = [] - for i in range(image.shape[0]): - output_images.append(preprocess(image[i], img_compression)) + output_images = [] + for i in range(image.shape[0]): + output_images.append(preprocess(image[i], img_compression)) return (torch.stack(output_images),) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 63fd13b9a..13d2b4bab 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -2,6 +2,7 @@ import numpy as np import scipy.ndimage import torch import comfy.utils +import node_helpers from nodes import MAX_RESOLUTION @@ -87,6 +88,7 @@ class ImageCompositeMasked: CATEGORY = "image" def composite(self, destination, source, x, y, resize_source, mask = None): + destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) return (output,) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index ceac5654b..71a652ffa 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -20,10 +20,6 @@ class LCM(comfy.model_sampling.EPS): return c_out * x0 + c_skip * model_input -class X0(comfy.model_sampling.EPS): - def calculate_denoised(self, sigma, model_output, model_input): - return model_output - class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): original_timesteps = 50 @@ -56,7 +52,7 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0"],), + "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), "zsnr": ("BOOLEAN", {"default": False}), }} @@ -77,7 +73,9 @@ class ModelSamplingDiscrete: sampling_type = LCM sampling_base = ModelSamplingDiscreteDistilled elif sampling == "x0": - sampling_type = X0 + sampling_type = comfy.model_sampling.X0 + elif sampling == "img_to_img": + sampling_type = comfy.model_sampling.IMG_TO_IMG class ModelSamplingAdvanced(sampling_base, sampling_type): pass diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index 3e37f70d4..dc3411947 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb." + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["patch_embedding."] = argument + arg_dict["time_embedding."] = argument + arg_dict["time_projection."] = argument + arg_dict["text_embedding."] = argument + arg_dict["img_emb."] = argument + + for i in range(40): + arg_dict["blocks.{}.".format(i)] = argument + + arg_dict["head."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks @@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeLTXV": ModelMergeLTXV, "ModelMergeCosmos7B": ModelMergeCosmos7B, "ModelMergeCosmos14B": ModelMergeCosmos14B, + "ModelMergeWAN2_1": ModelMergeWAN2_1, } diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index b1372b8ce..075b26c40 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -2,6 +2,7 @@ import torch import comfy.model_management from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat +import kornia.color class Morphology: @@ -40,8 +41,45 @@ class Morphology: img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) return (img_out,) + +class ImageRGBToYUV: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE") + RETURN_NAMES = ("Y", "U", "V") + FUNCTION = "execute" + + CATEGORY = "image/batch" + + def execute(self, image): + out = kornia.color.rgb_to_ycbcr(image.movedim(-1, 1)).movedim(1, -1) + return (out[..., 0:1].expand_as(image), out[..., 1:2].expand_as(image), out[..., 2:3].expand_as(image)) + +class ImageYUVToRGB: + @classmethod + def INPUT_TYPES(s): + return {"required": {"Y": ("IMAGE",), + "U": ("IMAGE",), + "V": ("IMAGE",), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "execute" + + CATEGORY = "image/batch" + + def execute(self, Y, U, V): + image = torch.cat([torch.mean(Y, dim=-1, keepdim=True), torch.mean(U, dim=-1, keepdim=True), torch.mean(V, dim=-1, keepdim=True)], dim=-1) + out = kornia.color.ycbcr_to_rgb(image.movedim(-1, 1)).movedim(1, -1) + return (out,) + NODE_CLASS_MAPPINGS = { "Morphology": Morphology, + "ImageRGBToYUV": ImageRGBToYUV, + "ImageYUVToRGB": ImageYUVToRGB, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 68f6ef51e..5b9542015 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -6,7 +6,7 @@ import math import comfy.utils import comfy.model_management - +import node_helpers class Blend: def __init__(self): @@ -34,6 +34,7 @@ class Blend: CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py new file mode 100644 index 000000000..b770104fb --- /dev/null +++ b/comfy_extras/nodes_primitive.py @@ -0,0 +1,79 @@ +# Primitive nodes that are evaluated at backend. +from __future__ import annotations + +from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO + + +class String(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.STRING, {})}, + } + + RETURN_TYPES = (IO.STRING,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: str) -> tuple[str]: + return (value,) + + +class Int(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.INT, {"control_after_generate": True})}, + } + + RETURN_TYPES = (IO.INT,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: int) -> tuple[int]: + return (value,) + + +class Float(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.FLOAT, {})}, + } + + RETURN_TYPES = (IO.FLOAT,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: float) -> tuple[float]: + return (value,) + + +class Boolean(ComfyNodeABC): + @classmethod + def INPUT_TYPES(cls) -> InputTypeDict: + return { + "required": {"value": (IO.BOOLEAN, {})}, + } + + RETURN_TYPES = (IO.BOOLEAN,) + FUNCTION = "execute" + CATEGORY = "utils/primitive" + + def execute(self, value: bool) -> tuple[bool]: + return (value,) + + +NODE_CLASS_MAPPINGS = { + "PrimitiveString": String, + "PrimitiveInt": Int, + "PrimitiveFloat": Float, + "PrimitiveBoolean": Boolean, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PrimitiveString": "String", + "PrimitiveInt": "Int", + "PrimitiveFloat": "Float", + "PrimitiveBoolean": "Boolean", +} diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index dc30eb546..2d0f31ac8 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -3,6 +3,7 @@ import node_helpers import torch import comfy.model_management import comfy.utils +import comfy.latent_formats class WanImageToVideo: @@ -49,6 +50,110 @@ class WanImageToVideo: return (positive, negative, out_latent) +class WanFunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + +class WanFunInpaintToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), + "start_image": ("IMAGE", ), + "end_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + if end_image is not None: + end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + image = torch.ones((length, height, width, 3)) * 0.5 + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + image[:start_image.shape[0]] = start_image + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + if end_image is not None: + image[-end_image.shape[0]:] = end_image + mask[:, :, -end_image.shape[0]:] = 0.0 + + concat_latent_image = vae.encode(image[:, :, :, :3]) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + NODE_CLASS_MAPPINGS = { "WanImageToVideo": WanImageToVideo, + "WanFunControlToVideo": WanFunControlToVideo, + "WanFunInpaintToVideo": WanFunInpaintToVideo, } diff --git a/comfyui_version.py b/comfyui_version.py index b5e6fbead..705622529 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.26" +__version__ = "0.3.27" diff --git a/node_helpers.py b/node_helpers.py index 48da3b099..c3e1a14ca 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -44,3 +44,11 @@ def string_to_torch_dtype(string): return torch.float16 if string == "bf16": return torch.bfloat16 + +def image_alpha_fix(destination, source): + if destination.shape[-1] < source.shape[-1]: + source = source[...,:destination.shape[-1]] + elif destination.shape[-1] > source.shape[-1]: + destination = torch.nn.functional.pad(destination, (0, 1)) + destination[..., -1] = 1.0 + return destination, source diff --git a/nodes.py b/nodes.py index 71d1b8dd7..218f93256 100644 --- a/nodes.py +++ b/nodes.py @@ -1006,6 +1006,8 @@ class CLIPVisionLoader: def load_clip(self, clip_name): clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name) clip_vision = comfy.clip_vision.load(clip_path) + if clip_vision is None: + raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.") return (clip_vision,) class CLIPVisionEncode: @@ -2264,6 +2266,10 @@ def init_builtin_extra_nodes(): "nodes_video.py", "nodes_lumina2.py", "nodes_wan.py", + "nodes_lotus.py", + "nodes_hunyuan3d.py", + "nodes_primitive.py", + "nodes_cfg.py", ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index f13fed8dc..db9e776cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.26" +version = "0.3.27" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 70689bc99..806fbc751 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.12.14 +comfyui-frontend-package==1.14.6 torch torchsde torchvision From 726fdfcaa09ac639b44b0dd8af6dd446ebb25fe5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 8 Apr 2025 18:46:43 +0800 Subject: [PATCH 06/10] Fix import error --- comfy/weight_adapter/base.py | 61 +++++++++++++++++++++++++++++++++++- comfy/weight_adapter/lora.py | 13 ++++---- 2 files changed, 66 insertions(+), 8 deletions(-) diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 9a7aa9de8..54af3babe 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -1,6 +1,10 @@ +from typing import Optional + import torch import torch.nn as nn +import comfy.model_management + class WeightAdapterBase: name: str @@ -8,7 +12,7 @@ class WeightAdapterBase: weights: list[torch.Tensor] @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor]) -> "WeightAdapterBase" | None: + def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": @@ -33,3 +37,58 @@ class WeightAdapterTrainBase(nn.Module): super().__init__() # [TODO] Collaborate with LoRA training PR #7032 + + +def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): + dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) + lora_diff *= alpha + weight_calc = weight + function(lora_diff).type(weight.dtype) + weight_norm = ( + weight_calc.transpose(0, 1) + .reshape(weight_calc.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) + .transpose(0, 1) + ) + + weight_calc *= (dora_scale / weight_norm).type(weight.dtype) + if strength != 1.0: + weight_calc -= weight + weight += strength * (weight_calc) + else: + weight[:] = weight_calc + return weight + + +def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: + """ + Pad a tensor to a new shape with zeros. + + Args: + tensor (torch.Tensor): The original tensor to be padded. + new_shape (List[int]): The desired shape of the padded tensor. + + Returns: + torch.Tensor: A new tensor padded with zeros to the specified shape. + + Note: + If the new shape is smaller than the original tensor in any dimension, + the original tensor will be truncated in that dimension. + """ + if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): + raise ValueError("The new shape must be larger than the original tensor in all dimensions") + + if len(new_shape) != len(tensor.shape): + raise ValueError("The new shape must have the same number of dimensions as the original tensor") + + # Create a new tensor filled with zeros + padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) + + # Create slicing tuples for both tensors + orig_slices = tuple(slice(0, dim) for dim in tensor.shape) + new_slices = tuple(slice(0, dim) for dim in tensor.shape) + + # Copy the original tensor into the new tensor + padded_tensor[new_slices] = tensor[orig_slices] + + return padded_tensor diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index b79bfd820..9e00e9e65 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -1,11 +1,10 @@ import logging -import torch -import comfy.utils -import comfy.model_management -import comfy.model_base -from comfy.lora import weight_decompose, pad_tensor_to_shape +from optparse import Option +from typing import Optional -from .base import WeightAdapterBase +import torch +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape class LoRAAdapter(WeightAdapterBase): @@ -23,7 +22,7 @@ class LoRAAdapter(WeightAdapterBase): alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, - ) -> "LoRAAdapter" | None: + ) -> Optional["LoRAAdapter"]: if loaded_keys is None: loaded_keys = set() From a220e5ca80d1f886c241a3af28be84e92e762f6f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 8 Apr 2025 18:46:53 +0800 Subject: [PATCH 07/10] Fix typing syntax error --- comfy/weight_adapter/glora.py | 5 +++-- comfy/weight_adapter/loha.py | 5 +++-- comfy/weight_adapter/lokr.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 89574f9fc..206fc6e28 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -1,5 +1,6 @@ -import torch +from typing import Optional +import torch from .base import WeightAdapterBase @@ -18,7 +19,7 @@ class GLoRAAdapter(WeightAdapterBase): alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, - ) -> "GLoRAAdapter" | None: + ) -> Optional["GLoRAAdapter"]: if loaded_keys is None: loaded_keys = set() a1_name = "{}.a1.weight".format(x) diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index c0303925d..7a6cab243 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -1,5 +1,6 @@ -import torch +from typing import Optional +import torch from .base import WeightAdapterBase @@ -18,7 +19,7 @@ class LoHaAdapter(WeightAdapterBase): alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, - ) -> "LoHaAdapter" | None: + ) -> Optional["LoHaAdapter"]: if loaded_keys is None: loaded_keys = set() diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 85bad7ec7..da1b15519 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -1,5 +1,6 @@ -import torch +from typing import Optional +import torch from .base import WeightAdapterBase @@ -18,7 +19,7 @@ class LoKrAdapter(WeightAdapterBase): alpha: float, dora_scale: torch.Tensor, loaded_keys: set[str] = None, - ) -> "LoKrAdapter" | None: + ) -> Optional["LoKrAdapter"]: if loaded_keys is None: loaded_keys = set() lokr_w1_name = "{}.lokr_w1".format(x) From ff050275abf2606c224fbb853e0cf6588725c107 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 8 Apr 2025 18:48:58 +0800 Subject: [PATCH 08/10] Use correct v list --- comfy/weight_adapter/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 9e00e9e65..89f56a45f 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -90,7 +90,7 @@ class LoRAAdapter(WeightAdapterBase): intermediate_dtype=torch.float32, original_weight=None, ): - v = self.weights[1] + v = self.weights mat1 = comfy.model_management.cast_to_device( v[0], weight.device, intermediate_dtype ) From 889f94773a11ef870ef5b14780c56394e5079e2b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 8 Apr 2025 22:01:43 +0800 Subject: [PATCH 09/10] Remove unused import --- comfy/weight_adapter/lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 89f56a45f..b2e623924 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -1,5 +1,4 @@ import logging -from optparse import Option from typing import Optional import torch From e8f3bc5ab766ba0149ac8ba8f67c328f8bb3da7d Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 9 Apr 2025 09:16:52 +0800 Subject: [PATCH 10/10] Finalize the modularized weight adapter impl * LoRA/LoHa/LoKr/GLoRA working well * Removed TONS of code in lora.py --- comfy/lora.py | 180 +--------------------------------- comfy/weight_adapter/glora.py | 49 ++++++++- comfy/weight_adapter/loha.py | 45 ++++++++- comfy/weight_adapter/lokr.py | 54 +++++++++- 4 files changed, 143 insertions(+), 185 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index ab053e7d3..8760a21fb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -282,26 +282,6 @@ def model_lora_keys_unet(model, key_map={}): return key_map -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) - lora_diff *= alpha - weight_calc = weight + function(lora_diff).type(weight.dtype) - weight_norm = ( - weight_calc.transpose(0, 1) - .reshape(weight_calc.shape[1], -1) - .norm(dim=1, keepdim=True) - .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1)) - .transpose(0, 1) - ) - - weight_calc *= (dora_scale / weight_norm).type(weight.dtype) - if strength != 1.0: - weight_calc -= weight - weight += strength * (weight_calc) - else: - weight[:] = weight_calc - return weight - def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor: """ Pad a tensor to a new shape with zeros. @@ -358,14 +338,13 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori if isinstance(v, weight_adapter.WeightAdapterBase): output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights) - if output is not None: + if output is None: + logging.warning("Calculate Weight Failed: {} {}".format(v.name, key)) + else: weight = output if old_weight is not None: weight = old_weight - continue - else: - #Fallback when calculate_weight haven't implemented - v = v.weights + continue if len(v) == 1: patch_type = "diff" @@ -393,157 +372,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \ comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype) weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype)) - elif patch_type == "lora": #lora/locon - mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) - mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) - dora_scale = v[4] - reshape = v[5] - - if reshape is not None: - weight = pad_tensor_to_shape(weight, reshape) - - if v[2] is not None: - alpha = v[2] / mat2.shape[0] - else: - alpha = 1.0 - - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "lokr": - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dora_scale = v[8] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) - else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) - else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha = v[2] / dim - else: - alpha = 1.0 - - try: - lora_diff = torch.kron(w1, w2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "loha": - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha = v[2] / w1b.shape[0] - else: - alpha = 1.0 - - w2a = v[3] - w2b = v[4] - dora_scale = v[7] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) - - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) - else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) - - try: - lora_diff = (m1 * m2).reshape(weight.shape) - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "glora": - dora_scale = v[5] - - old_glora = False - if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: - rank = v[0].shape[0] - old_glora = True - - if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: - pass - else: - old_glora = False - rank = v[1].shape[0] - - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) - - if v[4] is not None: - alpha = v[4] / rank - else: - alpha = 1.0 - - try: - if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora - else: - if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) - lora_diff += torch.mm(b1, b2).reshape(weight.shape) - - if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: logging.warning("patch type not recognized {} {}".format(patch_type, key)) diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 206fc6e28..939abbba5 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -1,7 +1,9 @@ +import logging from typing import Optional import torch -from .base import WeightAdapterBase +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose class GLoRAAdapter(WeightAdapterBase): @@ -27,7 +29,7 @@ class GLoRAAdapter(WeightAdapterBase): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - weights = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) + weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) @@ -47,4 +49,45 @@ class GLoRAAdapter(WeightAdapterBase): intermediate_dtype=torch.float32, original_weight=None, ): - pass + v = self.weights + dora_scale = v[5] + + old_glora = False + if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]: + rank = v[0].shape[0] + old_glora = True + + if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: + if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + pass + else: + old_glora = False + rank = v[1].shape[0] + + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + + if v[4] is not None: + alpha = v[4] / rank + else: + alpha = 1.0 + + try: + if old_glora: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + else: + if weight.dim() > 2: + lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + else: + lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff += torch.mm(b1, b2).reshape(weight.shape) + + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 7a6cab243..ce79abad5 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -1,7 +1,9 @@ +import logging from typing import Optional import torch -from .base import WeightAdapterBase +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose class LoHaAdapter(WeightAdapterBase): @@ -38,7 +40,7 @@ class LoHaAdapter(WeightAdapterBase): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - weights = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) + weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -58,4 +60,41 @@ class LoHaAdapter(WeightAdapterBase): intermediate_dtype=torch.float32, original_weight=None, ): - pass + v = self.weights + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + + w2a = v[3] + w2b = v[4] + dora_scale = v[7] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + else: + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + + try: + lora_diff = (m1 * m2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index da1b15519..51233db2d 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -1,7 +1,9 @@ +import logging from typing import Optional import torch -from .base import WeightAdapterBase +import comfy.model_management +from .base import WeightAdapterBase, weight_decompose class LoKrAdapter(WeightAdapterBase): @@ -66,7 +68,7 @@ class LoKrAdapter(WeightAdapterBase): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - weights = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) + weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) return cls(loaded_keys, weights) else: return None @@ -82,4 +84,50 @@ class LoKrAdapter(WeightAdapterBase): intermediate_dtype=torch.float32, original_weight=None, ): - pass + v = self.weights + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dora_scale = v[8] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + else: + w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + else: + w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha = v[2] / dim + else: + alpha = 1.0 + + try: + lora_diff = torch.kron(w1, w2).reshape(weight.shape) + if dora_scale is not None: + weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(self.name, key, e)) + return weight