mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-16 08:33:29 +00:00
Fix import error
This commit is contained in:
parent
c792fad88b
commit
726fdfcaa0
@ -1,6 +1,10 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
class WeightAdapterBase:
|
class WeightAdapterBase:
|
||||||
name: str
|
name: str
|
||||||
@ -8,7 +12,7 @@ class WeightAdapterBase:
|
|||||||
weights: list[torch.Tensor]
|
weights: list[torch.Tensor]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> "WeightAdapterBase" | None:
|
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def to_train(self) -> "WeightAdapterTrainBase":
|
def to_train(self) -> "WeightAdapterTrainBase":
|
||||||
@ -33,3 +37,58 @@ class WeightAdapterTrainBase(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# [TODO] Collaborate with LoRA training PR #7032
|
# [TODO] Collaborate with LoRA training PR #7032
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import torch
|
from optparse import Option
|
||||||
import comfy.utils
|
from typing import Optional
|
||||||
import comfy.model_management
|
|
||||||
import comfy.model_base
|
|
||||||
from comfy.lora import weight_decompose, pad_tensor_to_shape
|
|
||||||
|
|
||||||
from .base import WeightAdapterBase
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||||
|
|
||||||
|
|
||||||
class LoRAAdapter(WeightAdapterBase):
|
class LoRAAdapter(WeightAdapterBase):
|
||||||
@ -23,7 +22,7 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
alpha: float,
|
alpha: float,
|
||||||
dora_scale: torch.Tensor,
|
dora_scale: torch.Tensor,
|
||||||
loaded_keys: set[str] = None,
|
loaded_keys: set[str] = None,
|
||||||
) -> "LoRAAdapter" | None:
|
) -> Optional["LoRAAdapter"]:
|
||||||
if loaded_keys is None:
|
if loaded_keys is None:
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user