From 25a4805e519ea97110651ae5bb1d7c0e6644b26f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Feb 2024 14:13:31 -0500 Subject: [PATCH] Add a way to set different conditioning for the controlnet. --- comfy/controlnet.py | 2 +- comfy/model_base.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 82170431e..d9d990a71 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -166,7 +166,7 @@ class ControlNet(ControlBase): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] + context = cond.get('crossattn_controlnet', cond['c_crossattn']) y = cond.get('y', None) if y is not None: y = y.to(dtype) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8a843a98c..aafb88e05 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -153,6 +153,10 @@ class BaseModel(torch.nn.Module): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) + if cross_attn_cnet is not None: + out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet) + return out def load_model_weights(self, sd, unet_prefix=""):