mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Remove autocast from controlnet code.
This commit is contained in:
parent
0d7b0a4dc7
commit
d08e53de2e
@ -279,7 +279,7 @@ class ControlNet(nn.Module):
|
|||||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||||
|
|
||||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
guided_hint = self.input_hint_block(hint, emb, context)
|
guided_hint = self.input_hint_block(hint, emb, context)
|
||||||
@ -287,9 +287,6 @@ class ControlNet(nn.Module):
|
|||||||
outs = []
|
outs = []
|
||||||
|
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
|
||||||
emb = self.time_embed(t_emb)
|
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
20
comfy/sd.py
20
comfy/sd.py
@ -798,17 +798,14 @@ class ControlNet(ControlBase):
|
|||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
if self.control_model.dtype == torch.float16:
|
|
||||||
precision_scope = torch.autocast
|
|
||||||
else:
|
|
||||||
precision_scope = contextlib.nullcontext
|
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
context = torch.cat(cond['c_crossattn'], 1)
|
||||||
context = torch.cat(cond['c_crossattn'], 1)
|
y = cond.get('c_adm', None)
|
||||||
y = cond.get('c_adm', None)
|
if y is not None:
|
||||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
|
y = y.to(self.control_model.dtype)
|
||||||
|
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||||
|
|
||||||
out = {'middle':[], 'output': []}
|
out = {'middle':[], 'output': []}
|
||||||
autocast_enabled = torch.is_autocast_enabled()
|
|
||||||
|
|
||||||
for i in range(len(control)):
|
for i in range(len(control)):
|
||||||
if i == (len(control) - 1):
|
if i == (len(control) - 1):
|
||||||
@ -822,7 +819,7 @@ class ControlNet(ControlBase):
|
|||||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||||
|
|
||||||
x *= self.strength
|
x *= self.strength
|
||||||
if x.dtype != output_dtype and not autocast_enabled:
|
if x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
if control_prev is not None and key in control_prev:
|
if control_prev is not None and key in control_prev:
|
||||||
@ -1098,11 +1095,10 @@ class T2IAdapter(ControlBase):
|
|||||||
output_dtype = x_noisy.dtype
|
output_dtype = x_noisy.dtype
|
||||||
out = {'input':[]}
|
out = {'input':[]}
|
||||||
|
|
||||||
autocast_enabled = torch.is_autocast_enabled()
|
|
||||||
for i in range(len(self.control_input)):
|
for i in range(len(self.control_input)):
|
||||||
key = 'input'
|
key = 'input'
|
||||||
x = self.control_input[i] * self.strength
|
x = self.control_input[i] * self.strength
|
||||||
if x.dtype != output_dtype and not autocast_enabled:
|
if x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
if control_prev is not None and key in control_prev:
|
if control_prev is not None and key in control_prev:
|
||||||
|
Loading…
Reference in New Issue
Block a user