diff --git a/comfy/ops.py b/comfy/ops.py index 5fef7cee..d7596634 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -20,31 +20,36 @@ import torch import comfy.model_management from comfy.cli_args import args -def cast_to(weight, dtype=None, device=None, non_blocking=False): - if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device): +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True): + if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device): return weight r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r -def cast_to_input(weight, input, non_blocking=False): - return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking) +def cast_to_input(weight, input, non_blocking=False, copy=True): + return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None): +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype if device is None: device = input.device bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: - bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking) - if s.bias_function is not None: + has_function = s.bias_function is not None + bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) + if has_function: bias = s.bias_function(bias) - weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking) - if s.weight_function is not None: + + has_function = s.weight_function is not None + weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) + if has_function: weight = s.weight_function(weight) return weight, bias @@ -252,7 +257,8 @@ def fp8_linear(self, input): if len(input.shape) == 3: inn = input.reshape(-1, input.shape[2]).to(dtype) non_blocking = comfy.model_management.device_supports_non_blocking(input.device) - w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t() + w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) + w = w.t() scale_weight = self.scale_weight scale_input = self.scale_input @@ -263,8 +269,8 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((1), device=input.device, dtype=torch.float32) - if self.bias is not None: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight) + if bias is not None: + o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) else: o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)