From 2c038ccef0f819ee8693a94dd880f05a4eb3808c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 31 Jul 2024 01:32:35 -0400 Subject: [PATCH] Lower CLIP memory usage by a bit. --- comfy/clip_model.py | 23 ++++++++++++----------- comfy/sd1_clip.py | 7 ++++--- comfy/text_encoders/bert.py | 21 +++++++++++---------- comfy/text_encoders/t5.py | 2 +- 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index ab775309..3c67b737 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -1,5 +1,6 @@ import torch from comfy.ldm.modules.attention import optimized_attention_for_device +import comfy.ops class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): @@ -71,13 +72,13 @@ class CLIPEncoder(torch.nn.Module): return x, intermediate class CLIPEmbeddings(torch.nn.Module): - def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None): + def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None): super().__init__() - self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) - self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device) - def forward(self, input_tokens): - return self.token_embedding(input_tokens) + self.position_embedding.weight + def forward(self, input_tokens, dtype=torch.float32): + return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device) class CLIPTextModel_(torch.nn.Module): @@ -90,12 +91,12 @@ class CLIPTextModel_(torch.nn.Module): self.eos_token_id = config_dict["eos_token_id"] super().__init__() - self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device) + self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) - def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): - x = self.embeddings(input_tokens) + def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): + x = self.embeddings(input_tokens, dtype=dtype) mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) @@ -154,11 +155,11 @@ class CLIPVisionEmbeddings(torch.nn.Module): num_patches = (image_size // patch_size) ** 2 num_positions = num_patches + 1 - self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device) + self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device) def forward(self, pixel_values): embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) - return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device) + return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds) class CLIPVision(torch.nn.Module): @@ -170,7 +171,7 @@ class CLIPVision(torch.nn.Module): intermediate_size = config_dict["intermediate_size"] intermediate_activation = config_dict["hidden_act"] - self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations) + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations) self.pre_layrnorm = operations.LayerNorm(embed_dim) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.post_layernorm = operations.LayerNorm(embed_dim) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index f209bed4..d32121d1 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -94,7 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): with open(textmodel_json_config) as f: config = json.load(f) - self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast) + self.operations = comfy.ops.manual_cast + self.transformer = model_class(config, dtype, device, self.operations) self.num_layers = self.transformer.num_layers self.max_length = max_length @@ -161,7 +162,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): n = token_dict_size if len(embedding_weights) > 0: - new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) + new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) new_embedding.weight[:token_dict_size] = current_embeds.weight for x in embedding_weights: new_embedding.weight[n] = x @@ -194,7 +195,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index b76e7666..fc9bac1d 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -1,5 +1,6 @@ import torch from comfy.ldm.modules.attention import optimized_attention_for_device +import comfy.ops class BertAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): @@ -86,19 +87,19 @@ class BertEncoder(torch.nn.Module): class BertEmbeddings(torch.nn.Module): def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_token_id, embed_dim, layer_norm_eps, dtype, device, operations): super().__init__() - self.word_embeddings = torch.nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device) - self.position_embeddings = torch.nn.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device) - self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device) + self.word_embeddings = operations.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device) + self.position_embeddings = operations.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device) + self.token_type_embeddings = operations.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device) self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device) - def forward(self, input_tokens, token_type_ids=None): - x = self.word_embeddings(input_tokens) - x += self.position_embeddings.weight[:x.shape[1]] + def forward(self, input_tokens, token_type_ids=None, dtype=None): + x = self.word_embeddings(input_tokens, out_dtype=dtype) + x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x) if token_type_ids is not None: - x += self.token_type_embeddings(token_type_ids) + x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype) else: - x += self.token_type_embeddings.weight[0] + x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x) x = self.LayerNorm(x) return x @@ -112,8 +113,8 @@ class BertModel_(torch.nn.Module): self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) - def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): - x = self.embeddings(input_tokens) + def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + x = self.embeddings(input_tokens, dtype=dtype) mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index 2109f4ea..b6491090 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -200,7 +200,7 @@ class T5Stack(torch.nn.Module): self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) # self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True): + def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])