From afff30fc0a4d11be4823ccce78d281a4e504c914 Mon Sep 17 00:00:00 2001
From: comfyanonymous <comfyanonymous@protonmail.com>
Date: Mon, 6 Mar 2023 10:50:50 -0500
Subject: [PATCH] Add --cpu to use the cpu for inference.

---
 comfy/model_management.py | 16 +++++++++++++++-
 comfy/samplers.py         |  2 +-
 comfy/sd.py               | 16 +++++++++++-----
 main.py                   |  1 +
 nodes.py                  | 22 +++++++---------------
 5 files changed, 35 insertions(+), 22 deletions(-)

diff --git a/comfy/model_management.py b/comfy/model_management.py
index 32159b82..4b061c32 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -31,6 +31,8 @@ try:
 except:
     pass
 
+if "--cpu" in sys.argv:
+    vram_state = CPU
 if "--lowvram" in sys.argv:
     set_vram_to = LOW_VRAM
 if "--novram" in sys.argv:
@@ -118,6 +120,8 @@ def load_model_gpu(model):
 def load_controlnet_gpu(models):
     global current_gpu_controlnets
     global vram_state
+    if vram_state == CPU:
+        return
 
     if vram_state == LOW_VRAM or vram_state == NO_VRAM:
         #don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
@@ -144,10 +148,20 @@ def unload_if_low_vram(model):
         return model.cpu()
     return model
 
+def get_torch_device():
+    if vram_state == CPU:
+        return torch.device("cpu")
+    else:
+        return torch.cuda.current_device()
+
+def get_autocast_device(dev):
+    if hasattr(dev, 'type'):
+        return dev.type
+    return "cuda"
 
 def get_free_memory(dev=None, torch_free_too=False):
     if dev is None:
-        dev = torch.cuda.current_device()
+        dev = get_torch_device()
 
     if hasattr(dev, 'type') and dev.type == 'cpu':
         mem_free_total = psutil.virtual_memory().available
diff --git a/comfy/samplers.py b/comfy/samplers.py
index 3562f89d..569c32f4 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -438,7 +438,7 @@ class KSampler:
         else:
             max_denoise = True
 
-        with precision_scope(self.device):
+        with precision_scope(model_management.get_autocast_device(self.device)):
             if self.sampler == "uni_pc":
                 samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask)
             elif self.sampler == "uni_pc_bh2":
diff --git a/comfy/sd.py b/comfy/sd.py
index eb4ea793..67a207cb 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -299,7 +299,7 @@ class CLIP:
         return cond
 
 class VAE:
-    def __init__(self, ckpt_path=None, scale_factor=0.18215, device="cuda", config=None):
+    def __init__(self, ckpt_path=None, scale_factor=0.18215, device=None, config=None):
         if config is None:
             #default SD1.x/SD2.x VAE parameters
             ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -308,6 +308,8 @@ class VAE:
             self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path)
         self.first_stage_model = self.first_stage_model.eval()
         self.scale_factor = scale_factor
+        if device is None:
+            device = model_management.get_torch_device()
         self.device = device
 
     def decode(self, samples):
@@ -381,11 +383,13 @@ def resize_image_to(tensor, target_latent_tensor, batched_number):
         return torch.cat([tensor] * batched_number, dim=0)
 
 class ControlNet:
-    def __init__(self, control_model, device="cuda"):
+    def __init__(self, control_model, device=None):
         self.control_model = control_model
         self.cond_hint_original = None
         self.cond_hint = None
         self.strength = 1.0
+        if device is None:
+            device = model_management.get_torch_device()
         self.device = device
         self.previous_controlnet = None
 
@@ -406,7 +410,7 @@ class ControlNet:
         else:
             precision_scope = contextlib.nullcontext
 
-        with precision_scope(self.device):
+        with precision_scope(model_management.get_autocast_device(self.device)):
             self.control_model = model_management.load_if_low_vram(self.control_model)
             control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt)
             self.control_model = model_management.unload_if_low_vram(self.control_model)
@@ -481,7 +485,7 @@ def load_controlnet(ckpt_path, model=None):
     context_dim = controlnet_data[key].shape[1]
 
     use_fp16 = False
-    if controlnet_data[key].dtype == torch.float16:
+    if model_management.should_use_fp16() and controlnet_data[key].dtype == torch.float16:
         use_fp16 = True
 
     control_model = cldm.ControlNet(image_size=32,
@@ -527,10 +531,12 @@ def load_controlnet(ckpt_path, model=None):
     return control
 
 class T2IAdapter:
-    def __init__(self, t2i_model, channels_in, device="cuda"):
+    def __init__(self, t2i_model, channels_in, device=None):
         self.t2i_model = t2i_model
         self.channels_in = channels_in
         self.strength = 1.0
+        if device is None:
+            device = model_management.get_torch_device()
         self.device = device
         self.previous_controlnet = None
         self.control_input = None
diff --git a/main.py b/main.py
index 43dff955..ca8674b5 100644
--- a/main.py
+++ b/main.py
@@ -24,6 +24,7 @@ if __name__ == "__main__":
         print("\t--lowvram\t\t\tSplit the unet in parts to use less vram.")
         print("\t--novram\t\t\tWhen lowvram isn't enough.")
         print()
+        print("\t--cpu\t\t\tTo use the CPU for everything (slow).")
         exit()
 
     if '--dont-upcast-attention' in sys.argv:
diff --git a/nodes.py b/nodes.py
index 84510a05..e5800d0d 100644
--- a/nodes.py
+++ b/nodes.py
@@ -628,9 +628,10 @@ class SetLatentNoiseMask:
         return (s,)
 
 
-def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
+def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
     latent_image = latent["samples"]
     noise_mask = None
+    device = model_management.get_torch_device()
 
     if disable_noise:
         noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
@@ -646,12 +647,9 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
         noise_mask = noise_mask.to(device)
 
     real_model = None
-    if device != "cpu":
-        model_management.load_model_gpu(model)
-        real_model = model.model
-    else:
-        #TODO: cpu support
-        real_model = model.patch_model()
+    model_management.load_model_gpu(model)
+    real_model = model.model
+
     noise = noise.to(device)
     latent_image = latent_image.to(device)
 
@@ -697,9 +695,6 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
     return (out, )
 
 class KSampler:
-    def __init__(self, device="cuda"):
-        self.device = device
-
     @classmethod
     def INPUT_TYPES(s):
         return {"required":
@@ -721,12 +716,9 @@ class KSampler:
     CATEGORY = "sampling"
 
     def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
-        return common_ksampler(self.device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
+        return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
 
 class KSamplerAdvanced:
-    def __init__(self, device="cuda"):
-        self.device = device
-
     @classmethod
     def INPUT_TYPES(s):
         return {"required":
@@ -757,7 +749,7 @@ class KSamplerAdvanced:
         disable_noise = False
         if add_noise == "disable":
             disable_noise = True
-        return common_ksampler(self.device, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
+        return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
 
 class SaveImage:
     def __init__(self):