mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add a way to set different conditioning for the controlnet.
This commit is contained in:
parent
fd73b5ee3a
commit
25a4805e51
@ -166,7 +166,7 @@ class ControlNet(ControlBase):
|
|||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
context = cond['c_crossattn']
|
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
||||||
y = cond.get('y', None)
|
y = cond.get('y', None)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
y = y.to(dtype)
|
y = y.to(dtype)
|
||||||
|
@ -153,6 +153,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||||
|
|
||||||
|
cross_attn_cnet = kwargs.get("cross_attn_controlnet", None)
|
||||||
|
if cross_attn_cnet is not None:
|
||||||
|
out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_model_weights(self, sd, unet_prefix=""):
|
def load_model_weights(self, sd, unet_prefix=""):
|
||||||
|
Loading…
Reference in New Issue
Block a user