mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Sampling code refactor to make it easier to add more conds.
This commit is contained in:
parent
5c65da312a
commit
3fce8881ca
@ -9,9 +9,58 @@ import math
|
|||||||
from comfy import model_base
|
from comfy import model_base
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
|
|
||||||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
|
||||||
return abs(a*b) // math.gcd(a, b)
|
return abs(a*b) // math.gcd(a, b)
|
||||||
|
|
||||||
|
class CONDRegular:
|
||||||
|
def __init__(self, cond):
|
||||||
|
self.cond = cond
|
||||||
|
|
||||||
|
def can_concat(self, other):
|
||||||
|
if self.cond.shape != other.cond.shape:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def concat(self, others):
|
||||||
|
conds = [self.cond]
|
||||||
|
for x in others:
|
||||||
|
conds.append(x.cond)
|
||||||
|
return torch.cat(conds)
|
||||||
|
|
||||||
|
class CONDCrossAttn:
|
||||||
|
def __init__(self, cond):
|
||||||
|
self.cond = cond
|
||||||
|
|
||||||
|
def can_concat(self, other):
|
||||||
|
s1 = self.cond.shape
|
||||||
|
s2 = other.cond.shape
|
||||||
|
if s1 != s2:
|
||||||
|
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
||||||
|
return False
|
||||||
|
|
||||||
|
mult_min = lcm(s1[1], s2[1])
|
||||||
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def concat(self, others):
|
||||||
|
conds = [self.cond]
|
||||||
|
crossattn_max_len = self.cond.shape[1]
|
||||||
|
for x in others:
|
||||||
|
c = x.cond
|
||||||
|
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
||||||
|
conds.append(c)
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for c in conds:
|
||||||
|
if c.shape[1] < crossattn_max_len:
|
||||||
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||||
|
out.append(c)
|
||||||
|
return torch.cat(out)
|
||||||
|
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns predicted noise
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
@ -67,7 +116,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
|
||||||
|
|
||||||
conditionning = {}
|
conditionning = {}
|
||||||
conditionning['c_crossattn'] = cond[0]
|
conditionning['c_crossattn'] = CONDCrossAttn(cond[0])
|
||||||
|
|
||||||
if 'concat' in cond[1]:
|
if 'concat' in cond[1]:
|
||||||
cond_concat_in = cond[1]['concat']
|
cond_concat_in = cond[1]['concat']
|
||||||
@ -76,10 +125,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
for x in cond_concat_in:
|
for x in cond_concat_in:
|
||||||
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
|
||||||
cropped.append(cr)
|
cropped.append(cr)
|
||||||
conditionning['c_concat'] = torch.cat(cropped, dim=1)
|
conditionning['c_concat'] = CONDRegular(torch.cat(cropped, dim=1))
|
||||||
|
|
||||||
if adm_cond is not None:
|
if adm_cond is not None:
|
||||||
conditionning['c_adm'] = adm_cond
|
conditionning['c_adm'] = CONDRegular(adm_cond)
|
||||||
|
|
||||||
control = None
|
control = None
|
||||||
if 'control' in cond[1]:
|
if 'control' in cond[1]:
|
||||||
@ -105,22 +154,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
return True
|
return True
|
||||||
if c1.keys() != c2.keys():
|
if c1.keys() != c2.keys():
|
||||||
return False
|
return False
|
||||||
if 'c_crossattn' in c1:
|
for k in c1:
|
||||||
s1 = c1['c_crossattn'].shape
|
if not c1[k].can_concat(c2[k]):
|
||||||
s2 = c2['c_crossattn'].shape
|
|
||||||
if s1 != s2:
|
|
||||||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
|
|
||||||
return False
|
|
||||||
|
|
||||||
mult_min = lcm(s1[1], s2[1])
|
|
||||||
diff = mult_min // min(s1[1], s2[1])
|
|
||||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
|
||||||
return False
|
|
||||||
if 'c_concat' in c1:
|
|
||||||
if c1['c_concat'].shape != c2['c_concat'].shape:
|
|
||||||
return False
|
|
||||||
if 'c_adm' in c1:
|
|
||||||
if c1['c_adm'].shape != c2['c_adm'].shape:
|
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -149,31 +184,19 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
c_concat = []
|
c_concat = []
|
||||||
c_adm = []
|
c_adm = []
|
||||||
crossattn_max_len = 0
|
crossattn_max_len = 0
|
||||||
for x in c_list:
|
|
||||||
if 'c_crossattn' in x:
|
|
||||||
c = x['c_crossattn']
|
|
||||||
if crossattn_max_len == 0:
|
|
||||||
crossattn_max_len = c.shape[1]
|
|
||||||
else:
|
|
||||||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
|
|
||||||
c_crossattn.append(c)
|
|
||||||
if 'c_concat' in x:
|
|
||||||
c_concat.append(x['c_concat'])
|
|
||||||
if 'c_adm' in x:
|
|
||||||
c_adm.append(x['c_adm'])
|
|
||||||
out = {}
|
|
||||||
c_crossattn_out = []
|
|
||||||
for c in c_crossattn:
|
|
||||||
if c.shape[1] < crossattn_max_len:
|
|
||||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
|
||||||
c_crossattn_out.append(c)
|
|
||||||
|
|
||||||
if len(c_crossattn_out) > 0:
|
temp = {}
|
||||||
out['c_crossattn'] = torch.cat(c_crossattn_out)
|
for x in c_list:
|
||||||
if len(c_concat) > 0:
|
for k in x:
|
||||||
out['c_concat'] = torch.cat(c_concat)
|
cur = temp.get(k, [])
|
||||||
if len(c_adm) > 0:
|
cur.append(x[k])
|
||||||
out['c_adm'] = torch.cat(c_adm)
|
temp[k] = cur
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
for k in temp:
|
||||||
|
conds = temp[k]
|
||||||
|
out[k] = conds[0].concat(conds[1:])
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||||
|
Loading…
Reference in New Issue
Block a user