mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-13 14:21:20 +00:00
Cleanup.
This commit is contained in:
parent
98f828fad9
commit
0bdc2b15c7
@ -6,7 +6,7 @@ from einops import rearrange, repeat
|
||||
from typing import Optional, Any
|
||||
import logging
|
||||
|
||||
from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
|
||||
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
@ -454,15 +454,11 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
self.checkpoint = checkpoint
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
|
||||
|
||||
def forward(self, x, context=None, transformer_options={}):
|
||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
extra_options = {}
|
||||
block = transformer_options.get("block", None)
|
||||
block_index = transformer_options.get("block_index", 0)
|
||||
@ -629,7 +625,7 @@ class SpatialTransformer(nn.Module):
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
x = x.movedim(1, -1).flatten(1, 2).contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
@ -637,7 +633,7 @@ class SpatialTransformer(nn.Module):
|
||||
x = block(x, context=context[i], transformer_options=transformer_options)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
x = x.reshape(x.shape[0], h, w, x.shape[-1]).movedim(-1, 1).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
@ -258,7 +258,7 @@ class ResBlock(TimestepBlock):
|
||||
else:
|
||||
if emb_out is not None:
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||
emb_out = emb_out.movedim(1, 2)
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
Loading…
Reference in New Issue
Block a user