mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 18:35:17 +00:00
26 lines
724 B
Python
26 lines
724 B
Python
|
import torch
|
||
|
import comfy.ops
|
||
|
|
||
|
ops = comfy.ops.manual_cast
|
||
|
|
||
|
class ReduxImageEncoder(torch.nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
redux_dim: int = 1152,
|
||
|
txt_in_features: int = 4096,
|
||
|
device=None,
|
||
|
dtype=None,
|
||
|
) -> None:
|
||
|
super().__init__()
|
||
|
|
||
|
self.redux_dim = redux_dim
|
||
|
self.device = device
|
||
|
self.dtype = dtype
|
||
|
|
||
|
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||
|
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||
|
|
||
|
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||
|
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||
|
return projected_x
|