mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-14 21:47:07 +00:00
Clean up percent start/end and make controlnets work with sigmas.
This commit is contained in:
parent
a268a574fa
commit
7c0f255de1
@ -132,6 +132,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.model_sampling_current = None
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
control_prev = None
|
||||
@ -159,7 +160,10 @@ class ControlNet(ControlBase):
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(self.control_model.dtype)
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
|
||||
out.append(self.control_model_wrapped)
|
||||
return out
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.model_sampling_current = model.model_sampling
|
||||
|
||||
def cleanup(self):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
|
@ -82,6 +82,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
return self.sigma(torch.tensor(percent * 999.0))
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
@ -126,7 +129,7 @@ class BaseModel(torch.nn.Module):
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
xc = xc.to(dtype)
|
||||
t = self.model_sampling.timestep(t).to(dtype)
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
context = context.to(dtype)
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
|
@ -415,15 +415,16 @@ def create_cond_with_same_area_if_none(conds, c):
|
||||
conds += [out]
|
||||
|
||||
def calculate_start_end_timesteps(model, conds):
|
||||
s = model.model_sampling
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
if 'start_percent' in x:
|
||||
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0)))
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0)))
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
|
||||
if (timestep_start is not None) or (timestep_end is not None):
|
||||
n = x.copy()
|
||||
@ -434,14 +435,15 @@ def calculate_start_end_timesteps(model, conds):
|
||||
conds[t] = n
|
||||
|
||||
def pre_run_control(model, conds):
|
||||
s = model.model_sampling
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@ -571,8 +573,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
|
||||
model_wrap = wrap_model(model)
|
||||
|
||||
calculate_start_end_timesteps(model_wrap, negative)
|
||||
calculate_start_end_timesteps(model_wrap, positive)
|
||||
calculate_start_end_timesteps(model, negative)
|
||||
calculate_start_end_timesteps(model, positive)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
@ -580,7 +582,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
pre_run_control(model_wrap, negative + positive)
|
||||
pre_run_control(model, negative + positive)
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
|
Loading…
Reference in New Issue
Block a user