From 69694f40b3033c7a8969a764369b00c586ffcf3a Mon Sep 17 00:00:00 2001 From: contentis Date: Mon, 4 Nov 2024 11:59:28 -0800 Subject: [PATCH] fix dynamic shape export (#5490) --- comfy/ldm/flux/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index b7d8a692..233d7839 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -151,8 +151,8 @@ class Flux(nn.Module): h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)