mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Clean up percent start/end and make controlnets work with sigmas.
This commit is contained in:
parent
a268a574fa
commit
7c0f255de1
@ -132,6 +132,7 @@ class ControlNet(ControlBase):
|
|||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
self.model_sampling_current = None
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -159,7 +160,10 @@ class ControlNet(ControlBase):
|
|||||||
y = cond.get('y', None)
|
y = cond.get('y', None)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
y = y.to(self.control_model.dtype)
|
y = y.to(self.control_model.dtype)
|
||||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
|
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
||||||
return self.control_merge(None, control, control_prev, output_dtype)
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
|
|||||||
out.append(self.control_model_wrapped)
|
out.append(self.control_model_wrapped)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
self.model_sampling_current = model.model_sampling
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
self.model_sampling_current = None
|
||||||
|
super().cleanup()
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module):
|
class Linear(torch.nn.Module):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
|
@ -82,6 +82,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
return log_sigma.exp()
|
return log_sigma.exp()
|
||||||
|
|
||||||
|
def percent_to_sigma(self, percent):
|
||||||
|
return self.sigma(torch.tensor(percent * 999.0))
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
if model_type == ModelType.EPS:
|
if model_type == ModelType.EPS:
|
||||||
c = EPS
|
c = EPS
|
||||||
@ -126,7 +129,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = self.model_sampling.timestep(t).to(dtype)
|
t = self.model_sampling.timestep(t).float()
|
||||||
context = context.to(dtype)
|
context = context.to(dtype)
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
|
@ -415,15 +415,16 @@ def create_cond_with_same_area_if_none(conds, c):
|
|||||||
conds += [out]
|
conds += [out]
|
||||||
|
|
||||||
def calculate_start_end_timesteps(model, conds):
|
def calculate_start_end_timesteps(model, conds):
|
||||||
|
s = model.model_sampling
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
|
||||||
timestep_start = None
|
timestep_start = None
|
||||||
timestep_end = None
|
timestep_end = None
|
||||||
if 'start_percent' in x:
|
if 'start_percent' in x:
|
||||||
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0)))
|
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||||
if 'end_percent' in x:
|
if 'end_percent' in x:
|
||||||
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0)))
|
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||||
|
|
||||||
if (timestep_start is not None) or (timestep_end is not None):
|
if (timestep_start is not None) or (timestep_end is not None):
|
||||||
n = x.copy()
|
n = x.copy()
|
||||||
@ -434,14 +435,15 @@ def calculate_start_end_timesteps(model, conds):
|
|||||||
conds[t] = n
|
conds[t] = n
|
||||||
|
|
||||||
def pre_run_control(model, conds):
|
def pre_run_control(model, conds):
|
||||||
|
s = model.model_sampling
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
|
||||||
timestep_start = None
|
timestep_start = None
|
||||||
timestep_end = None
|
timestep_end = None
|
||||||
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||||
if 'control' in x:
|
if 'control' in x:
|
||||||
x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
|
x['control'].pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
@ -571,8 +573,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
|
|
||||||
model_wrap = wrap_model(model)
|
model_wrap = wrap_model(model)
|
||||||
|
|
||||||
calculate_start_end_timesteps(model_wrap, negative)
|
calculate_start_end_timesteps(model, negative)
|
||||||
calculate_start_end_timesteps(model_wrap, positive)
|
calculate_start_end_timesteps(model, positive)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for c in positive:
|
||||||
@ -580,7 +582,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
for c in negative:
|
for c in negative:
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
create_cond_with_same_area_if_none(positive, c)
|
||||||
|
|
||||||
pre_run_control(model_wrap, negative + positive)
|
pre_run_control(model, negative + positive)
|
||||||
|
|
||||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
|
Loading…
Reference in New Issue
Block a user