diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 82170431e..d9d990a71 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -166,7 +166,7 @@ class ControlNet(ControlBase): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] + context = cond.get('crossattn_controlnet', cond['c_crossattn']) y = cond.get('y', None) if y is not None: y = y.to(dtype) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8a843a98c..aafb88e05 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -153,6 +153,10 @@ class BaseModel(torch.nn.Module): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) + if cross_attn_cnet is not None: + out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet) + return out def load_model_weights(self, sd, unet_prefix=""):