From c14429940f6f9491c77250eb15cad3746e350753 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 17 Apr 2025 12:04:48 -0400 Subject: [PATCH] Support loading WAN FLF model. --- comfy/ldm/wan/model.py | 13 +++++++++++-- comfy/model_detection.py | 3 +++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index d64e73a8..8907f70a 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -251,7 +251,7 @@ class Head(nn.Module): class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim, operation_settings={}): + def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}): super().__init__() self.proj = torch.nn.Sequential( @@ -259,7 +259,15 @@ class MLPProj(torch.nn.Module): torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + if flf_pos_embed_token_number is not None: + self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) + else: + self.emb_pos = None + def forward(self, image_embeds): + if self.emb_pos is not None: + image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device) + clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens @@ -285,6 +293,7 @@ class WanModel(torch.nn.Module): qk_norm=True, cross_attn_norm=True, eps=1e-6, + flf_pos_embed_token_number=None, image_model=None, device=None, dtype=None, @@ -374,7 +383,7 @@ class WanModel(torch.nn.Module): self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)]) if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings) + self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings) else: self.img_emb = None diff --git a/comfy/model_detection.py b/comfy/model_detection.py index a4da1afc..6499bf23 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "i2v" else: dit_config["model_type"] = "t2v" + flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) + if flf_weight is not None: + dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D