Clean up percent start/end and make controlnets work with sigmas.

This commit is contained in:
comfyanonymous 2023-10-31 22:14:32 -04:00
parent a268a574fa
commit 7c0f255de1
3 changed files with 26 additions and 9 deletions

View File

@ -132,6 +132,7 @@ class ControlNet(ControlBase):
self.control_model = control_model self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@ -159,7 +160,10 @@ class ControlNet(ControlBase):
y = cond.get('y', None) y = cond.get('y', None)
if y is not None: if y is not None:
y = y.to(self.control_model.dtype) y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype) return self.control_merge(None, control, control_prev, output_dtype)
def copy(self): def copy(self):
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
out.append(self.control_model_wrapped) out.append(self.control_model_wrapped)
return out return out
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.model_sampling_current = model.model_sampling
def cleanup(self):
self.model_sampling_current = None
super().cleanup()
class ControlLoraOps: class ControlLoraOps:
class Linear(torch.nn.Module): class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True, def __init__(self, in_features: int, out_features: int, bias: bool = True,

View File

@ -82,6 +82,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp() return log_sigma.exp()
def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0))
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
if model_type == ModelType.EPS: if model_type == ModelType.EPS:
c = EPS c = EPS
@ -126,7 +129,7 @@ class BaseModel(torch.nn.Module):
context = c_crossattn context = c_crossattn
dtype = self.get_dtype() dtype = self.get_dtype()
xc = xc.to(dtype) xc = xc.to(dtype)
t = self.model_sampling.timestep(t).to(dtype) t = self.model_sampling.timestep(t).float()
context = context.to(dtype) context = context.to(dtype)
extra_conds = {} extra_conds = {}
for o in kwargs: for o in kwargs:

View File

@ -415,15 +415,16 @@ def create_cond_with_same_area_if_none(conds, c):
conds += [out] conds += [out]
def calculate_start_end_timesteps(model, conds): def calculate_start_end_timesteps(model, conds):
s = model.model_sampling
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
timestep_start = None timestep_start = None
timestep_end = None timestep_end = None
if 'start_percent' in x: if 'start_percent' in x:
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0))) timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x: if 'end_percent' in x:
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0))) timestep_end = s.percent_to_sigma(x['end_percent'])
if (timestep_start is not None) or (timestep_end is not None): if (timestep_start is not None) or (timestep_end is not None):
n = x.copy() n = x.copy()
@ -434,14 +435,15 @@ def calculate_start_end_timesteps(model, conds):
conds[t] = n conds[t] = n
def pre_run_control(model, conds): def pre_run_control(model, conds):
s = model.model_sampling
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
timestep_start = None timestep_start = None
timestep_end = None timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x: if 'control' in x:
x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) x['control'].pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []
@ -571,8 +573,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
model_wrap = wrap_model(model) model_wrap = wrap_model(model)
calculate_start_end_timesteps(model_wrap, negative) calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model_wrap, positive) calculate_start_end_timesteps(model, positive)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
@ -580,7 +582,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
for c in negative: for c in negative:
create_cond_with_same_area_if_none(positive, c) create_cond_with_same_area_if_none(positive, c)
pre_run_control(model_wrap, negative + positive) pre_run_control(model, negative + positive)
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])