Update ldm dir with latest upstream stable diffusion changes.

This commit is contained in:
comfyanonymous 2023-02-09 13:47:36 -05:00
parent 642516a3a6
commit 1f6a467e92
5 changed files with 21 additions and 10 deletions

View File

@ -8,16 +8,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
class DDIMSampler(object): class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs): def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.device = device
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != self.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -1331,7 +1331,13 @@ class DiffusionWrapper(torch.nn.Module):
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
else: else:
cc = c_crossattn cc = c_crossattn
out = self.diffusion_model(x, t, context=cc) if hasattr(self, "scripted_diffusion_model"):
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc)
else:
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)

View File

@ -11,16 +11,17 @@ MODEL_TYPES = {
class DPMSolverSampler(object): class DPMSolverSampler(object):
def __init__(self, model, **kwargs): def __init__(self, model, device=torch.device("cuda"), **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != self.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
@torch.no_grad() @torch.no_grad()

View File

@ -10,16 +10,17 @@ from ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object): class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs): def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.device = device
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != self.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):

View File

@ -454,6 +454,7 @@ class UNetModel(nn.Module):
num_classes=None, num_classes=None,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, use_fp16=False,
use_bf16=False,
num_heads=-1, num_heads=-1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
@ -518,6 +519,7 @@ class UNetModel(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32 self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample