mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Controlnet/t2iadapter cleanup.
This commit is contained in:
parent
763b0cf024
commit
cf5ae46928
@ -632,7 +632,9 @@ class UNetModel(nn.Module):
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||
h += control['middle'].pop()
|
||||
ctrl = control['middle'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
|
105
comfy/sd.py
105
comfy/sd.py
@ -742,6 +742,7 @@ class ControlBase:
|
||||
device = model_management.get_torch_device()
|
||||
self.device = device
|
||||
self.previous_controlnet = None
|
||||
self.global_average_pooling = False
|
||||
|
||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
||||
self.cond_hint_original = cond_hint
|
||||
@ -777,6 +778,51 @@ class ControlBase:
|
||||
c.strength = self.strength
|
||||
c.timestep_percent_range = self.timestep_percent_range
|
||||
|
||||
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||
out = {'input':[], 'middle':[], 'output': []}
|
||||
|
||||
if control_input is not None:
|
||||
for i in range(len(control_input)):
|
||||
key = 'input'
|
||||
x = control_input[i]
|
||||
if x is not None:
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
out[key].insert(0, x)
|
||||
|
||||
if control_output is not None:
|
||||
for i in range(len(control_output)):
|
||||
if i == (len(control_output) - 1):
|
||||
key = 'middle'
|
||||
index = 0
|
||||
else:
|
||||
key = 'output'
|
||||
index = i
|
||||
x = control_output[i]
|
||||
if x is not None:
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
out[key].append(x)
|
||||
if control_prev is not None:
|
||||
for x in ['input', 'middle', 'output']:
|
||||
o = out[x]
|
||||
for i in range(len(control_prev[x])):
|
||||
prev_val = control_prev[x][i]
|
||||
if i >= len(o):
|
||||
o.append(prev_val)
|
||||
elif prev_val is not None:
|
||||
if o[i] is None:
|
||||
o[i] = prev_val
|
||||
else:
|
||||
o[i] += prev_val
|
||||
return out
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||
super().__init__(device)
|
||||
@ -811,32 +857,7 @@ class ControlNet(ControlBase):
|
||||
if y is not None:
|
||||
y = y.to(self.control_model.dtype)
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||
|
||||
out = {'middle':[], 'output': []}
|
||||
|
||||
for i in range(len(control)):
|
||||
if i == (len(control) - 1):
|
||||
key = 'middle'
|
||||
index = 0
|
||||
else:
|
||||
key = 'output'
|
||||
index = i
|
||||
x = control[i]
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
if control_prev is not None and key in control_prev:
|
||||
prev = control_prev[key][index]
|
||||
if prev is not None:
|
||||
x += prev
|
||||
out[key].append(x)
|
||||
if control_prev is not None and 'input' in control_prev:
|
||||
out['input'] = control_prev['input']
|
||||
return out
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||
@ -1101,37 +1122,13 @@ class T2IAdapter(ControlBase):
|
||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||
if self.control_input is None:
|
||||
self.t2i_model.to(x_noisy.dtype)
|
||||
self.t2i_model.to(self.device)
|
||||
self.control_input = self.t2i_model(self.cond_hint)
|
||||
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||
self.t2i_model.cpu()
|
||||
|
||||
output_dtype = x_noisy.dtype
|
||||
out = {'input':[]}
|
||||
|
||||
for i in range(len(self.control_input)):
|
||||
key = 'input'
|
||||
x = self.control_input[i] * self.strength
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
if control_prev is not None and key in control_prev:
|
||||
index = len(control_prev[key]) - i * 3 - 3
|
||||
prev = control_prev[key][index]
|
||||
if prev is not None:
|
||||
x += prev
|
||||
out[key].insert(0, None)
|
||||
out[key].insert(0, None)
|
||||
out[key].insert(0, x)
|
||||
|
||||
if control_prev is not None and 'input' in control_prev:
|
||||
for i in range(len(out['input'])):
|
||||
if out['input'][i] is None:
|
||||
out['input'][i] = control_prev['input'][i]
|
||||
if control_prev is not None and 'middle' in control_prev:
|
||||
out['middle'] = control_prev['middle']
|
||||
if control_prev is not None and 'output' in control_prev:
|
||||
out['output'] = control_prev['output']
|
||||
return out
|
||||
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
|
||||
return self.control_merge(control_input, None, control_prev, x_noisy.dtype)
|
||||
|
||||
def copy(self):
|
||||
c = T2IAdapter(self.t2i_model, self.channels_in)
|
||||
|
@ -128,6 +128,8 @@ class Adapter(nn.Module):
|
||||
for j in range(self.nums_rb):
|
||||
idx = i * self.nums_rb + j
|
||||
x = self.body[idx](x)
|
||||
features.append(None)
|
||||
features.append(None)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
@ -259,6 +261,8 @@ class Adapter_light(nn.Module):
|
||||
features = []
|
||||
for i in range(len(self.channels)):
|
||||
x = self.body[i](x)
|
||||
features.append(None)
|
||||
features.append(None)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
Loading…
Reference in New Issue
Block a user