mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Cleanup T5 code a bit.
This commit is contained in:
parent
80c4590998
commit
b8e58a9394
35
comfy/t5.py
35
comfy/t5.py
@ -13,29 +13,36 @@ class T5LayerNorm(torch.nn.Module):
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight.to(device=x.device, dtype=x.dtype) * x
|
||||
|
||||
activations = {
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
"relu": torch.nn.functional.relu,
|
||||
}
|
||||
|
||||
class T5DenseActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
self.act = activations[ff_activation]
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.relu(self.wi(x))
|
||||
x = self.act(self.wi(x))
|
||||
# x = self.dropout(x)
|
||||
x = self.wo(x)
|
||||
return x
|
||||
|
||||
class T5DenseGatedActDense(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, dtype, device, operations):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
|
||||
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
self.act = activations[ff_activation]
|
||||
|
||||
def forward(self, x):
|
||||
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh")
|
||||
hidden_gelu = self.act(self.wi_0(x))
|
||||
hidden_linear = self.wi_1(x)
|
||||
x = hidden_gelu * hidden_linear
|
||||
# x = self.dropout(x)
|
||||
@ -43,12 +50,12 @@ class T5DenseGatedActDense(torch.nn.Module):
|
||||
return x
|
||||
|
||||
class T5LayerFF(torch.nn.Module):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
|
||||
def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
|
||||
super().__init__()
|
||||
if ff_activation == "gelu_pytorch_tanh":
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations)
|
||||
elif ff_activation == "relu":
|
||||
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations)
|
||||
if gated_act:
|
||||
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
||||
else:
|
||||
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
|
||||
|
||||
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
@ -171,11 +178,11 @@ class T5LayerSelfAttention(torch.nn.Module):
|
||||
return x, past_bias
|
||||
|
||||
class T5Block(torch.nn.Module):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations):
|
||||
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList()
|
||||
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations))
|
||||
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations))
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
|
||||
@ -183,11 +190,11 @@ class T5Block(torch.nn.Module):
|
||||
return x, past_bias
|
||||
|
||||
class T5Stack(torch.nn.Module):
|
||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations):
|
||||
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
self.block = torch.nn.ModuleList(
|
||||
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
||||
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
|
||||
)
|
||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
@ -216,7 +223,7 @@ class T5(torch.nn.Module):
|
||||
self.num_layers = config_dict["num_layers"]
|
||||
model_dim = config_dict["d_model"]
|
||||
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations)
|
||||
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], dtype, device, operations)
|
||||
self.dtype = dtype
|
||||
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
"dense_act_fn": "relu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": false,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 12,
|
||||
|
@ -8,6 +8,7 @@
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 24,
|
||||
|
Loading…
Reference in New Issue
Block a user