mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-13 22:31:08 +00:00
ONNX tracing fixes.
This commit is contained in:
parent
0a6b008117
commit
3b71f84b50
@ -9,6 +9,7 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
@ -407,10 +408,7 @@ class MMDiT(nn.Module):
|
||||
|
||||
def patchify(self, x):
|
||||
B, C, H, W = x.size()
|
||||
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
|
||||
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
|
||||
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
x = x.view(
|
||||
B,
|
||||
C,
|
||||
|
8
comfy/ldm/common_dit.py
Normal file
8
comfy/ldm/common_dit.py
Normal file
@ -0,0 +1,8 @@
|
||||
import torch
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
@ -15,6 +15,7 @@ from .layers import (
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
@ -42,7 +43,7 @@ class Flux(nn.Module):
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.in_channels = params.in_channels * 2 * 2
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
@ -125,10 +126,7 @@ class Flux(nn.Module):
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
pad_h = (patch_size - h % 2) % patch_size
|
||||
pad_w = (patch_size - w % 2) % patch_size
|
||||
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='circular')
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
|
@ -9,6 +9,7 @@ from .. import attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def default(x, y):
|
||||
if x is not None:
|
||||
@ -111,9 +112,7 @@ class PatchEmbed(nn.Module):
|
||||
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
# )
|
||||
if self.dynamic_img_pad:
|
||||
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
||||
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
||||
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode=self.padding_mode)
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
|
@ -131,7 +131,7 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["in_channels"] = 64
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
|
Loading…
Reference in New Issue
Block a user