mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Controlnet/t2iadapter cleanup.
This commit is contained in:
parent
763b0cf024
commit
cf5ae46928
@ -632,7 +632,9 @@ class UNetModel(nn.Module):
|
|||||||
transformer_options["block"] = ("middle", 0)
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||||
h += control['middle'].pop()
|
ctrl = control['middle'].pop()
|
||||||
|
if ctrl is not None:
|
||||||
|
h += ctrl
|
||||||
|
|
||||||
for id, module in enumerate(self.output_blocks):
|
for id, module in enumerate(self.output_blocks):
|
||||||
transformer_options["block"] = ("output", id)
|
transformer_options["block"] = ("output", id)
|
||||||
|
105
comfy/sd.py
105
comfy/sd.py
@ -742,6 +742,7 @@ class ControlBase:
|
|||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
self.global_average_pooling = False
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@ -777,6 +778,51 @@ class ControlBase:
|
|||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
|
|
||||||
|
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||||
|
out = {'input':[], 'middle':[], 'output': []}
|
||||||
|
|
||||||
|
if control_input is not None:
|
||||||
|
for i in range(len(control_input)):
|
||||||
|
key = 'input'
|
||||||
|
x = control_input[i]
|
||||||
|
if x is not None:
|
||||||
|
x *= self.strength
|
||||||
|
if x.dtype != output_dtype:
|
||||||
|
x = x.to(output_dtype)
|
||||||
|
out[key].insert(0, x)
|
||||||
|
|
||||||
|
if control_output is not None:
|
||||||
|
for i in range(len(control_output)):
|
||||||
|
if i == (len(control_output) - 1):
|
||||||
|
key = 'middle'
|
||||||
|
index = 0
|
||||||
|
else:
|
||||||
|
key = 'output'
|
||||||
|
index = i
|
||||||
|
x = control_output[i]
|
||||||
|
if x is not None:
|
||||||
|
if self.global_average_pooling:
|
||||||
|
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||||
|
|
||||||
|
x *= self.strength
|
||||||
|
if x.dtype != output_dtype:
|
||||||
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
|
out[key].append(x)
|
||||||
|
if control_prev is not None:
|
||||||
|
for x in ['input', 'middle', 'output']:
|
||||||
|
o = out[x]
|
||||||
|
for i in range(len(control_prev[x])):
|
||||||
|
prev_val = control_prev[x][i]
|
||||||
|
if i >= len(o):
|
||||||
|
o.append(prev_val)
|
||||||
|
elif prev_val is not None:
|
||||||
|
if o[i] is None:
|
||||||
|
o[i] = prev_val
|
||||||
|
else:
|
||||||
|
o[i] += prev_val
|
||||||
|
return out
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
@ -811,32 +857,7 @@ class ControlNet(ControlBase):
|
|||||||
if y is not None:
|
if y is not None:
|
||||||
y = y.to(self.control_model.dtype)
|
y = y.to(self.control_model.dtype)
|
||||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||||
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
out = {'middle':[], 'output': []}
|
|
||||||
|
|
||||||
for i in range(len(control)):
|
|
||||||
if i == (len(control) - 1):
|
|
||||||
key = 'middle'
|
|
||||||
index = 0
|
|
||||||
else:
|
|
||||||
key = 'output'
|
|
||||||
index = i
|
|
||||||
x = control[i]
|
|
||||||
if self.global_average_pooling:
|
|
||||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
|
||||||
|
|
||||||
x *= self.strength
|
|
||||||
if x.dtype != output_dtype:
|
|
||||||
x = x.to(output_dtype)
|
|
||||||
|
|
||||||
if control_prev is not None and key in control_prev:
|
|
||||||
prev = control_prev[key][index]
|
|
||||||
if prev is not None:
|
|
||||||
x += prev
|
|
||||||
out[key].append(x)
|
|
||||||
if control_prev is not None and 'input' in control_prev:
|
|
||||||
out['input'] = control_prev['input']
|
|
||||||
return out
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||||
@ -1101,37 +1122,13 @@ class T2IAdapter(ControlBase):
|
|||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
if self.control_input is None:
|
if self.control_input is None:
|
||||||
|
self.t2i_model.to(x_noisy.dtype)
|
||||||
self.t2i_model.to(self.device)
|
self.t2i_model.to(self.device)
|
||||||
self.control_input = self.t2i_model(self.cond_hint)
|
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||||
self.t2i_model.cpu()
|
self.t2i_model.cpu()
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
|
||||||
out = {'input':[]}
|
return self.control_merge(control_input, None, control_prev, x_noisy.dtype)
|
||||||
|
|
||||||
for i in range(len(self.control_input)):
|
|
||||||
key = 'input'
|
|
||||||
x = self.control_input[i] * self.strength
|
|
||||||
if x.dtype != output_dtype:
|
|
||||||
x = x.to(output_dtype)
|
|
||||||
|
|
||||||
if control_prev is not None and key in control_prev:
|
|
||||||
index = len(control_prev[key]) - i * 3 - 3
|
|
||||||
prev = control_prev[key][index]
|
|
||||||
if prev is not None:
|
|
||||||
x += prev
|
|
||||||
out[key].insert(0, None)
|
|
||||||
out[key].insert(0, None)
|
|
||||||
out[key].insert(0, x)
|
|
||||||
|
|
||||||
if control_prev is not None and 'input' in control_prev:
|
|
||||||
for i in range(len(out['input'])):
|
|
||||||
if out['input'][i] is None:
|
|
||||||
out['input'][i] = control_prev['input'][i]
|
|
||||||
if control_prev is not None and 'middle' in control_prev:
|
|
||||||
out['middle'] = control_prev['middle']
|
|
||||||
if control_prev is not None and 'output' in control_prev:
|
|
||||||
out['output'] = control_prev['output']
|
|
||||||
return out
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = T2IAdapter(self.t2i_model, self.channels_in)
|
c = T2IAdapter(self.t2i_model, self.channels_in)
|
||||||
|
@ -128,6 +128,8 @@ class Adapter(nn.Module):
|
|||||||
for j in range(self.nums_rb):
|
for j in range(self.nums_rb):
|
||||||
idx = i * self.nums_rb + j
|
idx = i * self.nums_rb + j
|
||||||
x = self.body[idx](x)
|
x = self.body[idx](x)
|
||||||
|
features.append(None)
|
||||||
|
features.append(None)
|
||||||
features.append(x)
|
features.append(x)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
@ -259,6 +261,8 @@ class Adapter_light(nn.Module):
|
|||||||
features = []
|
features = []
|
||||||
for i in range(len(self.channels)):
|
for i in range(len(self.channels)):
|
||||||
x = self.body[i](x)
|
x = self.body[i](x)
|
||||||
|
features.append(None)
|
||||||
|
features.append(None)
|
||||||
features.append(x)
|
features.append(x)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
Loading…
Reference in New Issue
Block a user