mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Add a --gpu-only argument to keep and run everything on the GPU.
Make the CLIP model work on the GPU.
This commit is contained in:
parent
7bf89ba923
commit
f7edcfd927
@ -59,12 +59,14 @@ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", he
|
|||||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||||
|
|
||||||
vram_group = parser.add_mutually_exclusive_group()
|
vram_group = parser.add_mutually_exclusive_group()
|
||||||
|
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||||
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
|
||||||
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
|
||||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
|
||||||
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
|
||||||
|
@ -151,7 +151,7 @@ if args.lowvram:
|
|||||||
lowvram_available = True
|
lowvram_available = True
|
||||||
elif args.novram:
|
elif args.novram:
|
||||||
set_vram_to = VRAMState.NO_VRAM
|
set_vram_to = VRAMState.NO_VRAM
|
||||||
elif args.highvram:
|
elif args.highvram or args.gpu_only:
|
||||||
vram_state = VRAMState.HIGH_VRAM
|
vram_state = VRAMState.HIGH_VRAM
|
||||||
|
|
||||||
FORCE_FP32 = False
|
FORCE_FP32 = False
|
||||||
@ -307,6 +307,12 @@ def unload_if_low_vram(model):
|
|||||||
return model.cpu()
|
return model.cpu()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def text_encoder_device():
|
||||||
|
if args.gpu_only:
|
||||||
|
return get_torch_device()
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
return dev.type
|
return dev.type
|
||||||
|
@ -467,7 +467,11 @@ class CLIP:
|
|||||||
clip = sd1_clip.SD1ClipModel
|
clip = sd1_clip.SD1ClipModel
|
||||||
tokenizer = sd1_clip.SD1Tokenizer
|
tokenizer = sd1_clip.SD1Tokenizer
|
||||||
|
|
||||||
|
self.device = model_management.text_encoder_device()
|
||||||
|
params["device"] = self.device
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
self.cond_stage_model = self.cond_stage_model.to(self.device)
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory)
|
||||||
self.patcher = ModelPatcher(self.cond_stage_model)
|
self.patcher = ModelPatcher(self.cond_stage_model)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
|
@ -20,7 +20,7 @@ class ClipTokenWeightEncoder:
|
|||||||
output += [z]
|
output += [z]
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
return self.encode(self.empty_tokens)
|
return self.encode(self.empty_tokens)
|
||||||
return torch.cat(output, dim=-2)
|
return torch.cat(output, dim=-2).cpu()
|
||||||
|
|
||||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
|
Loading…
Reference in New Issue
Block a user