From 0270a0b41cef69726694b189f37942a04d762c8a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 26 Feb 2025 16:59:26 -0500 Subject: [PATCH] Reduce artifacts on Wan by doing the patch embedding in fp32. --- comfy/ldm/wan/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index c67a65b24..dbe84b8bb 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -18,7 +18,7 @@ def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 half = dim // 2 - position = position.type(torch.float64) + position = position.type(torch.float32) # calculation sinusoid = torch.outer( @@ -353,7 +353,7 @@ class WanModel(torch.nn.Module): # embeddings self.patch_embedding = operations.Conv3d( - in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32) self.text_embedding = nn.Sequential( operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))) @@ -411,7 +411,7 @@ class WanModel(torch.nn.Module): List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ # embeddings - x = self.patch_embedding(x) + x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] x = x.flatten(2).transpose(1, 2)