diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 5201b3c2..25148313 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -279,7 +279,7 @@ class ControlNet(nn.Module): return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) def forward(self, x, hint, timesteps, context, y=None, **kwargs): - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) emb = self.time_embed(t_emb) guided_hint = self.input_hint_block(hint, emb, context) @@ -287,9 +287,6 @@ class ControlNet(nn.Module): outs = [] hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - if self.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) diff --git a/comfy/sd.py b/comfy/sd.py index dc5daffa..85806e70 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -798,17 +798,14 @@ class ControlNet(ControlBase): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - if self.control_model.dtype == torch.float16: - precision_scope = torch.autocast - else: - precision_scope = contextlib.nullcontext - with precision_scope(model_management.get_autocast_device(self.device)): - context = torch.cat(cond['c_crossattn'], 1) - y = cond.get('c_adm', None) - control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y) + context = torch.cat(cond['c_crossattn'], 1) + y = cond.get('c_adm', None) + if y is not None: + y = y.to(self.control_model.dtype) + control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) + out = {'middle':[], 'output': []} - autocast_enabled = torch.is_autocast_enabled() for i in range(len(control)): if i == (len(control) - 1): @@ -822,7 +819,7 @@ class ControlNet(ControlBase): x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) x *= self.strength - if x.dtype != output_dtype and not autocast_enabled: + if x.dtype != output_dtype: x = x.to(output_dtype) if control_prev is not None and key in control_prev: @@ -1098,11 +1095,10 @@ class T2IAdapter(ControlBase): output_dtype = x_noisy.dtype out = {'input':[]} - autocast_enabled = torch.is_autocast_enabled() for i in range(len(self.control_input)): key = 'input' x = self.control_input[i] * self.strength - if x.dtype != output_dtype and not autocast_enabled: + if x.dtype != output_dtype: x = x.to(output_dtype) if control_prev is not None and key in control_prev: