mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 14:09:36 +00:00
Make sure cond_concat is on the right device.
This commit is contained in:
parent
45c972aba8
commit
e6962120c6
@ -79,6 +79,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
denoise_mask = kwargs.get("denoise_mask", None)
|
denoise_mask = kwargs.get("denoise_mask", None)
|
||||||
latent_image = kwargs.get("latent_image", None)
|
latent_image = kwargs.get("latent_image", None)
|
||||||
noise = kwargs.get("noise", None)
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
def blank_inpaint_image_like(latent_image):
|
def blank_inpaint_image_like(latent_image):
|
||||||
blank_image = torch.ones_like(latent_image)
|
blank_image = torch.ones_like(latent_image)
|
||||||
@ -92,9 +93,9 @@ class BaseModel(torch.nn.Module):
|
|||||||
for ck in concat_keys:
|
for ck in concat_keys:
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(denoise_mask[:,:1])
|
cond_concat.append(denoise_mask[:,:1].to(device))
|
||||||
elif ck == "masked_image":
|
elif ck == "masked_image":
|
||||||
cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
|
cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
|
||||||
else:
|
else:
|
||||||
if ck == "mask":
|
if ck == "mask":
|
||||||
cond_concat.append(torch.ones_like(noise)[:,:1])
|
cond_concat.append(torch.ones_like(noise)[:,:1])
|
||||||
|
@ -537,10 +537,11 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type):
|
|||||||
|
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
def encode_cond(model_function, key, conds, **kwargs):
|
def encode_cond(model_function, key, conds, device, **kwargs):
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
params = x[1].copy()
|
params = x[1].copy()
|
||||||
|
params["device"] = device
|
||||||
for k in kwargs:
|
for k in kwargs:
|
||||||
if k not in params:
|
if k not in params:
|
||||||
params[k] = kwargs[k]
|
params[k] = kwargs[k]
|
||||||
@ -677,8 +678,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative")
|
||||||
|
|
||||||
if hasattr(model, 'cond_concat'):
|
if hasattr(model, 'cond_concat'):
|
||||||
positive = encode_cond(model.cond_concat, "concat", positive, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||||
negative = encode_cond(model.cond_concat, "concat", negative, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask)
|
||||||
|
|
||||||
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user