From b8e58a939463be85877c1244ee00763368723c07 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 6 Jul 2024 00:33:25 -0400 Subject: [PATCH] Cleanup T5 code a bit. --- comfy/t5.py | 35 +++++++++++++++++++++-------------- comfy/t5_config_base.json | 1 + comfy/t5_config_xxl.json | 1 + 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/comfy/t5.py b/comfy/t5.py index 06dfe476..d00b560a 100644 --- a/comfy/t5.py +++ b/comfy/t5.py @@ -13,29 +13,36 @@ class T5LayerNorm(torch.nn.Module): x = x * torch.rsqrt(variance + self.variance_epsilon) return self.weight.to(device=x.device, dtype=x.dtype) * x +activations = { + "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"), + "relu": torch.nn.functional.relu, +} + class T5DenseActDense(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device, operations): + def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations): super().__init__() self.wi = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) # self.dropout = nn.Dropout(config.dropout_rate) + self.act = activations[ff_activation] def forward(self, x): - x = torch.nn.functional.relu(self.wi(x)) + x = self.act(self.wi(x)) # x = self.dropout(x) x = self.wo(x) return x class T5DenseGatedActDense(torch.nn.Module): - def __init__(self, model_dim, ff_dim, dtype, device, operations): + def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations): super().__init__() self.wi_0 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) self.wi_1 = operations.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) self.wo = operations.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) # self.dropout = nn.Dropout(config.dropout_rate) + self.act = activations[ff_activation] def forward(self, x): - hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_gelu = self.act(self.wi_0(x)) hidden_linear = self.wi_1(x) x = hidden_gelu * hidden_linear # x = self.dropout(x) @@ -43,12 +50,12 @@ class T5DenseGatedActDense(torch.nn.Module): return x class T5LayerFF(torch.nn.Module): - def __init__(self, model_dim, ff_dim, ff_activation, dtype, device, operations): + def __init__(self, model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations): super().__init__() - if ff_activation == "gelu_pytorch_tanh": - self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device, operations) - elif ff_activation == "relu": - self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, dtype, device, operations) + if gated_act: + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, ff_activation, dtype, device, operations) + else: + self.DenseReluDense = T5DenseActDense(model_dim, ff_dim, ff_activation, dtype, device, operations) self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) # self.dropout = nn.Dropout(config.dropout_rate) @@ -171,11 +178,11 @@ class T5LayerSelfAttention(torch.nn.Module): return x, past_bias class T5Block(torch.nn.Module): - def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias, dtype, device, operations): + def __init__(self, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias, dtype, device, operations): super().__init__() self.layer = torch.nn.ModuleList() self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device, operations)) - self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, dtype, device, operations)) + self.layer.append(T5LayerFF(model_dim, ff_dim, ff_activation, gated_act, dtype, device, operations)) def forward(self, x, mask=None, past_bias=None, optimized_attention=None): x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention) @@ -183,11 +190,11 @@ class T5Block(torch.nn.Module): return x, past_bias class T5Stack(torch.nn.Module): - def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, num_heads, dtype, device, operations): + def __init__(self, num_layers, model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, dtype, device, operations): super().__init__() self.block = torch.nn.ModuleList( - [T5Block(model_dim, inner_dim, ff_dim, ff_activation, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)] + [T5Block(model_dim, inner_dim, ff_dim, ff_activation, gated_act, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device, operations=operations) for i in range(num_layers)] ) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) # self.dropout = nn.Dropout(config.dropout_rate) @@ -216,7 +223,7 @@ class T5(torch.nn.Module): self.num_layers = config_dict["num_layers"] model_dim = config_dict["d_model"] - self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["num_heads"], dtype, device, operations) + self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], dtype, device, operations) self.dtype = dtype self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device) diff --git a/comfy/t5_config_base.json b/comfy/t5_config_base.json index facd85ef..71f68327 100644 --- a/comfy/t5_config_base.json +++ b/comfy/t5_config_base.json @@ -8,6 +8,7 @@ "dense_act_fn": "relu", "initializer_factor": 1.0, "is_encoder_decoder": true, + "is_gated_act": false, "layer_norm_epsilon": 1e-06, "model_type": "t5", "num_decoder_layers": 12, diff --git a/comfy/t5_config_xxl.json b/comfy/t5_config_xxl.json index bf4feadc..28283b51 100644 --- a/comfy/t5_config_xxl.json +++ b/comfy/t5_config_xxl.json @@ -8,6 +8,7 @@ "dense_act_fn": "gelu_pytorch_tanh", "initializer_factor": 1.0, "is_encoder_decoder": true, + "is_gated_act": true, "layer_norm_epsilon": 1e-06, "model_type": "t5", "num_decoder_layers": 24,