This commit is contained in:
dave-juicelabs 2025-04-11 11:52:08 -04:00 committed by GitHub
commit 02987dd4a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 0 deletions

View File

@ -73,6 +73,7 @@ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.") parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
parser.add_argument("--cpu-model-sampling", action="store_true", help="Run the model sampling on the CPU.")
fpte_group = parser.add_mutually_exclusive_group() fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")

View File

@ -845,6 +845,11 @@ def vae_device():
return torch.device("cpu") return torch.device("cpu")
return get_torch_device() return get_torch_device()
def model_sampling_device():
if args.cpu_model_sampling:
return torch.device("cpu")
return get_torch_device()
def vae_offload_device(): def vae_offload_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()

View File

@ -984,6 +984,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
logging.info("loaded diffusion model directly to GPU") logging.info("loaded diffusion model directly to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True) model_management.load_models_gpu([model_patcher], force_full_load=True)
#damcclos: move the model_sampling back to the CPU. The work needed for this is not worth the gpu.
model_sampling_device = model_management.model_sampling_device()
if model_sampling_device == torch.device("cpu"):
model_patcher.model.model_sampling.to(model_sampling_device)
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)