From e87a8669b66af5dc08d83c7ef29c386618db9927 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Feb 2023 17:39:42 -0500 Subject: [PATCH] Add a LoadImageMask node to load one colour channel in an image as a mask. --- nodes.py | 48 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/nodes.py b/nodes.py index c190d633..0a41511c 100644 --- a/nodes.py +++ b/nodes.py @@ -410,11 +410,8 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po if "noise_mask" in latent: noise_mask = latent['noise_mask'] - print(noise_mask.shape, noise.shape) - noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") - noise_mask = noise_mask.floor() - noise_mask = torch.ones_like(noise_mask) - noise_mask + noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) noise_mask = torch.cat([noise_mask] * noise.shape[0]) noise_mask = noise_mask.to(device) @@ -581,10 +578,11 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): image_path = os.path.join(self.input_dir, image) - image = Image.open(image_path).convert("RGB") + i = Image.open(image_path) + image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image[None])[None,] - return image + image = torch.from_numpy(image)[None,] + return (image,) @classmethod def IS_CHANGED(s, image): @@ -594,6 +592,41 @@ class LoadImage: m.update(f.read()) return m.digest().hex() +class LoadImageMask: + input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + @classmethod + def INPUT_TYPES(s): + return {"required": + {"image": (os.listdir(s.input_dir), ), + "channel": (["alpha", "red", "green", "blue"], ),} + } + + CATEGORY = "image" + + RETURN_TYPES = ("MASK",) + FUNCTION = "load_image" + def load_image(self, image, channel): + image_path = os.path.join(self.input_dir, image) + i = Image.open(image_path) + mask = None + c = channel[0].upper() + if c in i.getbands(): + mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 + mask = torch.from_numpy(mask) + if c == 'A': + mask = 1. - mask + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + return (mask,) + + @classmethod + def IS_CHANGED(s, image, channel): + image_path = os.path.join(s.input_dir, image) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + class ImageScale: upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] @@ -626,6 +659,7 @@ NODE_CLASS_MAPPINGS = { "LatentUpscale": LatentUpscale, "SaveImage": SaveImage, "LoadImage": LoadImage, + "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea,