Add a way to disable cropping in the CLIPVisionEncode node.

This commit is contained in:
comfyanonymous 2024-11-28 20:24:47 -05:00
parent bf2650a80e
commit 26fb2c68e8
2 changed files with 17 additions and 8 deletions

View File

@ -16,13 +16,18 @@ class Output:
def __setitem__(self, key, item): def __setitem__(self, key, item):
setattr(self, key, item) setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]): def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
mean = torch.tensor(mean, device=image.device, dtype=image.dtype) mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1) image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size): if not (image.shape[2] == size and image.shape[3] == size):
if crop:
scale = (size / min(image.shape[2], image.shape[3])) scale = (size / min(image.shape[2], image.shape[3]))
image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True) scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
else:
scale_size = (size, size)
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2 h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2 w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size] image = image[:,:,h:h+size,w:w+size]
@ -51,9 +56,9 @@ class ClipVisionModel():
def get_sd(self): def get_sd(self):
return self.model.state_dict() return self.model.state_dict()
def encode_image(self, image): def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float() pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2) out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output() outputs = Output()

View File

@ -971,15 +971,19 @@ class CLIPVisionEncode:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",), return {"required": { "clip_vision": ("CLIP_VISION",),
"image": ("IMAGE",) "image": ("IMAGE",),
"crop": (["center", "none"],)
}} }}
RETURN_TYPES = ("CLIP_VISION_OUTPUT",) RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "conditioning" CATEGORY = "conditioning"
def encode(self, clip_vision, image): def encode(self, clip_vision, image, crop):
output = clip_vision.encode_image(image) crop_image = True
if crop != "center":
crop_image = False
output = clip_vision.encode_image(image, crop=crop_image)
return (output,) return (output,)
class StyleModelLoader: class StyleModelLoader: