mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Remove autocast from controlnet code.
This commit is contained in:
parent
0d7b0a4dc7
commit
d08e53de2e
@ -279,7 +279,7 @@ class ControlNet(nn.Module):
|
||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
|
||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
@ -287,9 +287,6 @@ class ControlNet(nn.Module):
|
||||
outs = []
|
||||
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
16
comfy/sd.py
16
comfy/sd.py
@ -798,17 +798,14 @@ class ControlNet(ControlBase):
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
|
||||
if self.control_model.dtype == torch.float16:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = contextlib.nullcontext
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
||||
context = torch.cat(cond['c_crossattn'], 1)
|
||||
y = cond.get('c_adm', None)
|
||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
|
||||
if y is not None:
|
||||
y = y.to(self.control_model.dtype)
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||
|
||||
out = {'middle':[], 'output': []}
|
||||
autocast_enabled = torch.is_autocast_enabled()
|
||||
|
||||
for i in range(len(control)):
|
||||
if i == (len(control) - 1):
|
||||
@ -822,7 +819,7 @@ class ControlNet(ControlBase):
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype and not autocast_enabled:
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
if control_prev is not None and key in control_prev:
|
||||
@ -1098,11 +1095,10 @@ class T2IAdapter(ControlBase):
|
||||
output_dtype = x_noisy.dtype
|
||||
out = {'input':[]}
|
||||
|
||||
autocast_enabled = torch.is_autocast_enabled()
|
||||
for i in range(len(self.control_input)):
|
||||
key = 'input'
|
||||
x = self.control_input[i] * self.strength
|
||||
if x.dtype != output_dtype and not autocast_enabled:
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
if control_prev is not None and key in control_prev:
|
||||
|
Loading…
Reference in New Issue
Block a user