diff --git a/comfy/conds.py b/comfy/conds.py index 920e2548..2af2a43a 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -86,3 +86,45 @@ class CONDConstant(CONDRegular): def size(self): return [1] + + +class CONDList(CONDRegular): + def __init__(self, cond): + self.cond = cond + + def process_cond(self, batch_size, device, **kwargs): + out = [] + for c in self.cond: + out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device)) + + return self._copy_with(out) + + def can_concat(self, other): + if len(self.cond) != len(other.cond): + return False + for i in range(len(self.cond)): + if self.cond[i].shape != other.cond[i].shape: + return False + + return True + + def concat(self, others): + out = [] + for i in range(len(self.cond)): + o = [self.cond[i]] + for x in others: + o.append(x.cond[i]) + out.append(torch.cat(o)) + + return out + + def size(self): # hackish implementation to make the mem estimation work + o = 0 + c = 1 + for c in self.cond: + size = c.size() + o += math.prod(size) + if len(size) > 1: + c = size[1] + + return [1, c, o // c] diff --git a/comfy/model_base.py b/comfy/model_base.py index 8ed12427..638b0409 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -168,6 +168,11 @@ class BaseModel(torch.nn.Module): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: extra = extra.to(dtype) + if isinstance(extra, list): + ex = [] + for ext in extra: + ex.append(ext.to(dtype)) + extra = ex extra_conds[o] = extra t = self.process_timestep(t, x=x, **extra_conds)