Controlnet/t2iadapter cleanup.

This commit is contained in:
comfyanonymous 2023-08-21 23:20:49 -04:00
parent 763b0cf024
commit cf5ae46928
3 changed files with 58 additions and 55 deletions

View File

@ -632,7 +632,9 @@ class UNetModel(nn.Module):
transformer_options["block"] = ("middle", 0) transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0: if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop() ctrl = control['middle'].pop()
if ctrl is not None:
h += ctrl
for id, module in enumerate(self.output_blocks): for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id) transformer_options["block"] = ("output", id)

View File

@ -742,6 +742,7 @@ class ControlBase:
device = model_management.get_torch_device() device = model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@ -777,6 +778,51 @@ class ControlBase:
c.strength = self.strength c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range c.timestep_percent_range = self.timestep_percent_range
def control_merge(self, control_input, control_output, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []}
if control_input is not None:
for i in range(len(control_input)):
key = 'input'
x = control_input[i]
if x is not None:
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].insert(0, x)
if control_output is not None:
for i in range(len(control_output)):
if i == (len(control_output) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control_output[i]
if x is not None:
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]
for i in range(len(control_prev[x])):
prev_val = control_prev[x][i]
if i >= len(o):
o.append(prev_val)
elif prev_val is not None:
if o[i] is None:
o[i] = prev_val
else:
o[i] += prev_val
return out
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None): def __init__(self, control_model, global_average_pooling=False, device=None):
super().__init__(device) super().__init__(device)
@ -811,32 +857,7 @@ class ControlNet(ControlBase):
if y is not None: if y is not None:
y = y.to(self.control_model.dtype) y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)
out = {'middle':[], 'output': []}
for i in range(len(control)):
if i == (len(control) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control[i]
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
if control_prev is not None and key in control_prev:
prev = control_prev[key][index]
if prev is not None:
x += prev
out[key].append(x)
if control_prev is not None and 'input' in control_prev:
out['input'] = control_prev['input']
return out
def copy(self): def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
@ -1101,37 +1122,13 @@ class T2IAdapter(ControlBase):
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
if self.control_input is None: if self.control_input is None:
self.t2i_model.to(x_noisy.dtype)
self.t2i_model.to(self.device) self.t2i_model.to(self.device)
self.control_input = self.t2i_model(self.cond_hint) self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
self.t2i_model.cpu() self.t2i_model.cpu()
output_dtype = x_noisy.dtype control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
out = {'input':[]} return self.control_merge(control_input, None, control_prev, x_noisy.dtype)
for i in range(len(self.control_input)):
key = 'input'
x = self.control_input[i] * self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
if control_prev is not None and key in control_prev:
index = len(control_prev[key]) - i * 3 - 3
prev = control_prev[key][index]
if prev is not None:
x += prev
out[key].insert(0, None)
out[key].insert(0, None)
out[key].insert(0, x)
if control_prev is not None and 'input' in control_prev:
for i in range(len(out['input'])):
if out['input'][i] is None:
out['input'][i] = control_prev['input'][i]
if control_prev is not None and 'middle' in control_prev:
out['middle'] = control_prev['middle']
if control_prev is not None and 'output' in control_prev:
out['output'] = control_prev['output']
return out
def copy(self): def copy(self):
c = T2IAdapter(self.t2i_model, self.channels_in) c = T2IAdapter(self.t2i_model, self.channels_in)

View File

@ -128,6 +128,8 @@ class Adapter(nn.Module):
for j in range(self.nums_rb): for j in range(self.nums_rb):
idx = i * self.nums_rb + j idx = i * self.nums_rb + j
x = self.body[idx](x) x = self.body[idx](x)
features.append(None)
features.append(None)
features.append(x) features.append(x)
return features return features
@ -259,6 +261,8 @@ class Adapter_light(nn.Module):
features = [] features = []
for i in range(len(self.channels)): for i in range(len(self.channels)):
x = self.body[i](x) x = self.body[i](x)
features.append(None)
features.append(None)
features.append(x) features.append(x)
return features return features