Raw torch is faster than einops?

This commit is contained in:
comfyanonymous 2024-08-08 22:09:29 -04:00
parent 11200de970
commit 413322645e

View File

@ -2,7 +2,6 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from .math import attention, rope from .math import attention, rope
@ -37,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
""" """
t = time_factor * t t = time_factor * t
half = dim // 2 half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
t.device
)
args = t[:, None].float() * freqs[None] args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@ -49,7 +46,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
embedding = embedding.to(t) embedding = embedding.to(t)
return embedding return embedding
class MLPEmbedder(nn.Module): class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
@ -95,14 +91,6 @@ class SelfAttention(nn.Module):
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass @dataclass
class ModulationOut: class ModulationOut:
@ -164,14 +152,14 @@ class DoubleStreamBlock(nn.Module):
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention # run actual attention
@ -236,7 +224,7 @@ class SingleStreamBlock(nn.Module):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention