From 8cf1daa108400f2e29188fa0b4404d6ebc83b864 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 18 Dec 2023 12:54:23 -0500 Subject: [PATCH] Fix SDXL area composition sometimes not using the right pooled output. --- comfy/model_base.py | 10 ++++++++++ comfy/samplers.py | 7 ++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index c80848b2..f2a6f984 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -126,9 +126,15 @@ class BaseModel(torch.nn.Module): cond_concat.append(blank_inpaint_image_like(noise)) data = torch.cat(cond_concat, dim=1) out['c_concat'] = comfy.conds.CONDNoiseShape(data) + adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = comfy.conds.CONDRegular(adm) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + return out def load_model_weights(self, sd, unet_prefix=""): @@ -322,6 +328,10 @@ class SVD_img2vid(BaseModel): out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + if "time_conditioning" in kwargs: out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) diff --git a/comfy/samplers.py b/comfy/samplers.py index 35c9ccf0..18bd75ef 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -599,6 +599,10 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model calculate_start_end_timesteps(model, negative) calculate_start_end_timesteps(model, positive) + if hasattr(model, 'extra_conds'): + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) + #make sure each cond area has an opposite one with the same area for c in positive: create_cond_with_same_area_if_none(negative, c) @@ -613,9 +617,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model if latent_image is not None: latent_image = model.process_latent_in(latent_image) - if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}