blend supports any size, dither -> quantize

This commit is contained in:
EllangoK 2023-04-03 09:52:04 -04:00
parent 4c7a9dbcb6
commit fa2febc062

View File

@ -1,5 +1,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
class Blend:
@ -28,6 +30,9 @@ class Blend:
CATEGORY = "postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
if image1.shape != image2.shape:
image2 = self.crop_and_resize(image2, image1.shape)
blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1)
@ -50,6 +55,29 @@ class Blend:
def g(self, x):
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
def crop_and_resize(self, img: torch.Tensor, target_shape: tuple):
batch_size, img_h, img_w, img_c = img.shape
_, target_h, target_w, _ = target_shape
img_aspect_ratio = img_w / img_h
target_aspect_ratio = target_w / target_h
# Crop center of the image to the target aspect ratio
if img_aspect_ratio > target_aspect_ratio:
new_width = int(img_h * target_aspect_ratio)
left = (img_w - new_width) // 2
img = img[:, :, left:left + new_width, :]
else:
new_height = int(img_w / target_aspect_ratio)
top = (img_h - new_height) // 2
img = img[:, top:top + new_height, :, :]
# Resize to target size
img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False)
img = img.permute(0, 2, 3, 1)
return img
class Blur:
def __init__(self):
pass
@ -100,7 +128,7 @@ class Blur:
return (blurred,)
class Dither:
class Quantize:
def __init__(self):
pass
@ -109,51 +137,37 @@ class Dither:
return {
"required": {
"image": ("IMAGE",),
"bits": ("INT", {
"default": 4,
"colors": ("INT", {
"default": 256,
"min": 1,
"max": 8,
"max": 256,
"step": 1
}),
"dither": (["none", "floyd-steinberg"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "dither"
FUNCTION = "quantize"
CATEGORY = "postprocessing"
def dither(self, image: torch.Tensor, bits: int):
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255)
height, width, _ = img.shape
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
scale = 255 / (2**bits - 1)
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
for y in range(height):
for x in range(width):
old_pixel = img[y, x].clone()
new_pixel = torch.round(old_pixel / scale) * scale
img[y, x] = new_pixel
quant_error = old_pixel - new_pixel
if x + 1 < width:
img[y, x + 1] += quant_error * 7 / 16
if y + 1 < height:
if x - 1 >= 0:
img[y + 1, x - 1] += quant_error * 3 / 16
img[y + 1, x] += quant_error * 5 / 16
if x + 1 < width:
img[y + 1, x + 1] += quant_error * 1 / 16
dithered = img / 255
tensor = dithered.unsqueeze(0)
result[b] = tensor
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array
return (result,)
@ -210,6 +224,6 @@ class Sharpen:
NODE_CLASS_MAPPINGS = {
"Blend": Blend,
"Blur": Blur,
"Dither": Dither,
"Quantize": Quantize,
"Sharpen": Sharpen,
}