diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 00373a790..5eee5a51c 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -141,24 +141,24 @@ class ControlNet(nn.Module): ) ] ) - self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)]) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)]) self.input_hint_block = TimestepEmbedSequential( - operations.conv_nd(dims, hint_channels, 16, 3, padding=1), + operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 16, 3, padding=1), + operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2), + operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 32, 3, padding=1), + operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2), + operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 96, 3, padding=1), + operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2), + operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1)) + operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) self._feature_size = model_channels @@ -206,7 +206,7 @@ class ControlNet(nn.Module): ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: @@ -234,7 +234,7 @@ class ControlNet(nn.Module): ) ch = out_ch input_block_chans.append(ch) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) ds *= 2 self._feature_size += ch @@ -276,11 +276,11 @@ class ControlNet(nn.Module): operations=operations )] self.middle_block = TimestepEmbedSequential(*mid_block) - self.middle_block_out = self.make_zero_conv(ch, operations=operations) + self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self._feature_size += ch - def make_zero_conv(self, channels, operations=None): - return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) + def make_zero_conv(self, channels, operations=None, dtype=None, device=None): + return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) def forward(self, x, hint, timesteps, context, y=None, **kwargs): t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 3212ac8c4..110b5c7c2 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -36,13 +36,13 @@ class ControlBase: self.cond_hint = None self.strength = 1.0 self.timestep_percent_range = (0.0, 1.0) + self.global_average_pooling = False self.timestep_range = None if device is None: device = comfy.model_management.get_torch_device() self.device = device self.previous_controlnet = None - self.global_average_pooling = False def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): self.cond_hint_original = cond_hint @@ -77,6 +77,7 @@ class ControlBase: c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range + c.global_average_pooling = self.global_average_pooling def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -129,12 +130,14 @@ class ControlBase: return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + self.load_device = load_device + self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling self.model_sampling_current = None + self.manual_cast_dtype = manual_cast_dtype def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -149,11 +152,8 @@ class ControlNet(ControlBase): return None dtype = self.control_model.dtype - if comfy.model_management.supports_dtype(self.device, dtype): - precision_scope = lambda a: contextlib.nullcontext(a) - else: - precision_scope = torch.autocast - dtype = torch.float32 + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: @@ -171,12 +171,11 @@ class ControlNet(ControlBase): timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - with precision_scope(comfy.model_management.get_autocast_device(self.device)): - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) self.copy_to(c) return c @@ -207,10 +206,11 @@ class ControlLoraOps: self.bias = None def forward(self, input): + weight, bias = comfy.ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias) + return torch.nn.functional.linear(input, weight, bias) class Conv2d(torch.nn.Module): def __init__( @@ -246,10 +246,11 @@ class ControlLoraOps: def forward(self, input): + weight, bias = comfy.ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) class ControlLora(ControlNet): @@ -263,12 +264,19 @@ class ControlLora(ControlNet): controlnet_config = model.model_config.unet_config.copy() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): - pass - controlnet_config["operations"] = control_lora_ops - self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) + self.manual_cast_dtype = model.manual_cast_dtype dtype = model.get_dtype() - self.control_model.to(dtype) + if self.manual_cast_dtype is None: + class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): + pass + else: + class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast): + pass + dtype = self.manual_cast_dtype + + controlnet_config["operations"] = control_lora_ops + controlnet_config["dtype"] = dtype + self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) self.control_model.to(comfy.model_management.get_torch_device()) diffusion_model = model.diffusion_model sd = diffusion_model.state_dict() @@ -372,6 +380,10 @@ def load_controlnet(ckpt_path, model=None): if controlnet_config is None: unet_dtype = comfy.model_management.unet_dtype() controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + load_device = comfy.model_management.get_torch_device() + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + controlnet_config["operations"] = comfy.ops.manual_cast controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) @@ -400,14 +412,12 @@ def load_controlnet(ckpt_path, model=None): missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) print(missing, unexpected) - control_model = control_model.to(unet_dtype) - global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True - control = ControlNet(control_model, global_average_pooling=global_average_pooling) + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control class T2IAdapter(ControlBase):