Cleanup T5 code a bit.

This commit is contained in:
comfyanonymous 2024-07-06 00:33:25 -04:00
parent 80c4590998
commit b8e58a9394
3 changed files with 23 additions and 14 deletions

View File

@ -13,29 +13,36 @@ class T5LayerNorm(torch.nn.Module):
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(device=x.device, dtype=x.dtype) * x return self.weight.to(device=x.device, dtype=x.dtype) * x
activations = {
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
"relu": torch.nn.functional.relu,
}
class T5DenseActDense(torch.nn.Module): class T5DenseActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device, operations): def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
super().__init__() super().__init__()
self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
# self.dropout = nn.Dropout(config.dropout_rate) # self.dropout = nn.Dropout(config.dropout_rate)
self.act = activations[ff_activation]
def forward(self, x): def forward(self, x):
x = torch.nn.functional.relu(self.wi(x)) x = self.act(self.wi(x))
# x = self.dropout(x) # x = self.dropout(x)
x = self.wo(x) x = self.wo(x)
return x return x
class T5DenseGatedActDense(torch.nn.Module): class T5DenseGatedActDense(torch.nn.Module):
def __init__(self, model_dim, ff_dim, dtype, device, operations): def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations):
super().__init__() super().__init__()
self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)
self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)
# self.dropout = nn.Dropout(config.dropout_rate) # self.dropout = nn.Dropout(config.dropout_rate)
self.act = activations[ff_activation]
def forward(self, x): def forward(self, x):
hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") hidden_gelu = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x) hidden_linear = self.wi_1(x)
x = hidden_gelu * hidden_linear x = hidden_gelu * hidden_linear
# x = self.dropout(x) # x = self.dropout(x)
@ -43,12 +50,12 @@ class T5DenseGatedActDense(torch.nn.Module):
return x return x
class T5LayerFF(torch.nn.Module): class T5LayerFF(torch.nn.Module):
def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations): def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations):
super().__init__() super().__init__()
if ff_activation == "gelu_pytorch_tanh": if gated_act:
self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations) self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
elif ff_activation == "relu": else:
self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations) self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations)
self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate) # self.dropout = nn.Dropout(config.dropout_rate)
@ -171,11 +178,11 @@ class T5LayerSelfAttention(torch.nn.Module):
return x, past_bias return x, past_bias
class T5Block(torch.nn.Module): class T5Block(torch.nn.Module):
def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations): def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations):
super().__init__() super().__init__()
self.layer = torch.nn.ModuleList() self.layer = torch.nn.ModuleList()
self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations)) self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations))
self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations)) self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations))
def forward(self, x, mask=None, past_bias=None, optimized_attention=None): def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention) x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
@ -183,11 +190,11 @@ class T5Block(torch.nn.Module):
return x, past_bias return x, past_bias
class T5Stack(torch.nn.Module): class T5Stack(torch.nn.Module):
def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations): def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, dtype, device, operations):
super().__init__() super().__init__()
self.block = torch.nn.ModuleList( self.block = torch.nn.ModuleList(
[T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)] [T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)]
) )
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate) # self.dropout = nn.Dropout(config.dropout_rate)
@ -216,7 +223,7 @@ class T5(torch.nn.Module):
self.num_layers = config_dict["num_layers"] self.num_layers = config_dict["num_layers"]
model_dim = config_dict["d_model"] model_dim = config_dict["d_model"]
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations) self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], dtype, device, operations)
self.dtype = dtype self.dtype = dtype
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device) self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)

View File

@ -8,6 +8,7 @@
"dense_act_fn": "relu", "dense_act_fn": "relu",
"initializer_factor": 1.0, "initializer_factor": 1.0,
"is_encoder_decoder": true, "is_encoder_decoder": true,
"is_gated_act": false,
"layer_norm_epsilon": 1e-06, "layer_norm_epsilon": 1e-06,
"model_type": "t5", "model_type": "t5",
"num_decoder_layers": 12, "num_decoder_layers": 12,

View File

@ -8,6 +8,7 @@
"dense_act_fn": "gelu_pytorch_tanh", "dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0, "initializer_factor": 1.0,
"is_encoder_decoder": true, "is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06, "layer_norm_epsilon": 1e-06,
"model_type": "t5", "model_type": "t5",
"num_decoder_layers": 24, "num_decoder_layers": 24,