mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Support conv3d in PatchEmbed.
This commit is contained in:
parent
558b7d8b22
commit
e83063bf24
@ -71,45 +71,33 @@ class PatchEmbed(nn.Module):
|
|||||||
strict_img_size: bool = True,
|
strict_img_size: bool = True,
|
||||||
dynamic_img_pad: bool = True,
|
dynamic_img_pad: bool = True,
|
||||||
padding_mode='circular',
|
padding_mode='circular',
|
||||||
|
conv3d=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = (patch_size, patch_size)
|
try:
|
||||||
|
len(patch_size)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
except:
|
||||||
|
if conv3d:
|
||||||
|
self.patch_size = (patch_size, patch_size, patch_size)
|
||||||
|
else:
|
||||||
|
self.patch_size = (patch_size, patch_size)
|
||||||
self.padding_mode = padding_mode
|
self.padding_mode = padding_mode
|
||||||
if img_size is not None:
|
|
||||||
self.img_size = (img_size, img_size)
|
|
||||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
|
||||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
||||||
else:
|
|
||||||
self.img_size = None
|
|
||||||
self.grid_size = None
|
|
||||||
self.num_patches = None
|
|
||||||
|
|
||||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||||
self.flatten = flatten
|
self.flatten = flatten
|
||||||
self.strict_img_size = strict_img_size
|
self.strict_img_size = strict_img_size
|
||||||
self.dynamic_img_pad = dynamic_img_pad
|
self.dynamic_img_pad = dynamic_img_pad
|
||||||
|
if conv3d:
|
||||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
self.proj = operations.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# B, C, H, W = x.shape
|
|
||||||
# if self.img_size is not None:
|
|
||||||
# if self.strict_img_size:
|
|
||||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
|
||||||
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
|
||||||
# elif not self.dynamic_img_pad:
|
|
||||||
# _assert(
|
|
||||||
# H % self.patch_size[0] == 0,
|
|
||||||
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
|
||||||
# )
|
|
||||||
# _assert(
|
|
||||||
# W % self.patch_size[1] == 0,
|
|
||||||
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
|
||||||
# )
|
|
||||||
if self.dynamic_img_pad:
|
if self.dynamic_img_pad:
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user