diff --git a/comfy/model_base.py b/comfy/model_base.py index 8e704022e..cda6765e4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -79,6 +79,7 @@ class BaseModel(torch.nn.Module): denoise_mask = kwargs.get("denoise_mask", None) latent_image = kwargs.get("latent_image", None) noise = kwargs.get("noise", None) + device = kwargs["device"] def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) @@ -92,9 +93,9 @@ class BaseModel(torch.nn.Module): for ck in concat_keys: if denoise_mask is not None: if ck == "mask": - cond_concat.append(denoise_mask[:,:1]) + cond_concat.append(denoise_mask[:,:1].to(device)) elif ck == "masked_image": - cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space + cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space else: if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) diff --git a/comfy/samplers.py b/comfy/samplers.py index a56599227..4840b6d9f 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -537,10 +537,11 @@ def encode_adm(model, conds, batch_size, width, height, device, prompt_type): return conds -def encode_cond(model_function, key, conds, **kwargs): +def encode_cond(model_function, key, conds, device, **kwargs): for t in range(len(conds)): x = conds[t] params = x[1].copy() + params["device"] = device for k in kwargs: if k not in params: params[k] = kwargs[k] @@ -677,8 +678,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model negative = encode_adm(model, negative, noise.shape[0], noise.shape[3], noise.shape[2], device, "negative") if hasattr(model, 'cond_concat'): - positive = encode_cond(model.cond_concat, "concat", positive, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_cond(model.cond_concat, "concat", negative, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + positive = encode_cond(model.cond_concat, "concat", positive, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) + negative = encode_cond(model.cond_concat, "concat", negative, device, noise=noise, latent_image=latent_image, denoise_mask=denoise_mask) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}