patch the right keys

This commit is contained in:
Raphael Walker 2024-12-05 15:57:15 +01:00
parent 67bd8d67c2
commit e716b529f7

View File

@ -119,10 +119,15 @@ class ReduxApplyWithAttnMask:
for i, block in enumerate(m.model.diffusion_model.double_blocks):
# is there already a patch there?
# if so, the attnwrapper can chain off it
previous = previous_patches.get(("double_blocks", i), block)
previous = previous_patches.get(("double_block", i), block)
wrapper = _ReduxAttnWrapper(previous, token_counts, bias=attn_bias, is_first=i==0)
# I think this properly clones things?
m.set_model_patch_replace(wrapper, "dit", "double_blocks", i)
m.set_model_patch_replace(wrapper, "dit", "double_block", i)
for i, block in enumerate(m.model.diffusion_model.single_blocks):
previous = previous_patches.get(("single_block", i), block)
wrapper = _ReduxAttnWrapper(previous, token_counts, bias=attn_bias)
m.set_model_patch_replace(wrapper, "dit", "single_block", i)
else:
m = model
return (m, c)