mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Reduce artifacts on Wan by doing the patch embedding in fp32.
This commit is contained in:
parent
26c7baf789
commit
0270a0b41c
@ -18,7 +18,7 @@ def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
half = dim // 2
|
||||
position = position.type(torch.float64)
|
||||
position = position.type(torch.float32)
|
||||
|
||||
# calculation
|
||||
sinusoid = torch.outer(
|
||||
@ -353,7 +353,7 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = operations.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
|
||||
self.text_embedding = nn.Sequential(
|
||||
operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
|
||||
operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
|
||||
@ -411,7 +411,7 @@ class WanModel(torch.nn.Module):
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
# embeddings
|
||||
x = self.patch_embedding(x)
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user