mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-13 15:03:33 +00:00
reduce code duplication
This commit is contained in:
parent
fc978a7ad8
commit
2c2481955d
@ -5,112 +5,26 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
from comfy.ldm.flux.layers import (
|
||||||
|
MLPEmbedder,
|
||||||
|
RMSNorm,
|
||||||
|
QKNorm,
|
||||||
|
SelfAttention,
|
||||||
|
ModulationOut,
|
||||||
|
)
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.theta = theta
|
|
||||||
self.axes_dim = axes_dim
|
|
||||||
|
|
||||||
def forward(self, ids: Tensor) -> Tensor:
|
class ChromaModulationOut(ModulationOut):
|
||||||
n_axes = ids.shape[-1]
|
@classmethod
|
||||||
emb = torch.cat(
|
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
||||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
return cls(
|
||||||
dim=-3,
|
shift=tensor[:, offset : offset + 1, :],
|
||||||
|
scale=tensor[:, offset + 1 : offset + 2, :],
|
||||||
|
gate=tensor[:, offset + 2 : offset + 3, :],
|
||||||
)
|
)
|
||||||
|
|
||||||
return emb.unsqueeze(1)
|
|
||||||
|
|
||||||
|
|
||||||
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
|
||||||
"""
|
|
||||||
Create sinusoidal timestep embeddings.
|
|
||||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
||||||
These may be fractional.
|
|
||||||
:param dim: the dimension of the output.
|
|
||||||
:param max_period: controls the minimum frequency of the embeddings.
|
|
||||||
:return: an (N, D) Tensor of positional embeddings.
|
|
||||||
"""
|
|
||||||
t = time_factor * t
|
|
||||||
half = dim // 2
|
|
||||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
|
||||||
|
|
||||||
args = t[:, None].float() * freqs[None]
|
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
||||||
if dim % 2:
|
|
||||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
||||||
if torch.is_floating_point(t):
|
|
||||||
embedding = embedding.to(t)
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
class MLPEmbedder(nn.Module):
|
|
||||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
self.silu = nn.SiLU()
|
|
||||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return self.out_layer(self.silu(self.in_layer(x)))
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
|
||||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
|
||||||
|
|
||||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
|
||||||
q = self.query_norm(q)
|
|
||||||
k = self.key_norm(k)
|
|
||||||
return q.to(v), k.to(v)
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
|
||||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
head_dim = dim // num_heads
|
|
||||||
|
|
||||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
|
||||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
|
||||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModulationOut:
|
|
||||||
shift: Tensor
|
|
||||||
scale: Tensor
|
|
||||||
gate: Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class Modulation(nn.Module):
|
|
||||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.is_double = double
|
|
||||||
self.multiplier = 6 if double else 3
|
|
||||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, vec: Tensor) -> tuple:
|
|
||||||
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
|
||||||
|
|
||||||
return (
|
|
||||||
ModulationOut(*out[:3]),
|
|
||||||
ModulationOut(*out[3:]) if self.is_double else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,15 +7,17 @@ from torch import Tensor, nn
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
from comfy.ldm.flux.layers import (
|
||||||
|
EmbedND,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
DoubleStreamBlock,
|
||||||
EmbedND,
|
|
||||||
LastLayer,
|
LastLayer,
|
||||||
MLPEmbedder,
|
|
||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
|
||||||
Approximator,
|
Approximator,
|
||||||
ModulationOut
|
ChromaModulationOut,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -39,14 +41,6 @@ class ChromaParams:
|
|||||||
n_layers: int
|
n_layers: int
|
||||||
|
|
||||||
|
|
||||||
class ChromaModulationOut(ModulationOut):
|
|
||||||
@classmethod
|
|
||||||
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
|
|
||||||
return cls(
|
|
||||||
shift=tensor[:, offset : offset + 1, :],
|
|
||||||
scale=tensor[:, offset + 1 : offset + 2, :],
|
|
||||||
gate=tensor[:, offset + 2 : offset + 3, :],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Chroma(nn.Module):
|
class Chroma(nn.Module):
|
||||||
@ -77,7 +71,6 @@ class Chroma(nn.Module):
|
|||||||
self.n_layers = params.n_layers
|
self.n_layers = params.n_layers
|
||||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
self.time_in = MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
|
||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||||
# set as nn identity for now, will overwrite it later.
|
# set as nn identity for now, will overwrite it later.
|
||||||
self.distilled_guidance_layer = Approximator(
|
self.distilled_guidance_layer = Approximator(
|
||||||
@ -88,9 +81,6 @@ class Chroma(nn.Module):
|
|||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
self.guidance_in = (
|
|
||||||
MLPEmbedder(in_dim=64, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if self.distilled_guidance_layer else nn.Identity()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.double_blocks = nn.ModuleList(
|
self.double_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user