diff --git a/comfy/samplers.py b/comfy/samplers.py index 6fa754b90..f8701c879 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -6,7 +6,6 @@ import contextlib from comfy import model_management from .ldm.models.diffusion.ddim import DDIMSampler from .ldm.modules.diffusionmodules.util import make_ddim_timesteps -from torchvision.ops import masks_to_boxes #The main sampling function shared by all the samplers #Returns predicted noise @@ -31,8 +30,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con assert(mask.shape[1] == x_in.shape[2]) assert(mask.shape[2] == x_in.shape[3]) mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if mask.shape[0] != input_x.shape[0]: - mask = mask.repeat(input_x.shape[0], 1, 1) + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) else: mask = torch.ones_like(input_x) mult = mask * strength @@ -315,6 +313,29 @@ def blank_inpaint_image_like(latent_image): blank_image[:,3] *= 0.1380 return blank_image +def get_mask_aabb(masks): + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device, dtype=torch.int) + + b = masks.shape[0] + + bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int) + is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool) + for i in range(b): + mask = masks[i] + if mask.numel() == 0: + continue + if torch.max(mask != 0) == False: + is_empty[i] = True + continue + y, x = torch.where(mask) + bounding_boxes[i, 0] = torch.min(x) + bounding_boxes[i, 1] = torch.min(y) + bounding_boxes[i, 2] = torch.max(x) + bounding_boxes[i, 3] = torch.max(y) + + return bounding_boxes, is_empty + def resolve_cond_masks(conditions, h, w, device): # We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes. # While we're doing this, we can also resolve the mask device and scaling for performance reasons @@ -329,13 +350,14 @@ def resolve_cond_masks(conditions, h, w, device): if mask.shape[2] != h or mask.shape[3] != w: mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1) - if 'area' not in modified: + if modified.get("set_area_to_bounds", False): bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) - if torch.max(bounds) == 0: - # Handle the edge-case of an all black mask (where masks_to_boxes would error) - area = (0, 0, 0, 0) + boxes, is_empty = get_mask_aabb(bounds) + if is_empty[0]: + # Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway) + modified['area'] = (8, 8, 0, 0) else: - box = masks_to_boxes(bounds)[0].type(torch.int) + box = boxes[0] H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0]) # Make sure the height and width are divisible by 8 if X % 8 != 0: @@ -350,8 +372,8 @@ def resolve_cond_masks(conditions, h, w, device): H = H + (8 - (H % 8)) if W % 8 != 0: W = W + (8 - (W % 8)) - area = (int(H), int(W), int(Y), (X)) - modified['area'] = area + area = (int(H), int(W), int(Y), int(X)) + modified['area'] = area modified['mask'] = mask conditions[i] = [c[0], modified] diff --git a/nodes.py b/nodes.py index be02f4676..12fa7e5a3 100644 --- a/nodes.py +++ b/nodes.py @@ -90,6 +90,7 @@ class ConditioningSetMask: def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), "mask": ("MASK", ), + "set_area_to_bounds": ([False, True],), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("CONDITIONING",) @@ -97,7 +98,7 @@ class ConditioningSetMask: CATEGORY = "conditioning" - def append(self, conditioning, mask, strength, min_sigma=0.0, max_sigma=99.0): + def append(self, conditioning, mask, set_area_to_bounds, strength, min_sigma=0.0, max_sigma=99.0): c = [] if len(mask.shape) < 3: mask = mask.unsqueeze(0) @@ -105,6 +106,7 @@ class ConditioningSetMask: n = [t[0], t[1].copy()] _, h, w = mask.shape n[1]['mask'] = mask + n[1]['set_area_to_bounds'] = set_area_to_bounds n[1]['strength'] = strength n[1]['min_sigma'] = min_sigma n[1]['max_sigma'] = max_sigma