From a68bbafddb93d6f111af3d1fc22d3e38ce8186de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 19 Oct 2024 23:47:42 -0400 Subject: [PATCH] Support diffusion models with scaled fp8 weights. --- comfy/model_base.py | 7 ++++- comfy/model_detection.py | 9 ++++-- comfy/model_patcher.py | 35 ++++++++++++++++++--- comfy/ops.py | 56 ++++++++++++++++++++++++++-------- comfy/sd.py | 8 ++++- comfy/supported_models_base.py | 1 + comfy/utils.py | 2 +- 7 files changed, 95 insertions(+), 23 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index a98fee1d..7e86e76d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -96,7 +96,8 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=model_config.optimizations.get("fp8", False)) + fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -244,6 +245,10 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) unet_state_dict = self.diffusion_model.state_dict() + + if self.model_config.scaled_fp8: + unet_state_dict["scaled_fp8"] = torch.tensor([]) + unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 1edbcda4..3f720bce 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -286,9 +286,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal return None model_config = model_config_from_unet_config(unet_config, state_dict) if model_config is None and use_base_if_no_match: - return comfy.supported_models_base.BASE(unet_config) - else: - return model_config + model_config = comfy.supported_models_base.BASE(unet_config) + + if "{}scaled_fp8".format(unet_key_prefix) in state_dict: + model_config.scaled_fp8 = True + + return model_config def unet_prefix_from_state_dict(state_dict): candidates = ["model.diffusion_model.", #ldm/sgm models diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 2ba30633..87ebe54c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -317,7 +317,26 @@ class ModelPatcher: if key not in self.patches: return - weight = comfy.utils.get_attr(self.model, key) + set_func = None + convert_func = None + op_keys = key.rsplit('.', 1) + if len(op_keys) < 2: + weight = comfy.utils.get_attr(self.model, key) + else: + op = comfy.utils.get_attr(self.model, op_keys[0]) + try: + set_func = getattr(op, "set_{}".format(op_keys[1])) + except AttributeError: + pass + + try: + convert_func = getattr(op, "convert_{}".format(op_keys[1])) + except AttributeError: + pass + + weight = getattr(op, op_keys[1]) + if convert_func is not None: + weight = comfy.utils.get_attr(self.model, key) inplace_update = self.weight_inplace_update or inplace_update @@ -328,12 +347,18 @@ class ModelPatcher: temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) + if convert_func is not None: + temp_weight = convert_func(temp_weight) + out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) + if set_func is None: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr_param(self.model, key, out_weight) else: - comfy.utils.set_attr_param(self.model, key, out_weight) + set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): mem_counter = 0 diff --git a/comfy/ops.py b/comfy/ops.py index a8bfe1ea..3f8271ea 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,6 +19,7 @@ import torch import comfy.model_management from comfy.cli_args import args +import comfy.float cast_to = comfy.model_management.cast_to #TODO: remove once no more references @@ -250,18 +251,18 @@ def fp8_linear(self, input): return None if len(input.shape) == 3: - inn = input.reshape(-1, input.shape[2]).to(dtype) w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) w = w.t() scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: - scale_weight = torch.ones((1), device=input.device, dtype=torch.float32) - if scale_input is None: - scale_input = scale_weight + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) if scale_input is None: - scale_input = torch.ones((1), device=input.device, dtype=torch.float32) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + inn = input.reshape(-1, input.shape[2]).to(dtype) + else: + inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype) if bias is not None: o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) @@ -289,15 +290,46 @@ class fp8_ops(manual_cast): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) +def scaled_fp8_ops(fp8_matrix_mult=False): + class scaled_fp8_op(manual_cast): + class Linear(manual_cast.Linear): + def reset_parameters(self): + if not hasattr(self, 'scale_weight'): + self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) + if not hasattr(self, 'scale_input'): + self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) + return None -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False): - if comfy.model_management.supports_fp8_compute(load_device): - if (fp8_optimizations or args.fast) and not disable_fast_fp8: - return fp8_ops + def forward_comfy_cast_weights(self, input): + if fp8_matrix_mult: + out = fp8_linear(self, input) + if out is not None: + return out + + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) + + def convert_weight(self, weight): + return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) + + def set_weight(self, weight, inplace_update=False, seed=None, **kwargs): + weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) + if inplace_update: + self.weight.data.copy_(weight) + else: + self.weight = torch.nn.Parameter(weight, requires_grad=False) + + return scaled_fp8_op + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=False): + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) + if scaled_fp8: + return scaled_fp8_ops(fp8_matrix_mult=fp8_compute) + + if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8: + return fp8_ops if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init - if args.fast and not disable_fast_fp8: - if comfy.model_management.supports_fp8_compute(load_device): - return fp8_ops + return manual_cast diff --git a/comfy/sd.py b/comfy/sd.py index 67b4ff0c..9f552f20 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -649,6 +649,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse sd = temp_sd parameters = comfy.utils.calculate_parameters(sd) + weight_dtype = comfy.utils.weight_dtype(sd) + load_device = model_management.get_torch_device() model_config = model_detection.model_config_from_unet(sd, "") @@ -675,8 +677,12 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() + unet_weight_dtype = list(model_config.supported_inference_dtypes) + if weight_dtype is not None: + unet_weight_dtype.append(weight_dtype) + if dtype is None: - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) else: unet_dtype = dtype diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index e2b70e69..0e69d16a 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,6 +49,7 @@ class BASE: manual_cast_dtype = None custom_operations = None + scaled_fp8 = False optimizations = {"fp8": False} @classmethod diff --git a/comfy/utils.py b/comfy/utils.py index 78f8314f..7cef9044 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -68,7 +68,7 @@ def weight_dtype(sd, prefix=""): for k in sd.keys(): if k.startswith(prefix): w = sd[k] - dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1 + dtypes[w.dtype] = dtypes.get(w.dtype, 0) + w.numel() if len(dtypes) == 0: return None