Support conv3d in PatchEmbed.

This commit is contained in:
comfyanonymous 2024-12-14 05:29:30 -05:00
parent 558b7d8b22
commit e83063bf24

View File

@ -71,45 +71,33 @@ class PatchEmbed(nn.Module):
strict_img_size: bool = True, strict_img_size: bool = True,
dynamic_img_pad: bool = True, dynamic_img_pad: bool = True,
padding_mode='circular', padding_mode='circular',
conv3d=False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
): ):
super().__init__() super().__init__()
self.patch_size = (patch_size, patch_size) try:
len(patch_size)
self.patch_size = patch_size
except:
if conv3d:
self.patch_size = (patch_size, patch_size, patch_size)
else:
self.patch_size = (patch_size, patch_size)
self.padding_mode = padding_mode self.padding_mode = padding_mode
if img_size is not None:
self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
self.num_patches = self.grid_size[0] * self.grid_size[1]
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
# flatten spatial dim and transpose to channels last, kept for bwd compat # flatten spatial dim and transpose to channels last, kept for bwd compat
self.flatten = flatten self.flatten = flatten
self.strict_img_size = strict_img_size self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad self.dynamic_img_pad = dynamic_img_pad
if conv3d:
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device) self.proj = operations.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
else:
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x): def forward(self, x):
# B, C, H, W = x.shape
# if self.img_size is not None:
# if self.strict_img_size:
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
# elif not self.dynamic_img_pad:
# _assert(
# H % self.patch_size[0] == 0,
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
# )
# _assert(
# W % self.patch_size[1] == 0,
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# )
if self.dynamic_img_pad: if self.dynamic_img_pad:
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode) x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
x = self.proj(x) x = self.proj(x)