From 7c0f255de16b78e54e0c051e9f7e1e46c7422c6c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 31 Oct 2023 22:14:32 -0400 Subject: [PATCH] Clean up percent start/end and make controlnets work with sigmas. --- comfy/controlnet.py | 14 +++++++++++++- comfy/model_base.py | 5 ++++- comfy/samplers.py | 16 +++++++++------- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 2a88dd019..098681582 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -132,6 +132,7 @@ class ControlNet(ControlBase): self.control_model = control_model self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling + self.model_sampling_current = None def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -159,7 +160,10 @@ class ControlNet(ControlBase): y = cond.get('y', None) if y is not None: y = y.to(self.control_model.dtype) - control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) + timestep = self.model_sampling_current.timestep(t) + x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) + + control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): @@ -172,6 +176,14 @@ class ControlNet(ControlBase): out.append(self.control_model_wrapped) return out + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + self.model_sampling_current = model.model_sampling + + def cleanup(self): + self.model_sampling_current = None + super().cleanup() + class ControlLoraOps: class Linear(torch.nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = True, diff --git a/comfy/model_base.py b/comfy/model_base.py index b8d04a2c8..84cf9829d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -82,6 +82,9 @@ class ModelSamplingDiscrete(torch.nn.Module): log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] return log_sigma.exp() + def percent_to_sigma(self, percent): + return self.sigma(torch.tensor(percent * 999.0)) + def model_sampling(model_config, model_type): if model_type == ModelType.EPS: c = EPS @@ -126,7 +129,7 @@ class BaseModel(torch.nn.Module): context = c_crossattn dtype = self.get_dtype() xc = xc.to(dtype) - t = self.model_sampling.timestep(t).to(dtype) + t = self.model_sampling.timestep(t).float() context = context.to(dtype) extra_conds = {} for o in kwargs: diff --git a/comfy/samplers.py b/comfy/samplers.py index e10e02c41..a74c8a1b8 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -415,15 +415,16 @@ def create_cond_with_same_area_if_none(conds, c): conds += [out] def calculate_start_end_timesteps(model, conds): + s = model.model_sampling for t in range(len(conds)): x = conds[t] timestep_start = None timestep_end = None if 'start_percent' in x: - timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0))) + timestep_start = s.percent_to_sigma(x['start_percent']) if 'end_percent' in x: - timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0))) + timestep_end = s.percent_to_sigma(x['end_percent']) if (timestep_start is not None) or (timestep_end is not None): n = x.copy() @@ -434,14 +435,15 @@ def calculate_start_end_timesteps(model, conds): conds[t] = n def pre_run_control(model, conds): + s = model.model_sampling for t in range(len(conds)): x = conds[t] timestep_start = None timestep_end = None - percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) + percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: - x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) + x['control'].pre_run(model, percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -571,8 +573,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model model_wrap = wrap_model(model) - calculate_start_end_timesteps(model_wrap, negative) - calculate_start_end_timesteps(model_wrap, positive) + calculate_start_end_timesteps(model, negative) + calculate_start_end_timesteps(model, positive) #make sure each cond area has an opposite one with the same area for c in positive: @@ -580,7 +582,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model for c in negative: create_cond_with_same_area_if_none(positive, c) - pre_run_control(model_wrap, negative + positive) + pre_run_control(model, negative + positive) apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])