diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 846703d5..8f4d99f5 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -195,20 +195,50 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def forward(self, x, timestep, context, y=None, guidance=None, control=None, transformer_options={}, **kwargs): + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) + + h_offset = ((h_offset + (patch_size // 2)) // patch_size) + w_offset = ((w_offset + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) + + def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): + bs, c, h_orig, w_orig = x.shape + patch_size = self.patch_size + + h_len = ((h_orig + (patch_size // 2)) // patch_size) + w_len = ((w_orig + (patch_size // 2)) // patch_size) + img, img_ids = self.process_img(x) + img_tokens = img.shape[1] + if ref_latents is not None: + h = 0 + w = 0 + for ref in ref_latents: + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + + kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset) + img = torch.cat([img, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w] + out = out[:, :img_tokens] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] diff --git a/comfy/model_base.py b/comfy/model_base.py index 12b0f3dc..fcdfde37 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -816,6 +816,7 @@ class PixArt(BaseModel): class Flux(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux): super().__init__(model_config, model_type, device=device, unet_model=unet_model) + self.memory_usage_factor_conds = ("kontext",) def concat_cond(self, **kwargs): try: @@ -876,8 +877,23 @@ class Flux(BaseModel): guidance = kwargs.get("guidance", 3.5) if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) return out + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + return out + + class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint) diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index ad6c15f3..8a8a1769 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -1,4 +1,5 @@ import node_helpers +import comfy.utils class CLIPTextEncodeFlux: @classmethod @@ -56,8 +57,52 @@ class FluxDisableGuidance: return (c, ) +PREFERED_KONTEXT_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +class FluxKontextImageScale: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE", ), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "scale" + + CATEGORY = "advanced/conditioning/flux" + DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext." + + def scale(self, image): + width = image.shape[2] + height = image.shape[1] + aspect_ratio = width / height + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) + image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) + return (image, ) + + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeFlux": CLIPTextEncodeFlux, "FluxGuidance": FluxGuidance, "FluxDisableGuidance": FluxDisableGuidance, + "FluxKontextImageScale": FluxKontextImageScale, }