mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Support weight padding on diff weight patch (#4576)
This commit is contained in:
parent
ab130001a8
commit
6bbdcd28ae
@ -16,6 +16,7 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
@ -347,6 +348,39 @@ def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediat
|
|||||||
weight[:] = weight_calc
|
weight[:] = weight_calc
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Pad a tensor to a new shape with zeros.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The original tensor to be padded.
|
||||||
|
new_shape (List[int]): The desired shape of the padded tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A new tensor padded with zeros to the specified shape.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If the new shape is smaller than the original tensor in any dimension,
|
||||||
|
the original tensor will be truncated in that dimension.
|
||||||
|
"""
|
||||||
|
if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
|
||||||
|
raise ValueError("The new shape must be larger than the original tensor in all dimensions")
|
||||||
|
|
||||||
|
if len(new_shape) != len(tensor.shape):
|
||||||
|
raise ValueError("The new shape must have the same number of dimensions as the original tensor")
|
||||||
|
|
||||||
|
# Create a new tensor filled with zeros
|
||||||
|
padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
|
||||||
|
|
||||||
|
# Create slicing tuples for both tensors
|
||||||
|
orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
new_slices = tuple(slice(0, dim) for dim in tensor.shape)
|
||||||
|
|
||||||
|
# Copy the original tensor into the new tensor
|
||||||
|
padded_tensor[new_slices] = tensor[orig_slices]
|
||||||
|
|
||||||
|
return padded_tensor
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
@ -375,12 +409,18 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
|||||||
v = v[1]
|
v = v[1]
|
||||||
|
|
||||||
if patch_type == "diff":
|
if patch_type == "diff":
|
||||||
w1 = v[0]
|
diff: torch.Tensor = v[0]
|
||||||
|
# An extra flag to pad the weight if the diff's shape is larger than the weight
|
||||||
|
do_pad_weight = len(v) > 1 and v[1]['pad_weight']
|
||||||
|
if do_pad_weight and diff.shape != weight.shape:
|
||||||
|
logging.info("Pad weight {} from {} to shape: {}".format(key, weight.shape, diff.shape))
|
||||||
|
weight = pad_tensor_to_shape(weight, diff.shape)
|
||||||
|
|
||||||
if strength != 0.0:
|
if strength != 0.0:
|
||||||
if w1.shape != weight.shape:
|
if diff.shape != weight.shape:
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
|
||||||
else:
|
else:
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
|
||||||
elif patch_type == "lora": #lora/locon
|
elif patch_type == "lora": #lora/locon
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user