Add a --gpu-only argument to keep and run everything on the GPU.

Make the CLIP model work on the GPU.
This commit is contained in:
comfyanonymous 2023-06-15 15:21:37 -04:00
parent 7bf89ba923
commit f7edcfd927
4 changed files with 14 additions and 2 deletions

View File

@ -59,12 +59,14 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
vram_group = parser.add_mutually_exclusive_group() vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.") vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.") vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).") parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")

View File

@ -151,7 +151,7 @@ if args.lowvram:
lowvram_available = True lowvram_available = True
elif args.novram: elif args.novram:
set_vram_to = VRAMState.NO_VRAM set_vram_to = VRAMState.NO_VRAM
elif args.highvram: elif args.highvram or args.gpu_only:
vram_state = VRAMState.HIGH_VRAM vram_state = VRAMState.HIGH_VRAM
FORCE_FP32 = False FORCE_FP32 = False
@ -307,6 +307,12 @@ def unload_if_low_vram(model):
return model.cpu() return model.cpu()
return model return model
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
else:
return torch.device("cpu")
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
return dev.type return dev.type

View File

@ -467,7 +467,11 @@ class CLIP:
clip = sd1_clip.SD1ClipModel clip = sd1_clip.SD1ClipModel
tokenizer = sd1_clip.SD1Tokenizer tokenizer = sd1_clip.SD1Tokenizer
self.device = model_management.text_encoder_device()
params["device"] = self.device
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
self.cond_stage_model = self.cond_stage_model.to(self.device)
self.tokenizer = tokenizer(embedding_directory=embedding_directory) self.tokenizer = tokenizer(embedding_directory=embedding_directory)
self.patcher = ModelPatcher(self.cond_stage_model) self.patcher = ModelPatcher(self.cond_stage_model)
self.layer_idx = None self.layer_idx = None

View File

@ -20,7 +20,7 @@ class ClipTokenWeightEncoder:
output += [z] output += [z]
if (len(output) == 0): if (len(output) == 0):
return self.encode(self.empty_tokens) return self.encode(self.empty_tokens)
return torch.cat(output, dim=-2) return torch.cat(output, dim=-2).cpu()
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""