logic to upload images from this server

This commit is contained in:
drunkplato 2025-01-21 19:00:08 +00:00 committed by Ubuntu
parent 47ee984278
commit 3f7db39fda
5 changed files with 178 additions and 126 deletions

3
.gitignore vendored
View File

@ -14,6 +14,8 @@ __pycache__/
!custom_nodes/example_node.py.example
!custom_nodes/MemedeckComfyNodes/
!custom_nodes/MemedeckComfyNodes/**
!comfy/ldm/models/autoencoder.py
!comfy/ldm/models/
extra_model_paths.yaml
/.vs
@ -43,4 +45,3 @@ comfy_venv_3.11
models-2
!comfy/ldm/models/autoencoder.py

View File

@ -1,3 +1,4 @@
<<<<<<< HEAD
import logging
import math
import torch
@ -7,6 +8,15 @@ from typing import Any, Dict, Tuple, Union
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
=======
import torch
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config
>>>>>>> 0e1536b4 (logic to upload images from this server)
from comfy.ldm.modules.ema import LitEma
import comfy.ops
@ -54,7 +64,11 @@ class AbstractAutoencoder(torch.nn.Module):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
<<<<<<< HEAD
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
=======
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
>>>>>>> 0e1536b4 (logic to upload images from this server)
def get_input(self, batch) -> Any:
raise NotImplementedError()
@ -70,14 +84,22 @@ class AbstractAutoencoder(torch.nn.Module):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
<<<<<<< HEAD
logging.info(f"{context}: Switched to EMA weights")
=======
logpy.info(f"{context}: Switched to EMA weights")
>>>>>>> 0e1536b4 (logic to upload images from this server)
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
<<<<<<< HEAD
logging.info(f"{context}: Restored training weights")
=======
logpy.info(f"{context}: Restored training weights")
>>>>>>> 0e1536b4 (logic to upload images from this server)
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@ -86,7 +108,11 @@ class AbstractAutoencoder(torch.nn.Module):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
<<<<<<< HEAD
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
=======
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
>>>>>>> 0e1536b4 (logic to upload images from this server)
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
@ -114,7 +140,11 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
<<<<<<< HEAD
self.regularization = instantiate_from_config(
=======
self.regularization: AbstractRegularizer = instantiate_from_config(
>>>>>>> 0e1536b4 (logic to upload images from this server)
regularizer_config
)
@ -162,6 +192,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
},
**kwargs,
)
<<<<<<< HEAD
if ddconfig.get("conv3d", False):
conv_op = comfy.ops.disable_weight_init.Conv3d
@ -169,12 +200,19 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
conv_op = comfy.ops.disable_weight_init.Conv2d
self.quant_conv = conv_op(
=======
self.quant_conv = comfy.ops.disable_weight_init.Conv2d(
>>>>>>> 0e1536b4 (logic to upload images from this server)
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
<<<<<<< HEAD
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
=======
self.post_quant_conv = comfy.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
>>>>>>> 0e1536b4 (logic to upload images from this server)
self.embed_dim = embed_dim
def get_autoencoder_params(self) -> list:

View File

@ -339,6 +339,7 @@ class MD_CompressAdjustNode:
image_cv2 = cv2.cvtColor(np.array(tensor2pil(image)), cv2.COLOR_RGB2BGR)
# calculate the crf based on the image
analysis_results = self.analyze_compression_artifacts(image_cv2, width=width, height=height)
logger.info(f"compression analysis_results: {analysis_results}")
calculated_crf = self.calculate_crf(analysis_results, self.ideal_blockiness, self.ideal_edge_density,
self.ideal_color_variation, self.blockiness_weight,
self.edge_density_weight, self.color_variation_weight)
@ -346,6 +347,8 @@ class MD_CompressAdjustNode:
if desired_crf is 0:
desired_crf = calculated_crf
logger.info(f"calculated_crf: {calculated_crf}")
# logger.info(f"desired_crf: {desired_crf}")
args = [
utils.ffmpeg_path,
"-v", "error",

View File

@ -7,6 +7,9 @@ from PIL import Image, ImageOps
from functools import partial
import pika
import json
import numpy as np
from lxml import etree
import io
import requests
@ -153,8 +156,10 @@ class MemedeckWorker:
routing_key = method.routing_key
workflow = 'faceswap' if routing_key == 'faceswap-queue' else 'generation'
user_id = None
user_id = payload["user_id"] if 'user_id' in payload else None
self.logger.info(f"[memedeck]: workflow {workflow} user_id: {user_id}")
if self.video_gen_only:
workflow = 'video_gen'
user_id = payload["user_id"]
@ -449,7 +454,7 @@ class MemedeckWorker:
async def send_preview(self, image_data, sid=None, progress=None, context=None, workflow=None):
self.logger.info(f"[memedeck]: send_preview: {sid}")
# self.logger.info(f"[memedeck]: send_preview: {sid}")
if sid is None:
self.logger.warning("Received preview without sid")
return
@ -483,16 +488,31 @@ class MemedeckWorker:
kind = "image_generating" if progress < 100 else "image_generated"
image_id = None
url = None
watermarked_url = None
if kind == "image_generated" and task['workflow'] != 'faceswap':
image_uuid = str(uuid.uuid4()).replace("-", "_") # create uuid for the image
blob_name = f"{task['user_id']}/{image_uuid}"
# upload to azure blob storage
url = await self.azure_storage.save_image(blob_name + ".jpeg", "image/jpeg", preview_bytes)
watermarked_url = await self.azure_storage.save_image_watermarked(blob_name + "_watermarked.jpeg", "image/jpeg", preview_bytes)
image_id = f"image:{image_uuid}"
ai_queue_progress = {
"ws_id": sid,
"kind": kind,
"data": list(preview_bytes),
"data": list(preview_bytes) if kind == "image_generating" else None,
"progress": int(progress),
"context": context
"context": context,
"user_id": task['user_id'],
"image_id": image_id,
"url": url,
"url_watermarked": watermarked_url
}
self.logger.info(f"[memedeck]: progress kind: {kind}")
self.logger.info(f"[memedeck]: progress: {progress}")
# self.logger.info(f"[memedeck]: progress kind: {kind}")
# self.logger.info(f"[memedeck]: progress: {progress}")
# set the kind to faceswap_generated if workflow is faceswap
if workflow == 'faceswap':
ai_queue_progress['kind'] = "faceswap_generated"
@ -501,7 +521,7 @@ class MemedeckWorker:
await self.send_to_api(ai_queue_progress)
if progress == 100 or workflow == 'faceswap':
del self.tasks_by_ws_id[sid] # Remove the task from tasks_by_ws_id
del self.tasks_by_ws_id[sid] # Remove the task from tasks_by_ws_id
# self.logger.info(f"[memedeck]: Task {sid} completed")
async def send_to_api(self, data):
@ -559,13 +579,17 @@ class MemedeckWorker:
# --------------------------------------------------------------------------
# MemedeckAzureStorage
# --------------------------------------------------------------------------
from azure.storage.blob.aio import BlobClient, BlobServiceClient
from azure.storage.blob.aio import BlobServiceClient
from azure.storage.blob import ContentSettings
from typing import Optional, Tuple
import cairosvg
WATERMARK = '<svg width="256" height="256" viewBox="0 0 256 256" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M60.0859 196.8C65.9526 179.067 71.5526 161.667 76.8859 144.6C79.1526 137.4 81.4859 129.867 83.8859 122C86.2859 114.133 88.6859 106.333 91.0859 98.6C93.4859 90.8667 95.6859 83.4 97.6859 76.2C99.8193 69 101.686 62.3333 103.286 56.2C110.619 56.2 117.553 55.8 124.086 55C130.619 54.2 137.686 53.4667 145.286 52.8C144.886 55.7333 144.419 59.0667 143.886 62.8C143.486 66.4 142.953 70.2 142.286 74.2C141.753 78.2 141.153 82.3333 140.486 86.6C139.819 90.8667 139.019 96.3333 138.086 103C137.153 109.667 135.886 118 134.286 128H136.886C140.753 117.867 143.953 109.467 146.486 102.8C149.019 96 151.086 90.4667 152.686 86.2C154.286 81.9333 155.886 77.8 157.486 73.8C159.219 69.6667 160.819 65.8 162.286 62.2C163.886 58.4667 165.353 55.2 166.686 52.4C170.019 52.1333 173.153 51.8 176.086 51.4C179.019 51 181.953 50.6 184.886 50.2C187.819 49.6667 190.753 49.2 193.686 48.8C196.753 48.2667 200.086 47.6667 203.686 47C202.353 54.7333 201.086 62.6667 199.886 70.8C198.686 78.9333 197.619 87.0667 196.686 95.2C195.753 103.2 194.819 111.133 193.886 119C193.086 126.867 192.353 134.333 191.686 141.4C190.086 157.933 188.686 174.067 187.486 189.8L152.686 196C152.686 195.333 152.753 193.533 152.886 190.6C153.153 187.667 153.419 184.067 153.686 179.8C154.086 175.533 154.553 170.8 155.086 165.6C155.753 160.4 156.353 155.2 156.886 150C157.553 144.8 158.219 139.8 158.886 135C159.553 130.067 160.219 125.867 160.886 122.4H159.086C157.219 128 155.153 133.933 152.886 140.2C150.619 146.333 148.286 152.6 145.886 159C143.619 165.4 141.353 171.667 139.086 177.8C136.819 183.933 134.819 189.8 133.086 195.4C128.419 195.533 124.419 195.733 121.086 196C117.753 196.133 113.886 196.333 109.486 196.6L115.886 122.4H112.886C112.619 124.133 111.953 127.067 110.886 131.2C109.819 135.2 108.553 139.867 107.086 145.2C105.753 150.4 104.286 155.867 102.686 161.6C101.086 167.2 99.5526 172.467 98.0859 177.4C96.7526 182.2 95.6193 186.2 94.6859 189.4C93.7526 192.467 93.2193 194.2 93.0859 194.6L60.0859 196.8Z" fill="white"/></svg>'
WATERMARK_SIZE = 40
WATERMARK = """
<svg width="256" height="256" viewBox="0 0 256 256" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M60.0859 196.8C65.9526 179.067 71.5526 161.667 76.8859 144.6C79.1526 137.4 81.4859 129.867 83.8859 122C86.2859 114.133 88.6859 106.333 91.0859 98.6C93.4859 90.8667 95.6859 83.4 97.6859 76.2C99.8193 69 101.686 62.3333 103.286 56.2C110.619 56.2 117.553 55.8 124.086 55C130.619 54.2 137.686 53.4667 145.286 52.8C144.886 55.7333 144.419 59.0667 143.886 62.8C143.486 66.4 142.953 70.2 142.286 74.2C141.753 78.2 141.153 82.3333 140.486 86.6C139.819 90.8667 139.019 96.3333 138.086 103C137.153 109.667 135.886 118 134.286 128H136.886C140.753 117.867 143.953 109.467 146.486 102.8C149.019 96 151.086 90.4667 152.686 86.2C154.286 81.9333 155.886 77.8 157.486 73.8C159.219 69.6667 160.819 65.8 162.286 62.2C163.886 58.4667 165.353 55.2 166.686 52.4C170.019 52.1333 173.153 51.8 176.086 51.4C179.019 51 181.953 50.6 184.886 50.2C187.819 49.6667 190.753 49.2 193.686 48.8C196.753 48.2667 200.086 47.6667 203.686 47C202.353 54.7333 201.086 62.6667 199.886 70.8C198.686 78.9333 197.619 87.0667 196.686 95.2C195.753 103.2 194.819 111.133 193.886 119C193.086 126.867 192.353 134.333 191.686 141.4C190.086 157.933 188.686 174.067 187.486 189.8L152.686 196C152.686 195.333 152.753 193.533 152.886 190.6C153.153 187.667 153.419 184.067 153.686 179.8C154.086 175.533 154.553 170.8 155.086 165.6C155.753 160.4 156.353 155.2 156.886 150C157.553 144.8 158.219 139.8 158.886 135C159.553 130.067 160.219 125.867 160.886 122.4H159.086C157.219 128 155.153 133.933 152.886 140.2C150.619 146.333 148.286 152.6 145.886 159C143.619 165.4 141.353 171.667 139.086 177.8C136.819 183.933 134.819 189.8 133.086 195.4C128.419 195.533 124.419 195.733 121.086 196C117.753 196.133 113.886 196.333 109.486 196.6L115.886 122.4H112.886C112.619 124.133 111.953 127.067 110.886 131.2C109.819 135.2 108.553 139.867 107.086 145.2C105.753 150.4 104.286 155.867 102.686 161.6C101.086 167.2 99.5526 172.467 98.0859 177.4C96.7526 182.2 95.6193 186.2 94.6859 189.4C93.7526 192.467 93.2193 194.2 93.0859 194.6L60.0859 196.8Z" fill="white"/>
</svg>
"""
WATERMARK_SIZE = 40
class MemedeckAzureStorage:
def __init__(self):
@ -573,7 +597,9 @@ class MemedeckAzureStorage:
self.account = os.getenv('STORAGE_ACCOUNT')
self.access_key = os.getenv('STORAGE_ACCESS_KEY')
self.container = os.getenv('STORAGE_CONTAINER')
self.logger = logging.getLogger(__name__)
logging.getLogger('azure.core.pipeline.policies.http_logging_policy').setLevel(logging.WARNING)
logging.getLogger("azure.storage.common.storageclient").setLevel(logging.WARNING)
self.logger = logging.getLogger('azure.storage.common')
if not all([self.account, self.access_key, self.container]):
raise EnvironmentError("Missing STORAGE_ACCOUNT, STORAGE_ACCESS_KEY, or STORAGE_CONTAINER environment variables")
@ -607,6 +633,7 @@ class MemedeckAzureStorage:
# Upload the blob
try:
# prevent logging the request
await blob_client.upload_blob(
bytes_data,
overwrite=True,
@ -621,126 +648,109 @@ class MemedeckAzureStorage:
# Construct and return the blob URL
blob_url = f"https://media.memedeck.xyz/{self.container}/{blob_name}"
return blob_url
async def save_image_watermarked(
self,
blob_name: str,
content_type: str,
bytes_data: bytes
) -> str:
image = Image.open(BytesIO(bytes_data))
watermarked_image = self.add_watermark_to_image(image)
# convert pil to bytes
img_byte_arr = BytesIO() # Create an in-memory byte stream
watermarked_image.save(img_byte_arr, format=image.format, quality=100, compress_level=1) # Save the image to the in-memory stream
watermarked_image_bytes = img_byte_arr.getvalue()
return await self.save_image(blob_name, content_type, watermarked_image_bytes)
def add_watermark_to_image(self, img, background_brightness=None):
"""
Adds a watermark to a single PIL Image.
# async def add_watermark(
# self,
# base_blob_name: str,
# base_image: bytes
# ) -> str:
# """
# Adds a watermark to the provided image and uploads the watermarked image.
Args:
img: A PIL Image object.
# Args:
# base_blob_name (str): Original blob name of the image.
# base_image (bytes): Image data in bytes.
Returns:
A PIL Image object with the watermark added.
"""
# Returns:
# str: URL of the watermarked image.
# """
# # Load the input image
# try:
# img = Image.open(BytesIO(base_image)).convert("RGBA")
# except Exception as e:
# raise Exception(f"Failed to load image: {e}")
padding = 12
x = img.width - WATERMARK_SIZE - padding
y = img.height - WATERMARK_SIZE - padding
# # Calculate position for the watermark (bottom right corner with padding)
# padding = 12
# x = img.width - WATERMARK_SIZE - padding
# y = img.height - WATERMARK_SIZE - padding
if background_brightness is None:
background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE)
# # Analyze background brightness where the watermark will be placed
# background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE)
# self.logger.info(f"Background brightness: {background_brightness}")
# Generate watermark image (replace this with your actual watermark generation)
watermark = self.generate_watermark(WATERMARK_SIZE, background_brightness)
# # Render SVG watermark to PNG bytes using cairosvg
# try:
# watermark_png_bytes = cairosvg.svg2png(bytestring=WATERMARK.encode('utf-8'), output_width=WATERMARK_SIZE, output_height=WATERMARK_SIZE)
# watermark = Image.open(BytesIO(watermark_png_bytes)).convert("RGBA")
# except Exception as e:
# raise Exception(f"Failed to render watermark SVG: {e}")
# Overlay the watermark
img.paste(watermark, (x, y), watermark)
# # Determine watermark color based on background brightness
# if background_brightness > 128:
# # Dark watermark for light backgrounds
# watermark_color = (0, 0, 0, int(255 * 0.65)) # Black with 65% opacity
# else:
# # Light watermark for dark backgrounds
# watermark_color = (255, 255, 255, int(255 * 0.65)) # White with 65% opacity
# # Apply the watermark color by blending
# solid_color = Image.new("RGBA", watermark.size, watermark_color)
# watermark = Image.alpha_composite(watermark, solid_color)
# # Overlay the watermark onto the original image
# img.paste(watermark, (x, y), watermark)
# # Save the watermarked image to bytes
# buffer = BytesIO()
# img = img.convert("RGB") # Convert back to RGB for JPEG format
# img.save(buffer, format="JPEG")
# buffer.seek(0)
# jpeg_bytes = buffer.read()
# # Modify the blob name to include '_watermarked'
# try:
# if "memes/" in base_blob_name:
# base_blob_name_right = base_blob_name.split("memes/", 1)[1]
# else:
# base_blob_name_right = base_blob_name
# base_blob_name_split = base_blob_name_right.rsplit(".", 1)
# base_blob_name_without_extension = base_blob_name_split[0]
# extension = base_blob_name_split[1]
# except Exception as e:
# raise Exception(f"Failed to process blob name: {e}")
# watermarked_blob_name = f"{base_blob_name_without_extension}_watermarked.{extension}"
# # Upload the watermarked image
# try:
# watermarked_blob_url = await self.save_image(
# watermarked_blob_name,
# "image/jpeg",
# jpeg_bytes
# )
# return watermarked_blob_url
# except Exception as e:
# raise Exception(f"Failed to upload watermarked image: {e}")
# def analyze_background_brightness(
# self,
# img: Image.Image,
# x: int,
# y: int,
# size: int
# ) -> int:
# """
# Analyzes the brightness of a specific region in the image.
# Args:
# img (Image.Image): The image to analyze.
# x (int): X-coordinate of the top-left corner of the region.
# y (int): Y-coordinate of the top-left corner of the region.
# size (int): Size of the square region to analyze.
# Returns:
# int: Average brightness (0-255) of the region.
# """
# # Crop the specified region
# sub_image = img.crop((x, y, x + size, y + size)).convert("RGB")
# # Calculate average brightness using the luminance formula
# total_brightness = 0
# pixel_count = 0
# for pixel in sub_image.getdata():
# r, g, b = pixel
# brightness = (r * 299 + g * 587 + b * 114) // 1000
# total_brightness += brightness
# pixel_count += 1
# if pixel_count == 0:
# return 0
# average_brightness = total_brightness // pixel_count
# return average_brightness
return img
def analyze_background_brightness(self, img, x, y, size):
"""
Analyzes the average brightness of a region in the image.
Args:
img: A PIL Image object.
x: The x-coordinate of the top-left corner of the region.
y: The y-coordinate of the top-left corner of the region.
size: The size of the region (square).
Returns:
The average brightness of the region as an integer.
"""
region = img.crop((x, y, x + size, y + size))
pixels = np.array(region)
total_brightness = np.sum(
0.299 * pixels[:, :, 0] + 0.587 * pixels[:, :, 1] + 0.114 * pixels[:, :, 2]
) / 1000
print(f"total_brightness: {total_brightness}")
return max(0, min(255, total_brightness))
def generate_watermark(self, size, background_brightness):
"""
Generates a watermark image from an SVG string.
Args:
size: The size of the watermark (square).
background_brightness: The background brightness at the watermark position.
Returns:
A PIL Image object representing the watermark.
"""
# Determine watermark color based on background brightness
watermark_color = (0, 0, 0, 165) if background_brightness > 128 else (255, 255, 255, 165)
# Parse the SVG string
svg_tree = etree.fromstring(WATERMARK)
# Find the path element and set its fill attribute
path_element = svg_tree.find(".//{http://www.w3.org/2000/svg}path")
if path_element is not None:
r, g, b, a = watermark_color
fill_color = f"rgba({r},{g},{b},{a/255})" # Convert to rgba string
path_element.set("fill", fill_color)
# Convert the modified SVG tree back to a string
modified_svg = etree.tostring(svg_tree, encoding="unicode")
# Render the modified SVG to a PNG image with a transparent background
png_data = cairosvg.svg2png(
bytestring=modified_svg,
output_width=size,
output_height=size,
background_color="transparent"
)
watermark_img = Image.open(BytesIO(png_data))
# Convert the watermark to RGBA to handle transparency
watermark_img = watermark_img.convert("RGBA")
return watermark_img