diff --git a/comfy/model_management.py b/comfy/model_management.py index 4f3f2857..f10d1ca8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -327,12 +327,18 @@ def unload_if_low_vram(model): return model.cpu() return model -def text_encoder_device(): +def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() else: return torch.device("cpu") +def text_encoder_device(): + if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED or vram_state == VRAMState.NORMAL_VRAM: + return get_torch_device() + else: + return torch.device("cpu") + def get_autocast_device(dev): if hasattr(dev, 'type'): return dev.type @@ -422,10 +428,15 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS -def should_use_fp16(): +def should_use_fp16(device=None): global xpu_available global directml_enabled + if device is not None: #TODO + if hasattr(device, 'type'): + if (device.type == 'cpu' or device.type == 'mps'): + return False + if FORCE_FP32: return False diff --git a/comfy/sd.py b/comfy/sd.py index 8eac1f8e..320b0fb7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -526,9 +526,10 @@ class CLIP: tokenizer = target.tokenizer self.device = model_management.text_encoder_device() - params["device"] = self.device self.cond_stage_model = clip(**(params)) - self.cond_stage_model = self.cond_stage_model.to(self.device) + if model_management.should_use_fp16(self.device): + self.cond_stage_model.half() + self.cond_stage_model = self.cond_stage_model.to(model_management.text_encoder_offload_device()) self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.patcher = ModelPatcher(self.cond_stage_model) @@ -559,11 +560,14 @@ class CLIP: if self.layer_idx is not None: self.cond_stage_model.clip_layer(self.layer_idx) try: + self.cond_stage_model.to(self.device) self.patch_model() cond, pooled = self.cond_stage_model.encode_token_weights(tokens) self.unpatch_model() + self.cond_stage_model.to(model_management.text_encoder_offload_device()) except Exception as e: self.unpatch_model() + self.cond_stage_model.to(model_management.text_encoder_offload_device()) raise e cond_out = cond diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 02a998e5..5c627cb8 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -5,6 +5,8 @@ import comfy.ops import torch import traceback import zipfile +from . import model_management +import contextlib class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): @@ -46,7 +48,6 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): with modeling_utils.no_init_weights(): self.transformer = CLIPTextModel(config) - self.device = device self.max_length = max_length if freeze: self.freeze() @@ -95,7 +96,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): out_tokens += [tokens_temp] if len(embedding_weights) > 0: - new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=self.device) + new_embedding = torch.nn.Embedding(next_new_token, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) new_embedding.weight[:token_dict_size] = current_embeds.weight[:] n = token_dict_size for x in embedding_weights: @@ -106,24 +107,34 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): def forward(self, tokens): backup_embeds = self.transformer.get_input_embeddings() + device = backup_embeds.weight.device tokens = self.set_up_textual_embeddings(tokens, backup_embeds) - tokens = torch.LongTensor(tokens).to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") - self.transformer.set_input_embeddings(backup_embeds) + tokens = torch.LongTensor(tokens).to(device) - if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] + if backup_embeds.weight.dtype != torch.float32: + print("autocast clip") + precision_scope = torch.autocast else: - z = outputs.hidden_states[self.layer_idx] - if self.layer_norm_hidden_state: - z = self.transformer.text_model.final_layer_norm(z) + precision_scope = contextlib.nullcontext + print("no autocast clip") - pooled_output = outputs.pooler_output - if self.text_projection is not None: - pooled_output = pooled_output @ self.text_projection - return z, pooled_output + with precision_scope(model_management.get_autocast_device(device)): + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + self.transformer.set_input_embeddings(backup_embeds) + + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + if self.layer_norm_hidden_state: + z = self.transformer.text_model.final_layer_norm(z) + + pooled_output = outputs.pooler_output + if self.text_projection is not None: + pooled_output = pooled_output @ self.text_projection + return z.float(), pooled_output.float() def encode(self, tokens): return self(tokens)