From f9f9fafacedeeb329c8c61d30be52fcf6e92f33b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 20 Oct 2024 06:24:31 -0400 Subject: [PATCH] Fixed model merging issue with scaled fp8. --- comfy/lora.py | 2 +- comfy/model_patcher.py | 59 +++++++++++++++++++++++------------------- comfy/ops.py | 8 ++++-- 3 files changed, 40 insertions(+), 29 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index 80057cdd..81cd1696 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -415,7 +415,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): weight *= strength_model if isinstance(v, list): - v = (calculate_weight(v[1:], comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype), ) + v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), ) if len(v) == 1: patch_type = "diff" diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 87ebe54c..3bba217a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -94,6 +94,31 @@ class LowVramPatch: return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key)) return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) + +def get_key_weight(model, key): + set_func = None + convert_func = None + op_keys = key.rsplit('.', 1) + if len(op_keys) < 2: + weight = comfy.utils.get_attr(model, key) + else: + op = comfy.utils.get_attr(model, op_keys[0]) + try: + set_func = getattr(op, "set_{}".format(op_keys[1])) + except AttributeError: + pass + + try: + convert_func = getattr(op, "convert_{}".format(op_keys[1])) + except AttributeError: + pass + + weight = getattr(op, op_keys[1]) + if convert_func is not None: + weight = comfy.utils.get_attr(model, key) + + return weight, set_func, convert_func + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -294,14 +319,16 @@ class ModelPatcher: if not k.startswith(filter_prefix): continue bk = self.backup.get(k, None) + weight, set_func, convert_func = get_key_weight(self.model, k) if bk is not None: weight = bk.weight - else: - weight = model_sd[k] + if convert_func is None: + convert_func = lambda a, **kwargs: a + if k in self.patches: - p[k] = [weight] + self.patches[k] + p[k] = [(weight, convert_func)] + self.patches[k] else: - p[k] = (weight,) + p[k] = [(weight, convert_func)] return p def model_state_dict(self, filter_prefix=None): @@ -317,27 +344,7 @@ class ModelPatcher: if key not in self.patches: return - set_func = None - convert_func = None - op_keys = key.rsplit('.', 1) - if len(op_keys) < 2: - weight = comfy.utils.get_attr(self.model, key) - else: - op = comfy.utils.get_attr(self.model, op_keys[0]) - try: - set_func = getattr(op, "set_{}".format(op_keys[1])) - except AttributeError: - pass - - try: - convert_func = getattr(op, "convert_{}".format(op_keys[1])) - except AttributeError: - pass - - weight = getattr(op, op_keys[1]) - if convert_func is not None: - weight = comfy.utils.get_attr(self.model, key) - + weight, set_func, convert_func = get_key_weight(self.model, key) inplace_update = self.weight_inplace_update or inplace_update if key not in self.backup: @@ -348,7 +355,7 @@ class ModelPatcher: else: temp_weight = weight.to(torch.float32, copy=True) if convert_func is not None: - temp_weight = convert_func(temp_weight) + temp_weight = convert_func(temp_weight, inplace=True) out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: diff --git a/comfy/ops.py b/comfy/ops.py index 3f8271ea..c07cd908 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -309,8 +309,12 @@ def scaled_fp8_ops(fp8_matrix_mult=False): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) - def convert_weight(self, weight): - return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) + def convert_weight(self, weight, inplace=False, **kwargs): + if inplace: + weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) + return weight + else: + return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) def set_weight(self, weight, inplace_update=False, seed=None, **kwargs): weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)