mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Update ldm dir with latest upstream stable diffusion changes.
This commit is contained in:
parent
642516a3a6
commit
1f6a467e92
@ -8,16 +8,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
|
|||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
class DDIMSampler(object):
|
||||||
def __init__(self, model, schedule="linear", **kwargs):
|
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
if attr.device != torch.device("cuda"):
|
if attr.device != self.device:
|
||||||
attr = attr.to(torch.device("cuda"))
|
attr = attr.to(self.device)
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
@ -1331,6 +1331,12 @@ class DiffusionWrapper(torch.nn.Module):
|
|||||||
cc = torch.cat(c_crossattn, 1)
|
cc = torch.cat(c_crossattn, 1)
|
||||||
else:
|
else:
|
||||||
cc = c_crossattn
|
cc = c_crossattn
|
||||||
|
if hasattr(self, "scripted_diffusion_model"):
|
||||||
|
# TorchScript changes names of the arguments
|
||||||
|
# with argument cc defined as context=cc scripted model will produce
|
||||||
|
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
|
||||||
|
out = self.scripted_diffusion_model(x, t, cc)
|
||||||
|
else:
|
||||||
out = self.diffusion_model(x, t, context=cc)
|
out = self.diffusion_model(x, t, context=cc)
|
||||||
elif self.conditioning_key == 'hybrid':
|
elif self.conditioning_key == 'hybrid':
|
||||||
xc = torch.cat([x] + c_concat, dim=1)
|
xc = torch.cat([x] + c_concat, dim=1)
|
||||||
|
@ -11,16 +11,17 @@ MODEL_TYPES = {
|
|||||||
|
|
||||||
|
|
||||||
class DPMSolverSampler(object):
|
class DPMSolverSampler(object):
|
||||||
def __init__(self, model, **kwargs):
|
def __init__(self, model, device=torch.device("cuda"), **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.device = device
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
if attr.device != torch.device("cuda"):
|
if attr.device != self.device:
|
||||||
attr = attr.to(torch.device("cuda"))
|
attr = attr.to(self.device)
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -10,16 +10,17 @@ from ldm.models.diffusion.sampling_util import norm_thresholding
|
|||||||
|
|
||||||
|
|
||||||
class PLMSSampler(object):
|
class PLMSSampler(object):
|
||||||
def __init__(self, model, schedule="linear", **kwargs):
|
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
self.device = device
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
if attr.device != torch.device("cuda"):
|
if attr.device != self.device:
|
||||||
attr = attr.to(torch.device("cuda"))
|
attr = attr.to(self.device)
|
||||||
setattr(self, name, attr)
|
setattr(self, name, attr)
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
@ -454,6 +454,7 @@ class UNetModel(nn.Module):
|
|||||||
num_classes=None,
|
num_classes=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
|
use_bf16=False,
|
||||||
num_heads=-1,
|
num_heads=-1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
@ -518,6 +519,7 @@ class UNetModel(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.dtype = th.float16 if use_fp16 else th.float32
|
self.dtype = th.float16 if use_fp16 else th.float32
|
||||||
|
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
self.num_heads_upsample = num_heads_upsample
|
self.num_heads_upsample = num_heads_upsample
|
||||||
|
Loading…
Reference in New Issue
Block a user