diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 11cec0ed..3ce3c2e7 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -632,7 +632,9 @@ class UNetModel(nn.Module): transformer_options["block"] = ("middle", 0) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) if control is not None and 'middle' in control and len(control['middle']) > 0: - h += control['middle'].pop() + ctrl = control['middle'].pop() + if ctrl is not None: + h += ctrl for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) diff --git a/comfy/sd.py b/comfy/sd.py index 3493b1a7..09eab505 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -742,6 +742,7 @@ class ControlBase: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + self.global_average_pooling = False def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): self.cond_hint_original = cond_hint @@ -777,6 +778,51 @@ class ControlBase: c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range + def control_merge(self, control_input, control_output, control_prev, output_dtype): + out = {'input':[], 'middle':[], 'output': []} + + if control_input is not None: + for i in range(len(control_input)): + key = 'input' + x = control_input[i] + if x is not None: + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + out[key].insert(0, x) + + if control_output is not None: + for i in range(len(control_output)): + if i == (len(control_output) - 1): + key = 'middle' + index = 0 + else: + key = 'output' + index = i + x = control_output[i] + if x is not None: + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + + x *= self.strength + if x.dtype != output_dtype: + x = x.to(output_dtype) + + out[key].append(x) + if control_prev is not None: + for x in ['input', 'middle', 'output']: + o = out[x] + for i in range(len(control_prev[x])): + prev_val = control_prev[x][i] + if i >= len(o): + o.append(prev_val) + elif prev_val is not None: + if o[i] is None: + o[i] = prev_val + else: + o[i] += prev_val + return out + class ControlNet(ControlBase): def __init__(self, control_model, global_average_pooling=False, device=None): super().__init__(device) @@ -811,32 +857,7 @@ class ControlNet(ControlBase): if y is not None: y = y.to(self.control_model.dtype) control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) - - out = {'middle':[], 'output': []} - - for i in range(len(control)): - if i == (len(control) - 1): - key = 'middle' - index = 0 - else: - key = 'output' - index = i - x = control[i] - if self.global_average_pooling: - x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) - - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - - if control_prev is not None and key in control_prev: - prev = control_prev[key][index] - if prev is not None: - x += prev - out[key].append(x) - if control_prev is not None and 'input' in control_prev: - out['input'] = control_prev['input'] - return out + return self.control_merge(None, control, control_prev, output_dtype) def copy(self): c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) @@ -1101,37 +1122,13 @@ class T2IAdapter(ControlBase): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) if self.control_input is None: + self.t2i_model.to(x_noisy.dtype) self.t2i_model.to(self.device) - self.control_input = self.t2i_model(self.cond_hint) + self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype)) self.t2i_model.cpu() - output_dtype = x_noisy.dtype - out = {'input':[]} - - for i in range(len(self.control_input)): - key = 'input' - x = self.control_input[i] * self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - - if control_prev is not None and key in control_prev: - index = len(control_prev[key]) - i * 3 - 3 - prev = control_prev[key][index] - if prev is not None: - x += prev - out[key].insert(0, None) - out[key].insert(0, None) - out[key].insert(0, x) - - if control_prev is not None and 'input' in control_prev: - for i in range(len(out['input'])): - if out['input'][i] is None: - out['input'][i] = control_prev['input'][i] - if control_prev is not None and 'middle' in control_prev: - out['middle'] = control_prev['middle'] - if control_prev is not None and 'output' in control_prev: - out['output'] = control_prev['output'] - return out + control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input)) + return self.control_merge(control_input, None, control_prev, x_noisy.dtype) def copy(self): c = T2IAdapter(self.t2i_model, self.channels_in) diff --git a/comfy/t2i_adapter/adapter.py b/comfy/t2i_adapter/adapter.py index 87e3d859..3647c4cf 100644 --- a/comfy/t2i_adapter/adapter.py +++ b/comfy/t2i_adapter/adapter.py @@ -128,6 +128,8 @@ class Adapter(nn.Module): for j in range(self.nums_rb): idx = i * self.nums_rb + j x = self.body[idx](x) + features.append(None) + features.append(None) features.append(x) return features @@ -259,6 +261,8 @@ class Adapter_light(nn.Module): features = [] for i in range(len(self.channels)): x = self.body[i](x) + features.append(None) + features.append(None) features.append(x) return features