diff --git a/comfy/samplers.py b/comfy/samplers.py index 437d1643..6548b70c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -167,7 +167,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con timestep_ = torch.cat([timestep] * batch_chunks) if control is not None: - c['control'] = control.get_control(input_x, timestep_, c['c_crossattn']) + c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'], len(cond_or_uncond)) output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x diff --git a/comfy/sd.py b/comfy/sd.py index 315d9512..7771081e 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -359,6 +359,28 @@ class VAE: samples = samples.cpu() return samples + +def resize_image_to(tensor, target_latent_tensor, batched_number): + tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center") + target_batch_size = target_latent_tensor.shape[0] + + current_batch_size = tensor.shape[0] + print(current_batch_size, target_batch_size) + if current_batch_size == 1: + return tensor + + per_batch = target_batch_size // batched_number + tensor = tensor[:per_batch] + + if per_batch > tensor.shape[0]: + tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0) + + current_batch_size = tensor.shape[0] + if current_batch_size == target_batch_size: + return tensor + else: + return torch.cat([tensor] * batched_number, dim=0) + class ControlNet: def __init__(self, control_model, device="cuda"): self.control_model = control_model @@ -368,7 +390,7 @@ class ControlNet: self.device = device self.previous_controlnet = None - def get_control(self, x_noisy, t, cond_txt): + def get_control(self, x_noisy, t, cond_txt, batched_number): control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) @@ -378,7 +400,7 @@ class ControlNet: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device) + self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).to(self.control_model.dtype).to(self.device) if self.control_model.dtype == torch.float16: precision_scope = torch.autocast @@ -516,7 +538,7 @@ class T2IAdapter: self.cond_hint_original = None self.cond_hint = None - def get_control(self, x_noisy, t, cond_txt): + def get_control(self, x_noisy, t, cond_txt, batched_number): control_prev = None if self.previous_controlnet is not None: control_prev = self.previous_controlnet.get_control(x_noisy, t, cond_txt) @@ -525,7 +547,7 @@ class T2IAdapter: if self.cond_hint is not None: del self.cond_hint self.cond_hint = None - self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device) + self.cond_hint = resize_image_to(self.cond_hint_original, x_noisy, batched_number).float().to(self.device) if self.channels_in == 1 and self.cond_hint.shape[1] > 1: self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True) self.t2i_model.to(self.device)