replace match with if, elif, else

This commit is contained in:
Raphael Walker 2024-12-06 13:18:20 +01:00
parent 7739f5f8d9
commit c0ac4d81e7

View File

@ -67,21 +67,21 @@ class _ReduxAttnWrapper:
mask[:256, redux_start:redux_end] = self.bias
# last 'img' tokens are the image being generated
mask[-self.token_counts["img"]:, redux_start:redux_end] = self.bias
match self.previous:
case DoubleStreamBlock():
x, c = self.previous(img=args["img"], txt=args["txt"],vec=args["vec"], pe=args["pe"], attn_mask=mask)
return {"img": x, "txt": c}
case SingleStreamBlock():
x = self.previous(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=mask)
return {"img": x}
case _ReduxAttnWrapper():
# pass along the mask, and tell the next redux what its part of the mask is
extra_args["attn_mask"] = mask
extra_args["redux_end"] = redux_start
return self.previous(args, extra_args)
case _:
print(f"Can't wrap {repr(self.previous)} with mask.")
return self.previous(args, extra_args)
# nice case for a match statement
if isinstance(self.previous, DoubleStreamBlock):
x, c = self.previous(img=args["img"], txt=args["txt"],vec=args["vec"], pe=args["pe"], attn_mask=mask)
return {"img": x, "txt": c}
elif isinstance(self.previous, SingleStreamBlock):
x = self.previous(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=mask)
return {"img": x}
elif isinstance(self.previous, _ReduxAttnWrapper):
# pass along the mask, and tell the next redux what its part of the mask is
extra_args["attn_mask"] = mask
extra_args["redux_end"] = redux_start
return self.previous(args, extra_args)
else:
print(f"Can't wrap {repr(self.previous)} with mask.")
return self.previous(args, extra_args)
class ReduxApplyWithAttnMask:
@classmethod