diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index ed917cfb..c9c82e9a 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -16,13 +16,18 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]): +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): mean = torch.tensor(mean, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype) image = image.movedim(-1, 1) if not (image.shape[2] == size and image.shape[3] == size): - scale = (size / min(image.shape[2], image.shape[3])) - image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True) + if crop: + scale = (size / min(image.shape[2], image.shape[3])) + scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) + else: + scale_size = (size, size) + + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) h = (image.shape[2] - size)//2 w = (image.shape[3] - size)//2 image = image[:,:,h:h+size,w:w+size] @@ -51,9 +56,9 @@ class ClipVisionModel(): def get_sd(self): return self.model.state_dict() - def encode_image(self, image): + def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float() + pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output=-2) outputs = Output() diff --git a/nodes.py b/nodes.py index 3a68d43c..cc21181b 100644 --- a/nodes.py +++ b/nodes.py @@ -971,15 +971,19 @@ class CLIPVisionEncode: @classmethod def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), - "image": ("IMAGE",) + "image": ("IMAGE",), + "crop": (["center", "none"],) }} RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" CATEGORY = "conditioning" - def encode(self, clip_vision, image): - output = clip_vision.encode_image(image) + def encode(self, clip_vision, image, crop): + crop_image = True + if crop != "center": + crop_image = False + output = clip_vision.encode_image(image, crop=crop_image) return (output,) class StyleModelLoader: