mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-09 08:47:17 +08:00
attn masks can be done using replace patches instead of a separate dict
This commit is contained in:
parent
338e1573a9
commit
581a4c9032
@ -102,7 +102,6 @@ class Flux(nn.Module):
|
|||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
attn_masks = transformer_options.get("attn_masks", {})
|
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
@ -122,18 +121,17 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
mask = attn_masks.get(("double_block", i), None)
|
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], mask=args["mask"])
|
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "mask": mask}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||||
txt = out["txt"]
|
txt = out["txt"]
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, mask=mask)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_i = control.get("input")
|
control_i = control.get("input")
|
||||||
@ -145,17 +143,16 @@ class Flux(nn.Module):
|
|||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
mask = attn_masks.get(("single_block", i), None)
|
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], mask=args["mask"])
|
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "mask": mask}, {"original_block": block_wrap})
|
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap})
|
||||||
img = out["img"]
|
img = out["img"]
|
||||||
else:
|
else:
|
||||||
img = block(img, vec=vec, pe=pe, mask=mask)
|
img = block(img, vec=vec, pe=pe)
|
||||||
|
|
||||||
if control is not None: # Controlnet
|
if control is not None: # Controlnet
|
||||||
control_o = control.get("output")
|
control_o = control.get("output")
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.ldm.flux.layers import SingleStreamBlock, DoubleStreamBlock
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
|
||||||
class CLIPTextEncodeFlux:
|
class CLIPTextEncodeFlux:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -37,8 +41,94 @@ class FluxGuidance:
|
|||||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
|
class _ReduxAttnWrapper:
|
||||||
|
def __init__(self, previous, token_counts, bias=0.0, is_first=False):
|
||||||
|
self.previous = previous
|
||||||
|
self.token_counts = token_counts
|
||||||
|
self.bias = bias
|
||||||
|
self.is_first = is_first
|
||||||
|
|
||||||
|
def __call__(self, args, extra_args):
|
||||||
|
# args: {"img": img, <"txt": txt>, "vec": vec, "pe": pe}
|
||||||
|
if self.is_first:
|
||||||
|
self.token_counts["img"] = args["img"].shape[1]
|
||||||
|
|
||||||
|
# determine the total number of tokens in the mask, depending on whether we're wrapping a single block or a double one
|
||||||
|
total_tokens = args["img"].shape[1]
|
||||||
|
if "txt" in args:
|
||||||
|
total_tokens += args["txt"].shape[1]
|
||||||
|
# create the mask (or bias map)
|
||||||
|
mask = extra_args.get("attn_mask", torch.zeros((total_tokens, total_tokens), device=args["img"].device, dtype=args["img"].dtype))
|
||||||
|
# if this wrapper was called by another ReduxAttnWrapper, compute the range of tokens that correspond to our image
|
||||||
|
redux_end = extra_args.get("redux_end", -self.token_counts["img"])
|
||||||
|
redux_start = redux_end - self.token_counts["redux"]
|
||||||
|
# modify the mask
|
||||||
|
# first 256 tokens are the text prompt
|
||||||
|
mask[:256, redux_start:redux_end] = self.bias
|
||||||
|
# last 'img' tokens are the image being generated
|
||||||
|
mask[-self.token_counts["img"]:, redux_start:redux_end] = self.bias
|
||||||
|
match self.previous:
|
||||||
|
case DoubleStreamBlock():
|
||||||
|
x, c = self.previous(img=args["img"], txt=args["txt"],vec=args["vec"], pe=args["pe"], attn_mask=mask)
|
||||||
|
return {"img": x, "txt": c}
|
||||||
|
case SingleStreamBlock():
|
||||||
|
x = self.previous(img=args["img"], vec=args["vec"], pe=args["pe"], attn_mask=mask)
|
||||||
|
return {"img": x}
|
||||||
|
case _ReduxAttnWrapper():
|
||||||
|
# pass along the mask, and tell the next redux what its part of the mask is
|
||||||
|
extra_args["mask"] = mask
|
||||||
|
extra_args["redux_end"] = redux_start
|
||||||
|
return self.previous(args, extra_args)
|
||||||
|
case _:
|
||||||
|
print(f"Can't wrap {repr(self.previous)} with mask.")
|
||||||
|
return self.previous(args, extra_args)
|
||||||
|
|
||||||
|
class ReduxApplyWithAttnMask:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"model": ("CONDITIONING", ),
|
||||||
|
"conditioning": ("CONDITIONING", ),
|
||||||
|
"style_model": ("STYLE_MODEL", ),
|
||||||
|
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
||||||
|
"attn_bias": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL", "CONDITIONING")
|
||||||
|
FUNCTION = "apply_stylemodel"
|
||||||
|
|
||||||
|
CATEGORY = "conditioning/style_model"
|
||||||
|
|
||||||
|
def apply_stylemodel(self, model: ModelPatcher, clip_vision_output, style_model, conditioning, attn_bias):
|
||||||
|
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
|
||||||
|
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
|
||||||
|
c.append(n)
|
||||||
|
|
||||||
|
if attn_bias != 0.0:
|
||||||
|
token_counts = {
|
||||||
|
"redux": cond.shape[1],
|
||||||
|
"img": None
|
||||||
|
}
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
# patch the model
|
||||||
|
previous_patches = m.model_options["transformer_options"].get("patches_replace", {}).get("dit", {})
|
||||||
|
|
||||||
|
for i, block in m.model.diffusion_model.double_blocks:
|
||||||
|
# is there already a patch there?
|
||||||
|
# if so, the attnwrapper can chain off it
|
||||||
|
previous = previous_patches.get(("double_blocks", i), block)
|
||||||
|
wrapper = _ReduxAttnWrapper(previous, token_counts, bias=attn_bias, is_first=i==0)
|
||||||
|
# I think this properly clones things?
|
||||||
|
m.set_model_patch_replace(wrapper, "dit", "double_blocks", i)
|
||||||
|
else:
|
||||||
|
m = model
|
||||||
|
return (c, m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||||
"FluxGuidance": FluxGuidance,
|
"FluxGuidance": FluxGuidance,
|
||||||
|
"ReduxWithAttnMask": ReduxApplyWithAttnMask
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user