mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-12 14:32:08 +08:00
Don't unload/reload model from CPU uselessly.
This commit is contained in:
parent
e3e65947f2
commit
a84cd0d1ad
26
comfy/model_management.py
Normal file
26
comfy/model_management.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
|
||||||
|
|
||||||
|
current_loaded_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def unload_model():
|
||||||
|
global current_loaded_model
|
||||||
|
if current_loaded_model is not None:
|
||||||
|
current_loaded_model.model.cpu()
|
||||||
|
current_loaded_model.unpatch_model()
|
||||||
|
current_loaded_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_gpu(model):
|
||||||
|
global current_loaded_model
|
||||||
|
if model is current_loaded_model:
|
||||||
|
return
|
||||||
|
unload_model()
|
||||||
|
try:
|
||||||
|
real_model = model.patch_model()
|
||||||
|
except Exception as e:
|
||||||
|
model.unpatch_model()
|
||||||
|
raise e
|
||||||
|
current_loaded_model = model
|
||||||
|
real_model.cuda()
|
||||||
|
return current_loaded_model
|
@ -2,6 +2,7 @@ import torch
|
|||||||
|
|
||||||
import sd1_clip
|
import sd1_clip
|
||||||
import sd2_clip
|
import sd2_clip
|
||||||
|
import model_management
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.autoencoder import AutoencoderKL
|
from ldm.models.autoencoder import AutoencoderKL
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@ -304,6 +305,7 @@ class VAE:
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def decode(self, samples):
|
def decode(self, samples):
|
||||||
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
samples = samples.to(self.device)
|
samples = samples.to(self.device)
|
||||||
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
|
pixel_samples = self.first_stage_model.decode(1. / self.scale_factor * samples)
|
||||||
@ -313,6 +315,7 @@ class VAE:
|
|||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def encode(self, pixel_samples):
|
def encode(self, pixel_samples):
|
||||||
|
model_management.unload_model()
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
|
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
|
||||||
samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor
|
samples = self.first_stage_model.encode(2. * pixel_samples - 1.).sample() * self.scale_factor
|
||||||
|
61
nodes.py
61
nodes.py
@ -15,6 +15,7 @@ sys.path.append(os.path.join(sys.path[0], "comfy"))
|
|||||||
|
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.sd
|
import comfy.sd
|
||||||
|
import model_management
|
||||||
|
|
||||||
supported_ckpt_extensions = ['.ckpt']
|
supported_ckpt_extensions = ['.ckpt']
|
||||||
supported_pt_extensions = ['.ckpt', '.pt', '.bin']
|
supported_pt_extensions = ['.ckpt', '.pt', '.bin']
|
||||||
@ -353,43 +354,39 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po
|
|||||||
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
|
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu")
|
||||||
|
|
||||||
real_model = None
|
real_model = None
|
||||||
try:
|
if device != "cpu":
|
||||||
|
model_management.load_model_gpu(model)
|
||||||
|
real_model = model.model
|
||||||
|
else:
|
||||||
|
#TODO: cpu support
|
||||||
real_model = model.patch_model()
|
real_model = model.patch_model()
|
||||||
real_model.to(device)
|
noise = noise.to(device)
|
||||||
noise = noise.to(device)
|
latent_image = latent_image.to(device)
|
||||||
latent_image = latent_image.to(device)
|
|
||||||
|
|
||||||
positive_copy = []
|
positive_copy = []
|
||||||
negative_copy = []
|
negative_copy = []
|
||||||
|
|
||||||
for p in positive:
|
for p in positive:
|
||||||
t = p[0]
|
t = p[0]
|
||||||
if t.shape[0] < noise.shape[0]:
|
if t.shape[0] < noise.shape[0]:
|
||||||
t = torch.cat([t] * noise.shape[0])
|
t = torch.cat([t] * noise.shape[0])
|
||||||
t = t.to(device)
|
t = t.to(device)
|
||||||
positive_copy += [[t] + p[1:]]
|
positive_copy += [[t] + p[1:]]
|
||||||
for n in negative:
|
for n in negative:
|
||||||
t = n[0]
|
t = n[0]
|
||||||
if t.shape[0] < noise.shape[0]:
|
if t.shape[0] < noise.shape[0]:
|
||||||
t = torch.cat([t] * noise.shape[0])
|
t = torch.cat([t] * noise.shape[0])
|
||||||
t = t.to(device)
|
t = t.to(device)
|
||||||
negative_copy += [[t] + n[1:]]
|
negative_copy += [[t] + n[1:]]
|
||||||
|
|
||||||
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
if sampler_name in comfy.samplers.KSampler.SAMPLERS:
|
||||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise)
|
||||||
else:
|
else:
|
||||||
#other samplers
|
#other samplers
|
||||||
pass
|
pass
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
|
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise)
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
real_model.cpu()
|
|
||||||
model.unpatch_model()
|
|
||||||
except Exception as e:
|
|
||||||
if real_model is not None:
|
|
||||||
real_model.cpu()
|
|
||||||
model.unpatch_model()
|
|
||||||
raise e
|
|
||||||
|
|
||||||
return (samples, )
|
return (samples, )
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user