rebased onto "master"

This commit is contained in:
Alexander Piskun 2025-01-11 22:14:26 +03:00
parent d83ab7d1d6
commit 92f831eb24
2 changed files with 14 additions and 13 deletions

View File

@ -217,19 +217,19 @@ class GeneralDIT(nn.Module):
raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}")
logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}")
kwargs = dict(
model_channels=self.model_channels,
len_h=self.max_img_h // self.patch_spatial,
len_w=self.max_img_w // self.patch_spatial,
len_t=self.max_frames // self.patch_temporal,
is_learnable=self.pos_emb_learnable,
interpolation=self.pos_emb_interpolation,
head_dim=self.model_channels // self.num_heads,
h_extrapolation_ratio=self.rope_h_extrapolation_ratio,
w_extrapolation_ratio=self.rope_w_extrapolation_ratio,
t_extrapolation_ratio=self.rope_t_extrapolation_ratio,
device=device,
)
kwargs = {
"model_channels": self.model_channels,
"len_h": self.max_img_h // self.patch_spatial,
"len_w": self.max_img_w // self.patch_spatial,
"len_t": self.max_frames // self.patch_temporal,
"is_learnable": self.pos_emb_learnable,
"interpolation": self.pos_emb_interpolation,
"head_dim": self.model_channels // self.num_heads,
"h_extrapolation_ratio": self.rope_h_extrapolation_ratio,
"w_extrapolation_ratio": self.rope_w_extrapolation_ratio,
"t_extrapolation_ratio": self.rope_t_extrapolation_ratio,
"device": device,
}
self.pos_embedder = cls_type(
**kwargs,
)

View File

@ -19,5 +19,6 @@ lint.select = [
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
"F",
"C408", # unnecessary dict(), list() or tuple() calls that can be rewritten as empty literals.
]
exclude = ["*.ipynb"]