From a9ac56fc0db5777de0edf2fe4b8ed628ccab1293 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 26 Jul 2024 04:32:33 -0400 Subject: [PATCH] Own BertModel implementation that works with lowvram. --- comfy/text_encoders/bert.py | 139 +++++++++++++++++++++++++++++++++++ comfy/text_encoders/hydit.py | 47 +----------- 2 files changed, 142 insertions(+), 44 deletions(-) create mode 100644 comfy/text_encoders/bert.py diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py new file mode 100644 index 00000000..b76e7666 --- /dev/null +++ b/comfy/text_encoders/bert.py @@ -0,0 +1,139 @@ +import torch +from comfy.ldm.modules.attention import optimized_attention_for_device + +class BertAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device, operations): + super().__init__() + + self.heads = heads + self.query = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.key = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + self.value = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device) + + + def forward(self, x, mask=None, optimized_attention=None): + q = self.query(x) + k = self.key(x) + v = self.value(x) + + out = optimized_attention(q, k, v, self.heads, mask) + return out + +class BertOutput(torch.nn.Module): + def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations): + super().__init__() + self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device) + self.LayerNorm = operations.LayerNorm(output_dim, eps=layer_norm_eps, dtype=dtype, device=device) + # self.dropout = nn.Dropout(0.0) + + def forward(self, x, y): + x = self.dense(x) + # hidden_states = self.dropout(hidden_states) + x = self.LayerNorm(x + y) + return x + +class BertAttentionBlock(torch.nn.Module): + def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): + super().__init__() + self.self = BertAttention(embed_dim, heads, dtype, device, operations) + self.output = BertOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) + + def forward(self, x, mask, optimized_attention): + y = self.self(x, mask, optimized_attention) + return self.output(y, x) + +class BertIntermediate(torch.nn.Module): + def __init__(self, embed_dim, intermediate_dim, dtype, device, operations): + super().__init__() + self.dense = operations.Linear(embed_dim, intermediate_dim, dtype=dtype, device=device) + + def forward(self, x): + x = self.dense(x) + return torch.nn.functional.gelu(x) + + +class BertBlock(torch.nn.Module): + def __init__(self, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations): + super().__init__() + self.attention = BertAttentionBlock(embed_dim, heads, layer_norm_eps, dtype, device, operations) + self.intermediate = BertIntermediate(embed_dim, intermediate_dim, dtype, device, operations) + self.output = BertOutput(intermediate_dim, embed_dim, layer_norm_eps, dtype, device, operations) + + def forward(self, x, mask, optimized_attention): + x = self.attention(x, mask, optimized_attention) + y = self.intermediate(x) + return self.output(y, x) + +class BertEncoder(torch.nn.Module): + def __init__(self, num_layers, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations): + super().__init__() + self.layer = torch.nn.ModuleList([BertBlock(embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations) for i in range(num_layers)]) + + def forward(self, x, mask=None, intermediate_output=None): + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layer) + intermediate_output + + intermediate = None + for i, l in enumerate(self.layer): + x = l(x, mask, optimized_attention) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + +class BertEmbeddings(torch.nn.Module): + def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_token_id, embed_dim, layer_norm_eps, dtype, device, operations): + super().__init__() + self.word_embeddings = torch.nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device) + self.position_embeddings = torch.nn.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device) + self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device) + + self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device) + + def forward(self, input_tokens, token_type_ids=None): + x = self.word_embeddings(input_tokens) + x += self.position_embeddings.weight[:x.shape[1]] + if token_type_ids is not None: + x += self.token_type_embeddings(token_type_ids) + else: + x += self.token_type_embeddings.weight[0] + x = self.LayerNorm(x) + return x + + +class BertModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + embed_dim = config_dict["hidden_size"] + layer_norm_eps = config_dict["layer_norm_eps"] + + self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) + self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) + + def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): + x = self.embeddings(input_tokens) + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + + x, i = self.encoder(x, mask, intermediate_output) + return x, i + + +class BertModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.bert = BertModel_(config_dict, dtype, device, operations) + self.num_layers = config_dict["num_hidden_layers"] + + def get_input_embeddings(self): + return self.bert.embeddings.word_embeddings + + def set_input_embeddings(self, embeddings): + self.bert.embeddings.word_embeddings = embeddings + + def forward(self, *args, **kwargs): + return self.bert(*args, **kwargs) diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py index e47c8cb2..fc1d3c75 100644 --- a/comfy/text_encoders/hydit.py +++ b/comfy/text_encoders/hydit.py @@ -1,56 +1,15 @@ from comfy import sd1_clip -from transformers import T5TokenizerFast, BertTokenizer, BertModel, modeling_utils, BertConfig +from transformers import BertTokenizer from .spiece_tokenizer import SPieceTokenizer +from .bert import BertModel import comfy.text_encoders.t5 import os - import torch -import contextlib - -@contextlib.contextmanager -def use_comfy_ops(ops, device=None, dtype=None): - old_torch_nn_linear = torch.nn.Linear - force_device = device - force_dtype = dtype - def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): - if force_device is not None: - device = force_device - if force_dtype is not None: - dtype = force_dtype - return ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) - - torch.nn.Linear = linear_with_dtype - try: - yield - finally: - torch.nn.Linear = old_torch_nn_linear - - -class RobertaWrapper(torch.nn.Module): - def __init__(self, config_dict, dtype, device, operations): - super().__init__() - config = BertConfig(**config_dict) - with use_comfy_ops(operations, device, dtype): - with modeling_utils.no_init_weights(): - self.bert = BertModel(config, add_pooling_layer=False) - - self.num_layers = config.num_hidden_layers - - def get_input_embeddings(self): - return self.bert.get_input_embeddings() - - def set_input_embeddings(self, value): - return self.bert.set_input_embeddings(value) - - def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): - intermediate = None - out = self.bert(input_ids=input_tokens, output_hidden_states=intermediate_output is not None, attention_mask=attention_mask) - return out.last_hidden_state, intermediate, out.pooler_output class HyditBertModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=RobertaWrapper, enable_attention_masks=True, return_attention_masks=True) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True) class HyditBertTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}):