mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-20 03:13:30 +00:00
adding custom_node permanently
This commit is contained in:
parent
497390aff6
commit
e512458a79
17
.gitignore
vendored
17
.gitignore
vendored
@ -3,10 +3,18 @@ __pycache__/
|
||||
/output/
|
||||
/input/
|
||||
!/input/example.png
|
||||
!comfy/ldm/models/autoencoder.py
|
||||
/models/
|
||||
/temp/
|
||||
/custom_nodes/
|
||||
# Ignore everything in custom_nodes
|
||||
/custom_nodes/*
|
||||
!comfy/ldm/models/autoencoder.py
|
||||
|
||||
# Explicitly allow these files/directories
|
||||
!custom_nodes/example_node.py.example
|
||||
!custom_nodes/MemedeckComfyNodes/
|
||||
!custom_nodes/MemedeckComfyNodes/**
|
||||
|
||||
extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
@ -26,12 +34,13 @@ web_custom_versions/
|
||||
|
||||
|
||||
models
|
||||
custom_nodes
|
||||
models/
|
||||
flux_loras/
|
||||
custom_nodes/
|
||||
pysssss-workflows/
|
||||
|
||||
comfy-venv
|
||||
comfy_venv_3.11
|
||||
models-2
|
||||
models-2
|
||||
|
||||
|
||||
!comfy/ldm/models/autoencoder.py
|
||||
|
2
custom_nodes/MemedeckComfyNodes/.gitignore
vendored
Normal file
2
custom_nodes/MemedeckComfyNodes/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
|
||||
__pycache__/
|
12
custom_nodes/MemedeckComfyNodes/README.md
Normal file
12
custom_nodes/MemedeckComfyNodes/README.md
Normal file
@ -0,0 +1,12 @@
|
||||
# MemedeckComfyNodes
|
||||
|
||||
This is a collection of nodes for Memedeck ComfyUI.
|
||||
|
||||
## Nodes
|
||||
|
||||
- MD_LoadImageFromUrl
|
||||
- MD_ImageToMotionPrompt
|
||||
- MD_SaveAnimatedWEBP
|
||||
- MD_LoadVideoModel
|
||||
- MD_ImgToVideo
|
||||
- MD_VideoToImg
|
32
custom_nodes/MemedeckComfyNodes/__init__.py
Normal file
32
custom_nodes/MemedeckComfyNodes/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
from .nodes_preprocessing import MD_LoadImageFromUrl, MD_CompressAdjustNode, MD_ImageToMotionPrompt
|
||||
from .nodes_model import MD_LoadVideoModel, MD_ImgToVideo, MD_VideoSampler
|
||||
from .nodes_output import MD_SaveAnimatedWEBP, MD_SaveMP4
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
# PREPROCESSING
|
||||
"Memedeck_ImageToMotionPrompt": MD_ImageToMotionPrompt,
|
||||
"Memedeck_CompressAdjustNode": MD_CompressAdjustNode,
|
||||
"Memedeck_LoadImageFromUrl": MD_LoadImageFromUrl,
|
||||
# MODEL NODES
|
||||
"Memedeck_LoadVideoModel": MD_LoadVideoModel,
|
||||
"Memedeck_ImgToVideo": MD_ImgToVideo,
|
||||
"Memedeck_VideoSampler": MD_VideoSampler,
|
||||
# POSTPROCESSING
|
||||
"Memedeck_SaveMP4": MD_SaveMP4,
|
||||
"Memedeck_SaveAnimatedWEBP": MD_SaveAnimatedWEBP
|
||||
# "Memedeck_SaveAnimatedGIF": MD_SaveAnimatedGIF
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# PREPROCESSING
|
||||
"Memedeck_ImageToMotionPrompt": "MemeDeck: Image To Motion Prompt",
|
||||
"Memedeck_CompressAdjustNode": "MemeDeck: Compression Detector & Adjuster",
|
||||
"Memedeck_LoadImageFromUrl": "MemeDeck: Load Image From URL",
|
||||
# MODEL NODES
|
||||
"Memedeck_LoadVideoModel": "MemeDeck: Load Video Model",
|
||||
"Memedeck_VideoScheduler": "MemeDeck: Video Scheduler",
|
||||
"Memedeck_ImgToVideo": "MemeDeck: Image To Video",
|
||||
"Memedeck_VideoSampler": "MemeDeck: Video Sampler",
|
||||
# POSTPROCESSING
|
||||
"Memedeck_SaveMP4": "MemeDeck: Save MP4"
|
||||
# "Memedeck_SaveAnimatedGIF": "MemeDeck: Save Animated GIF"
|
||||
}
|
32
custom_nodes/MemedeckComfyNodes/lib/image.py
Normal file
32
custom_nodes/MemedeckComfyNodes/lib/image.py
Normal file
@ -0,0 +1,32 @@
|
||||
import base64
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
import torch
|
||||
|
||||
|
||||
def tensor2pil(image: Tensor) -> PIL.Image.Image:
|
||||
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||||
|
||||
def pil2base64(image: PIL.Image.Image) -> str:
|
||||
from io import BytesIO
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="JPEG")
|
||||
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
def pil2tensor(images: Image.Image | list[Image.Image]) -> torch.Tensor:
|
||||
"""Converts a PIL Image or a list of PIL Images to a tensor."""
|
||||
|
||||
def single_pil2tensor(image: Image.Image) -> torch.Tensor:
|
||||
np_image = np.array(image).astype(np.float32) / 255.0
|
||||
if np_image.ndim == 2: # Grayscale
|
||||
return torch.from_numpy(np_image).unsqueeze(0) # (1, H, W)
|
||||
else: # RGB or RGBA
|
||||
return torch.from_numpy(np_image).unsqueeze(0) # (1, H, W, C)
|
||||
|
||||
if isinstance(images, Image.Image):
|
||||
return single_pil2tensor(images)
|
||||
else:
|
||||
return torch.cat([single_pil2tensor(img) for img in images], dim=0)
|
56
custom_nodes/MemedeckComfyNodes/lib/utils.py
Normal file
56
custom_nodes/MemedeckComfyNodes/lib/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
def ffmpeg_suitability(path):
|
||||
try:
|
||||
version = subprocess.run([path, "-version"], check=True,
|
||||
capture_output=True).stdout.decode("utf-8")
|
||||
except:
|
||||
return 0
|
||||
score = 0
|
||||
#rough layout of the importance of various features
|
||||
simple_criterion = [("libvpx", 20),("264",10), ("265",3),
|
||||
("svtav1",5),("libopus", 1)]
|
||||
for criterion in simple_criterion:
|
||||
if version.find(criterion[0]) >= 0:
|
||||
score += criterion[1]
|
||||
#obtain rough compile year from copyright information
|
||||
copyright_index = version.find('2000-2')
|
||||
if copyright_index >= 0:
|
||||
copyright_year = version[copyright_index+6:copyright_index+9]
|
||||
if copyright_year.isnumeric():
|
||||
score += int(copyright_year)
|
||||
return score
|
||||
|
||||
if "VHS_FORCE_FFMPEG_PATH" in os.environ:
|
||||
ffmpeg_path = os.environ.get("VHS_FORCE_FFMPEG_PATH")
|
||||
else:
|
||||
ffmpeg_paths = []
|
||||
try:
|
||||
from imageio_ffmpeg import get_ffmpeg_exe
|
||||
imageio_ffmpeg_path = get_ffmpeg_exe()
|
||||
ffmpeg_paths.append(imageio_ffmpeg_path)
|
||||
except:
|
||||
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
|
||||
raise
|
||||
# logger.warn("Failed to import imageio_ffmpeg")
|
||||
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
|
||||
ffmpeg_path = imageio_ffmpeg_path
|
||||
else:
|
||||
system_ffmpeg = shutil.which("ffmpeg")
|
||||
if system_ffmpeg is not None:
|
||||
ffmpeg_paths.append(system_ffmpeg)
|
||||
if os.path.isfile("ffmpeg"):
|
||||
ffmpeg_paths.append(os.path.abspath("ffmpeg"))
|
||||
if os.path.isfile("ffmpeg.exe"):
|
||||
ffmpeg_paths.append(os.path.abspath("ffmpeg.exe"))
|
||||
if len(ffmpeg_paths) == 0:
|
||||
# logger.error("No valid ffmpeg found.")
|
||||
ffmpeg_path = None
|
||||
elif len(ffmpeg_paths) == 1:
|
||||
#Evaluation of suitability isn't required, can take sole option
|
||||
#to reduce startup time
|
||||
ffmpeg_path = ffmpeg_paths[0]
|
||||
else:
|
||||
ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability)
|
186
custom_nodes/MemedeckComfyNodes/modules/video_model.py
Normal file
186
custom_nodes/MemedeckComfyNodes/modules/video_model.py
Normal file
@ -0,0 +1,186 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
import math
|
||||
|
||||
from comfy.ldm.lightricks.model import apply_rotary_emb, precompute_freqs_cis, LTXVModel, BasicTransformerBlock
|
||||
|
||||
|
||||
class GenesisModifiedCrossAttention(nn.Module):
|
||||
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
|
||||
context = x if context is None else context
|
||||
context_v = x if context is None else context
|
||||
|
||||
step = transformer_options.get('step', -1)
|
||||
total_steps = transformer_options.get('total_steps', 0)
|
||||
attn_bank = transformer_options.get('attn_bank', None)
|
||||
sample_mode = transformer_options.get('sample_mode', None)
|
||||
if attn_bank is not None and self.idx in attn_bank['block_map']:
|
||||
len_conds = len(transformer_options['cond_or_uncond'])
|
||||
pred_order = transformer_options['pred_order']
|
||||
if sample_mode == 'forward' and total_steps-step-1 < attn_bank['save_steps']:
|
||||
step_idx = f'{pred_order}_{total_steps-step-1}'
|
||||
attn_bank['block_map'][self.idx][step_idx] = x.cpu()
|
||||
elif sample_mode == 'reverse' and step < attn_bank['inject_steps']:
|
||||
step_idx = f'{pred_order}_{step}'
|
||||
inject_settings = attn_bank.get('inject_settings', {})
|
||||
if len(inject_settings) > 0:
|
||||
inj = attn_bank['block_map'][self.idx][step_idx].to(x.device).repeat(len_conds, 1, 1)
|
||||
if 'q' in inject_settings:
|
||||
x = inj
|
||||
if 'k' in inject_settings:
|
||||
context = inj
|
||||
if 'v' in inject_settings:
|
||||
context_v = inj
|
||||
|
||||
q = self.to_q(x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context_v)
|
||||
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
|
||||
alt_attn_fn = transformer_options.get('patches_replace', {}).get(f'layer', {}).get(('self_attn', self.idx), None)
|
||||
if alt_attn_fn is not None:
|
||||
out = alt_attn_fn(q,k,v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
elif mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class GenesisModifiedBasicTransformerBlock(BasicTransformerBlock):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class GenesisModelModified(LTXVModel):
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latents={}, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
guiding_latents = transformer_options.get('patches', {}).get('guiding_latents', None)
|
||||
|
||||
indices_grid = self.patchifier.get_grid(
|
||||
orig_num_frames=x.shape[2],
|
||||
orig_height=x.shape[3],
|
||||
orig_width=x.shape[4],
|
||||
batch_size=x.shape[0],
|
||||
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
ts = None
|
||||
input_x = None
|
||||
|
||||
if guiding_latents is not None:
|
||||
input_x = x.clone()
|
||||
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
|
||||
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
|
||||
ts *= input_ts
|
||||
for guide in guiding_latents:
|
||||
ts[:, :, guide.index] = 0.0
|
||||
x[:,:,guide.index] = guide.latent[:,:,0]
|
||||
timestep = self.patchifier.patchify(ts)
|
||||
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
x = self.patchifier.patchify(x)
|
||||
|
||||
x = self.patchify_proj(x)
|
||||
timestep = timestep * 1000.0
|
||||
|
||||
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
|
||||
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
|
||||
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
|
||||
|
||||
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
|
||||
|
||||
batch_size = x.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=x.dtype,
|
||||
)
|
||||
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, -1, embedded_timestep.shape[-1]
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = x.shape[0]
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(
|
||||
batch_size, -1, x.shape[-1]
|
||||
)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
x,
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
timestep=timestep,
|
||||
pe=pe,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
scale_shift_values = (
|
||||
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||
)
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
latents=x,
|
||||
output_height=orig_shape[3],
|
||||
output_width=orig_shape[4],
|
||||
output_num_frames=orig_shape[2],
|
||||
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
||||
)
|
||||
|
||||
if guiding_latents is not None:
|
||||
for guide in guiding_latents:
|
||||
x[:, :, guide.index] = (input_x[:, :, guide.index] - guide.latent[:, :, 0]) / input_ts[:, :, 0]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def inject_model(diffusion_model):
|
||||
diffusion_model.__class__ = GenesisModelModified
|
||||
for idx, transformer_block in enumerate(diffusion_model.transformer_blocks):
|
||||
transformer_block.__class__ = GenesisModifiedBasicTransformerBlock
|
||||
transformer_block.idx = idx
|
||||
transformer_block.attn1.__class__ = GenesisModifiedCrossAttention
|
||||
transformer_block.attn1.idx = idx
|
||||
return diffusion_model
|
390
custom_nodes/MemedeckComfyNodes/nodes_model.py
Normal file
390
custom_nodes/MemedeckComfyNodes/nodes_model.py
Normal file
@ -0,0 +1,390 @@
|
||||
import json
|
||||
import math
|
||||
|
||||
from comfy_extras.nodes_custom_sampler import Noise_RandomNoise
|
||||
import latent_preview
|
||||
from .modules.video_model import inject_model
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy
|
||||
|
||||
class MD_LoadVideoModel:
|
||||
"""
|
||||
Loads the DIT model for video generation.
|
||||
"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"chkpt_name": (folder_paths.get_filename_list("checkpoints"), {
|
||||
"default": "genesis-dit-video-2b.safetensors",
|
||||
"tooltip": "The name of the checkpoint (model) to load."
|
||||
}),
|
||||
"clip_name": (folder_paths.get_filename_list("text_encoders"), {
|
||||
"default": "t5xxl_fp16.safetensors",
|
||||
"tooltip": "The name of the clip (model) to load."
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
RETURN_NAMES = ("model", "clip", "vae")
|
||||
FUNCTION = "load_model"
|
||||
CATEGORY = "MemeDeck"
|
||||
|
||||
def load_model(self, chkpt_name, clip_name):
|
||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", chkpt_name)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
model = out[0]
|
||||
vae = out[2]
|
||||
|
||||
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
|
||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=8)
|
||||
|
||||
# modify model
|
||||
model.model.diffusion_model = inject_model(model.model.diffusion_model)
|
||||
|
||||
return (model, clip, vae, )
|
||||
|
||||
class LatentGuide(torch.nn.Module):
|
||||
def __init__(self, latent: torch.Tensor, index) -> None:
|
||||
super().__init__()
|
||||
self.index = index
|
||||
self.register_buffer('latent', latent)
|
||||
|
||||
class MD_ImgToVideo:
|
||||
"""
|
||||
Sets the conditioning and dimensions for video generation.
|
||||
"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"vae": ("VAE",),
|
||||
"image": ("IMAGE",),
|
||||
"width": ("INT", {
|
||||
"default": 832,
|
||||
"description": "The width of the video."
|
||||
}),
|
||||
"height": ("INT", {
|
||||
"default": 832,
|
||||
"description": "The height of the video."
|
||||
}),
|
||||
"length": ("INT", {
|
||||
"default": 97,
|
||||
"description": "The length of the video."
|
||||
}),
|
||||
"fps": ("INT", {
|
||||
"default": 24,
|
||||
"description": "The fps of the video."
|
||||
}),
|
||||
# LATENT GUIDE INPUTS
|
||||
"add_latent_guide_index": ("INT", {
|
||||
"default": 0,
|
||||
"description": "The index of the latent to add to the guide."
|
||||
}),
|
||||
"add_latent_guide_insert": ("BOOLEAN", {
|
||||
"default": False,
|
||||
"description": "Whether to add the latent to the guide."
|
||||
}),
|
||||
# SCHEDULER INPUTS
|
||||
"steps": ("INT", {
|
||||
"default": 40,
|
||||
"description": "Number of steps to generate the video."
|
||||
}),
|
||||
"max_shift": ("FLOAT", {
|
||||
"default": 1.5,
|
||||
"step": 0.01,
|
||||
"description": "The maximum shift of the video."
|
||||
}),
|
||||
"base_shift": ("FLOAT", {
|
||||
"default": 0.95,
|
||||
"step": 0.01,
|
||||
"description": "The base shift of the video."
|
||||
}),
|
||||
"stretch": ("BOOLEAN", {
|
||||
"default": True,
|
||||
"description": "Stretch the sigmas to be in the range [terminal, 1]."
|
||||
}),
|
||||
"terminal": ("FLOAT", {
|
||||
"default": 0.1,
|
||||
"description": "The terminal values of the sigmas after stretching."
|
||||
}),
|
||||
# ATTENTION OVERRIDE INPUTS
|
||||
"attention_override": ("STRING", {
|
||||
"default": 14,
|
||||
"description": "The amount of attention to override the model with."
|
||||
}),
|
||||
"attention_adjustment_scale": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"description": "The scale of the attention adjustment."
|
||||
}),
|
||||
"attention_adjustment_rescale": ("FLOAT", {
|
||||
"default": 0.5,
|
||||
"description": "The scale of the attention adjustment."
|
||||
}),
|
||||
"attention_adjustment_cfg": ("FLOAT", {
|
||||
"default": 3.0,
|
||||
"description": "The scale of the attention adjustment."
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING", "SIGMAS", "LATENT", "STRING")
|
||||
RETURN_NAMES = ("model", "positive", "negative", "sigmas", "latent", "img2vid_metadata")
|
||||
FUNCTION = "img_to_video"
|
||||
CATEGORY = "MemeDeck"
|
||||
|
||||
def img_to_video(self, model, positive, negative, vae, image, width, height, length, fps, add_latent_guide_index, add_latent_guide_insert, steps, max_shift, base_shift, stretch, terminal, attention_override, attention_adjustment_scale, attention_adjustment_rescale, attention_adjustment_cfg):
|
||||
batch_size = 1
|
||||
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
|
||||
|
||||
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||
latent[:, :, :t.shape[2]] = t
|
||||
latent_samples = {"samples": latent}
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": fps})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": fps})
|
||||
|
||||
# 2. add latent guide
|
||||
model, latent_updated = self.add_latent_guide(model, latent_samples, latent_samples, add_latent_guide_index, add_latent_guide_insert)
|
||||
|
||||
# 3. apply attention override
|
||||
attn_override_layers = self.attention_override(attention_override)
|
||||
model = self.apply_attention_override(model, attention_adjustment_scale, attention_adjustment_rescale, attention_adjustment_cfg, attn_override_layers)
|
||||
|
||||
# 5. configure scheduler
|
||||
sigmas = self.get_sigmas(steps, max_shift, base_shift, stretch, terminal, latent_updated)
|
||||
|
||||
# all parameters starting with width, height, fps, crf, etc
|
||||
img2vid_metadata = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"length": length,
|
||||
"fps": fps,
|
||||
"steps": steps,
|
||||
"max_shift": max_shift,
|
||||
"base_shift": base_shift,
|
||||
"stretch": stretch,
|
||||
"terminal": terminal,
|
||||
"attention_override": attention_override,
|
||||
"attention_adjustment_scale": attention_adjustment_scale,
|
||||
"attention_adjustment_rescale": attention_adjustment_rescale,
|
||||
"attention_adjustment_cfg": attention_adjustment_cfg,
|
||||
}
|
||||
|
||||
json_img2vid_metadata = json.dumps(img2vid_metadata)
|
||||
return (model, positive, negative, sigmas, latent_updated, json_img2vid_metadata)
|
||||
|
||||
# -----------------------------
|
||||
# Attention functions
|
||||
# -----------------------------
|
||||
# 1. Add latent guide
|
||||
def add_latent_guide(self, model, latent, image_latent, index, insert):
|
||||
image_latent = image_latent['samples']
|
||||
latent = latent['samples'].clone()
|
||||
|
||||
# Convert negative index to positive
|
||||
if insert:
|
||||
# Handle insertion
|
||||
if index == 0:
|
||||
# Insert at beginning
|
||||
latent = torch.cat([image_latent[:,:,0:1], latent], dim=2)
|
||||
elif index >= latent.shape[2] or index < 0:
|
||||
# Append to end
|
||||
latent = torch.cat([latent, image_latent[:,:,0:1]], dim=2)
|
||||
else:
|
||||
# Insert in middle
|
||||
latent = torch.cat([
|
||||
latent[:,:,:index],
|
||||
image_latent[:,:,0:1],
|
||||
latent[:,:,index:]
|
||||
], dim=2)
|
||||
else:
|
||||
# Original replacement behavior
|
||||
latent[:,:,index] = image_latent[:,:,0]
|
||||
|
||||
model = model.clone()
|
||||
guiding_latent = LatentGuide(image_latent, index)
|
||||
model.set_model_patch(guiding_latent, 'guiding_latents')
|
||||
|
||||
return (model, {"samples": latent},)
|
||||
|
||||
# 2. Apply attention override
|
||||
def is_integer(self, string):
|
||||
try:
|
||||
int(string)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def attention_override(self, layers: str = "14"):
|
||||
layers_map = set([])
|
||||
for block in layers.split(','):
|
||||
block = block.strip()
|
||||
if self.is_integer(block):
|
||||
layers_map.add(block)
|
||||
|
||||
return layers_map
|
||||
|
||||
def apply_attention_override(self, model, scale, rescale, cfg, attention_override: set):
|
||||
m = model.clone()
|
||||
|
||||
def pag_fn(q, k,v, heads, attn_precision=None, transformer_options=None):
|
||||
return v
|
||||
|
||||
def post_cfg_function(args):
|
||||
model = args["model"]
|
||||
|
||||
cond_pred = args["cond_denoised"]
|
||||
uncond_pred = args["uncond_denoised"]
|
||||
|
||||
len_conds = 1 if args.get('uncond', None) is None else 2
|
||||
|
||||
cond = args["cond"]
|
||||
sigma = args["sigma"]
|
||||
model_options = args["model_options"].copy()
|
||||
x = args["input"]
|
||||
|
||||
if scale == 0:
|
||||
if len_conds == 1:
|
||||
return cond_pred
|
||||
return uncond_pred + (cond_pred - uncond_pred)
|
||||
|
||||
for block_idx in attention_override:
|
||||
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, pag_fn, f"layer", "self_attn", int(block_idx))
|
||||
|
||||
(perturbed,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
|
||||
|
||||
output = uncond_pred + cfg * (cond_pred - uncond_pred) \
|
||||
+ scale * (cond_pred - perturbed)
|
||||
if rescale > 0:
|
||||
factor = cond_pred.std() / output.std()
|
||||
factor = rescale * factor + (1 - rescale)
|
||||
output = output * factor
|
||||
|
||||
return output
|
||||
|
||||
|
||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||
|
||||
return m
|
||||
|
||||
# -----------------------------
|
||||
# Scheduler
|
||||
# -----------------------------
|
||||
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
|
||||
if latent is None:
|
||||
tokens = 4096
|
||||
else:
|
||||
tokens = math.prod(latent["samples"].shape[2:])
|
||||
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1)
|
||||
|
||||
x1 = 1024
|
||||
x2 = 4096
|
||||
mm = (max_shift - base_shift) / (x2 - x1)
|
||||
b = base_shift - mm * x1
|
||||
sigma_shift = (tokens) * mm + b
|
||||
|
||||
power = 1
|
||||
sigmas = torch.where(
|
||||
sigmas != 0,
|
||||
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
|
||||
0,
|
||||
)
|
||||
|
||||
# Stretch sigmas so that its final value matches the given terminal value.
|
||||
if stretch:
|
||||
non_zero_mask = sigmas != 0
|
||||
non_zero_sigmas = sigmas[non_zero_mask]
|
||||
one_minus_z = 1.0 - non_zero_sigmas
|
||||
scale_factor = one_minus_z[-1] / (1.0 - terminal)
|
||||
stretched = 1.0 - (one_minus_z / scale_factor)
|
||||
sigmas[non_zero_mask] = stretched
|
||||
|
||||
return sigmas
|
||||
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "ddim", "euler_ancestral", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis"]
|
||||
|
||||
class MD_VideoSampler:
|
||||
"""
|
||||
Samples the video.
|
||||
"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"positive": ("CONDITIONING",),
|
||||
"negative": ("CONDITIONING",),
|
||||
"sigmas": ("SIGMAS",),
|
||||
"latent_image": ("LATENT",),
|
||||
"sampler": (KSAMPLER_NAMES, ),
|
||||
"noise_seed": ("INT", {
|
||||
"default": 42,
|
||||
"description": "The seed of the noise."
|
||||
}),
|
||||
"cfg": ("FLOAT", {
|
||||
"default": 5.0,
|
||||
"min": 0.0,
|
||||
"max": 30.0,
|
||||
"step": 0.01,
|
||||
"description": "The cfg of the video."
|
||||
}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("LATENT", "LATENT", "STRING")
|
||||
RETURN_NAMES = ("output", "denoised_output", "img2vid_metadata")
|
||||
FUNCTION = "video_sampler"
|
||||
CATEGORY = "MemeDeck"
|
||||
|
||||
def video_sampler(self, model, positive, negative, sigmas, latent_image, sampler, noise_seed, cfg):
|
||||
latent = latent_image
|
||||
latent_image = latent["samples"]
|
||||
latent = latent.copy()
|
||||
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
|
||||
latent["samples"] = latent_image
|
||||
|
||||
sampler_name = sampler
|
||||
noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
|
||||
sampler = comfy.samplers.sampler_object(sampler)
|
||||
|
||||
noise_mask = None
|
||||
if "noise_mask" in latent:
|
||||
noise_mask = latent["noise_mask"]
|
||||
|
||||
x0_output = {}
|
||||
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
|
||||
|
||||
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
|
||||
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
|
||||
|
||||
out = latent.copy()
|
||||
out["samples"] = samples
|
||||
if "x0" in x0_output:
|
||||
out_denoised = latent.copy()
|
||||
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
|
||||
else:
|
||||
out_denoised = out
|
||||
|
||||
sampler_metadata = {
|
||||
"sampler": sampler_name,
|
||||
"noise_seed": noise_seed,
|
||||
"cfg": cfg,
|
||||
}
|
||||
|
||||
json_sampler_metadata = json.dumps(sampler_metadata)
|
||||
return (out, out_denoised, json_sampler_metadata)
|
142
custom_nodes/MemedeckComfyNodes/nodes_output.py
Normal file
142
custom_nodes/MemedeckComfyNodes/nodes_output.py
Normal file
@ -0,0 +1,142 @@
|
||||
import folder_paths
|
||||
from comfy.cli_args import args
|
||||
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
|
||||
class MD_SaveAnimatedWEBP:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
methods = {"default": 4, "fastest": 0, "slowest": 6}
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "memedeck_video"}),
|
||||
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"lossless": ("BOOLEAN", {"default": False}),
|
||||
"quality": ("INT", {"default": 90, "min": 0, "max": 100}),
|
||||
"method": (list(s.methods.keys()),),
|
||||
"crf": ("INT",),
|
||||
"motion_prompt": ("STRING", ),
|
||||
"negative_prompt": ("STRING", ),
|
||||
"img2vid_metadata": ("STRING", ),
|
||||
"sampler_metadata": ("STRING", ),
|
||||
},
|
||||
# "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_images"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "MemeDeck"
|
||||
|
||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, crf=None, motion_prompt=None, negative_prompt=None, img2vid_metadata=None, sampler_metadata=None):
|
||||
method = self.methods.get(method)
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
|
||||
results = list()
|
||||
|
||||
pil_images = []
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
|
||||
pil_images.append(img)
|
||||
|
||||
metadata = pil_images[0].getexif()
|
||||
num_frames = len(pil_images)
|
||||
|
||||
json_metadata = json.dumps({
|
||||
"crf": crf,
|
||||
"motion_prompt": motion_prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"img2vid_metadata": img2vid_metadata,
|
||||
"sampler_metadata": sampler_metadata,
|
||||
}, indent=4)
|
||||
|
||||
c = len(pil_images)
|
||||
for i in range(0, c, num_frames):
|
||||
file = f"{filename}_{counter:05}_.webp"
|
||||
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type,
|
||||
})
|
||||
counter += 1
|
||||
|
||||
animated = num_frames != 1
|
||||
|
||||
return { "ui": { "images": results, "animated": (animated,), "metadata": json_metadata } }
|
||||
|
||||
|
||||
class MD_SaveMP4:
|
||||
def __init__(self):
|
||||
# Get absolute path of the output directory
|
||||
self.output_dir = os.path.abspath("output/video_gen")
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
|
||||
methods = {"default": 4, "fastest": 0, "slowest": 6}
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"images": ("IMAGE", ),
|
||||
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
|
||||
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
|
||||
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_video"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "MemeDeck"
|
||||
|
||||
def save_video(self, images, fps, filename_prefix, quality, prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
|
||||
)
|
||||
results = list()
|
||||
video_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.mp4")
|
||||
|
||||
# Determine video resolution
|
||||
height, width = images[0].shape[1], images[0].shape[2]
|
||||
video_writer = cv2.VideoWriter(
|
||||
video_path,
|
||||
cv2.VideoWriter_fourcc(*'mp4v'),
|
||||
fps,
|
||||
(width, height)
|
||||
)
|
||||
|
||||
# Write each frame to the video
|
||||
for image in images:
|
||||
i = 255. * image.cpu().numpy()
|
||||
frame = np.clip(i, 0, 255).astype(np.uint8)
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV
|
||||
video_writer.write(frame)
|
||||
|
||||
video_writer.release()
|
||||
|
||||
results.append({
|
||||
"filename": os.path.basename(video_path),
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
|
||||
return {"ui": {"videos": results}}
|
276
custom_nodes/MemedeckComfyNodes/nodes_preprocessing.py
Normal file
276
custom_nodes/MemedeckComfyNodes/nodes_preprocessing.py
Normal file
@ -0,0 +1,276 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
import folder_paths
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
import cv2
|
||||
import io
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import subprocess
|
||||
import torchvision.transforms as transforms
|
||||
from .lib import image, utils
|
||||
from .lib.image import pil2tensor
|
||||
import os
|
||||
import logging
|
||||
|
||||
# setup logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class MD_LoadImageFromUrl:
|
||||
"""Load an image from the given URL"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"url": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "https://media.memedeck.xyz/memes/user:08bdc8ed_6015_44f2_9808_7cb54051c666/35c95dfd_b186_4a40_9ef1_ac770f453706.jpeg"
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "load"
|
||||
CATEGORY = "MemeDeck"
|
||||
|
||||
def load(self, url):
|
||||
# strip out any quote characters
|
||||
url = url.replace("'", "")
|
||||
url = url.replace('"', '')
|
||||
|
||||
img = Image.open(requests.get(url, stream=True).raw)
|
||||
img = ImageOps.exif_transpose(img)
|
||||
return (pil2tensor(img),)
|
||||
|
||||
class MD_ImageToMotionPrompt:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"Image": ("IMAGE", {}),
|
||||
"clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."}),
|
||||
"pre_prompt": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": False,
|
||||
"default": "masterpiece, 4k, HDR, cinematic,",
|
||||
},
|
||||
),
|
||||
"prompt": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "Respond in a single flowing paragraph. Start with main action in a single sentence. Then add specific details about movements and gestures. Then describe character/object appearances precisely. After that, specify camera angles and movements, static camera motion, or minimal camera motion. Then describe lighting and colors.\nNo more than 200 words.\nAdditional instructions:",
|
||||
},
|
||||
),
|
||||
"negative_prompt": (
|
||||
"STRING",
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, unnatural motion, fused fingers, extra limbs, floating away, bad anatomy, weird hand, ugly, disappearing objects, closed captions, cross-eyed",
|
||||
},
|
||||
),
|
||||
"max_tokens": ("INT", {"min": 1, "max": 2048, "default": 200}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
RETURN_TYPES = ("STRING", "STRING", "CONDITIONING", "CONDITIONING",)
|
||||
RETURN_NAMES = ("prompt_string", "negative_prompt", "positive_conditioning", "negative_conditioning")
|
||||
FUNCTION = "generate_completion"
|
||||
CATEGORY = "MemeDeck"
|
||||
|
||||
def generate_completion(
|
||||
self, pre_prompt: str, Image: torch.Tensor, clip, prompt: str, negative_prompt: str, max_tokens: int
|
||||
) -> Tuple[str]:
|
||||
# start a timer
|
||||
start_time = time.time()
|
||||
b64image = image.pil2base64(image.tensor2pil(Image))
|
||||
# change this to a endpoint on localhost:5010/inference that takes a json with the image and the prompt
|
||||
|
||||
response = requests.post("http://127.0.0.1:5010/inference", json={"image_url": f"data:image/jpeg;base64,{b64image}", "prompt": prompt})
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to generate completion: {response.text}")
|
||||
end_time = time.time()
|
||||
|
||||
logger.info(f"Motion prompt took: {end_time - start_time} seconds")
|
||||
full_prompt = f"{pre_prompt}\n{response.json()['result']}"
|
||||
|
||||
pos_tokens = clip.tokenize(full_prompt)
|
||||
pos_output = clip.encode_from_tokens(pos_tokens, return_pooled=True, return_dict=True)
|
||||
pos_cond = pos_output.pop("cond")
|
||||
|
||||
neg_tokens = clip.tokenize(negative_prompt)
|
||||
neg_output = clip.encode_from_tokens(neg_tokens, return_pooled=True, return_dict=True)
|
||||
neg_cond = neg_output.pop("cond")
|
||||
|
||||
return (full_prompt, negative_prompt, [[pos_cond, pos_output]], [[neg_cond, neg_output]])
|
||||
|
||||
|
||||
class MD_CompressAdjustNode:
|
||||
"""
|
||||
Detect compression level and adjust to desired CRF.
|
||||
"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"desired_crf": ("INT", {
|
||||
"default": 25,
|
||||
"min": 0,
|
||||
"max": 51,
|
||||
"step": 1
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "INT")
|
||||
RETURN_NAMES = ("adjusted_image", "crf")
|
||||
FUNCTION = "tensor_to_video_and_back"
|
||||
CATEGORY = "MemeDeck"
|
||||
|
||||
def tensor_to_int(self,tensor, bits):
|
||||
tensor = tensor.cpu().numpy() * (2**bits-1)
|
||||
return np.clip(tensor, 0, (2**bits-1))
|
||||
|
||||
def tensor_to_bytes(self, tensor):
|
||||
return self.tensor_to_int(tensor, 8).astype(np.uint8)
|
||||
|
||||
def ffmpeg_process(self, args, file_path, env):
|
||||
res = None
|
||||
frame_data = yield
|
||||
total_frames_output = 0
|
||||
if res != b'':
|
||||
with subprocess.Popen(args + [file_path], stderr=subprocess.PIPE,
|
||||
stdin=subprocess.PIPE, env=env) as proc:
|
||||
try:
|
||||
while frame_data is not None:
|
||||
proc.stdin.write(frame_data)
|
||||
frame_data = yield
|
||||
total_frames_output+=1
|
||||
proc.stdin.flush()
|
||||
proc.stdin.close()
|
||||
res = proc.stderr.read()
|
||||
except BrokenPipeError as e:
|
||||
res = proc.stderr.read()
|
||||
raise Exception("An error occurred in the ffmpeg subprocess:\n" \
|
||||
+ res.decode("utf-8"))
|
||||
yield total_frames_output
|
||||
if len(res) > 0:
|
||||
print(res.decode("utf-8"), end="", file=sys.stderr)
|
||||
|
||||
def detect_image_clarity(self, image):
|
||||
# detect the clarity of the image
|
||||
# return a score between 0 and 100
|
||||
# 0 is the lowest clarity
|
||||
# 100 is the highest clarity
|
||||
return 100
|
||||
|
||||
def tensor_to_video_and_back(self, image, desired_crf=30):
|
||||
temp_dir = "temp_video"
|
||||
filename = f"frame_{time.time()}".split('.')[0]
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Convert single image to list if necessary
|
||||
if len(image.shape) == 3:
|
||||
image = [image]
|
||||
|
||||
first_image = image[0]
|
||||
|
||||
has_alpha = first_image.shape[-1] == 4
|
||||
dim_alignment = 8
|
||||
if (first_image.shape[1] % dim_alignment) or (first_image.shape[0] % dim_alignment):
|
||||
# pad the image to the nearest multiple of 8
|
||||
to_pad = (-first_image.shape[1] % dim_alignment,
|
||||
-first_image.shape[0] % dim_alignment)
|
||||
padding = (to_pad[0]//2, to_pad[0] - to_pad[0]//2,
|
||||
to_pad[1]//2, to_pad[1] - to_pad[1]//2)
|
||||
padfunc = torch.nn.ReplicationPad2d(padding)
|
||||
def pad(image):
|
||||
image = image.permute((2,0,1))#HWC to CHW
|
||||
padded = padfunc(image.to(dtype=torch.float32))
|
||||
return padded.permute((1,2,0))
|
||||
# pad single image
|
||||
first_image = pad(first_image)
|
||||
new_dims = (-first_image.shape[1] % dim_alignment + first_image.shape[1],
|
||||
-first_image.shape[0] % dim_alignment + first_image.shape[0])
|
||||
dimensions = f"{new_dims[0]}x{new_dims[1]}"
|
||||
logger.warn("Output images were not of valid resolution and have had padding applied")
|
||||
else:
|
||||
dimensions = f"{first_image.shape[1]}x{first_image.shape[0]}"
|
||||
|
||||
first_image_bytes = self.tensor_to_bytes(first_image).tobytes()
|
||||
|
||||
if has_alpha:
|
||||
i_pix_fmt = 'rgba'
|
||||
else:
|
||||
i_pix_fmt = 'rgb24'
|
||||
|
||||
# default bitrate and frame rate
|
||||
frame_rate = 25
|
||||
args = [
|
||||
utils.ffmpeg_path,
|
||||
"-v", "error",
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", i_pix_fmt,
|
||||
"-s", dimensions,
|
||||
"-r", str(frame_rate),
|
||||
"-i", "-",
|
||||
"-y",
|
||||
"-c:v", "libx264",
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-crf", str(desired_crf),
|
||||
|
||||
]
|
||||
|
||||
video_path = os.path.abspath(str(Path(temp_dir) / f"{filename}.mp4"))
|
||||
env = os.environ.copy()
|
||||
output_process = self.ffmpeg_process(args, video_path, env)
|
||||
|
||||
# Proceed to first yield
|
||||
output_process.send(None)
|
||||
output_process.send(first_image_bytes)
|
||||
try:
|
||||
output_process.send(None) # Signal end of input
|
||||
next(output_process) # Get the final yield
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"Video file not created at {video_path}")
|
||||
|
||||
# load the video h264 codec
|
||||
video = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
|
||||
if not video.isOpened():
|
||||
raise RuntimeError(f"Failed to open video file: {video_path}")
|
||||
|
||||
# read the first frame
|
||||
ret, frame = video.read()
|
||||
if not ret:
|
||||
raise RuntimeError("Failed to read frame from video")
|
||||
|
||||
video.release()
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
try:
|
||||
os.remove(video_path)
|
||||
except OSError as e:
|
||||
print(f"Warning: Could not remove temporary file {video_path}: {e}")
|
||||
|
||||
# convert the frame to a PIL image for ComfyUI
|
||||
frame = Image.fromarray(frame)
|
||||
frame_tensor = pil2tensor(frame)
|
||||
return (frame_tensor, desired_crf)
|
6
custom_nodes/MemedeckComfyNodes/requirements.txt
Normal file
6
custom_nodes/MemedeckComfyNodes/requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
requests
|
||||
numpy
|
||||
torch
|
||||
Pillow
|
||||
opencv-python
|
||||
torchvision
|
155
custom_nodes/example_node.py.example
Normal file
155
custom_nodes/example_node.py.example
Normal file
@ -0,0 +1,155 @@
|
||||
class Example:
|
||||
"""
|
||||
A example node
|
||||
|
||||
Class methods
|
||||
-------------
|
||||
INPUT_TYPES (dict):
|
||||
Tell the main program input parameters of nodes.
|
||||
IS_CHANGED:
|
||||
optional method to control when the node is re executed.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
RETURN_TYPES (`tuple`):
|
||||
The type of each element in the output tuple.
|
||||
RETURN_NAMES (`tuple`):
|
||||
Optional: The name of each output in the output tuple.
|
||||
FUNCTION (`str`):
|
||||
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
|
||||
OUTPUT_NODE ([`bool`]):
|
||||
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
|
||||
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
|
||||
Assumed to be False if not present.
|
||||
CATEGORY (`str`):
|
||||
The category the node should appear in the UI.
|
||||
DEPRECATED (`bool`):
|
||||
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
|
||||
functional in existing workflows that use them.
|
||||
EXPERIMENTAL (`bool`):
|
||||
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
|
||||
significant changes or removal in future versions. Use with caution in production workflows.
|
||||
execute(s) -> tuple || None:
|
||||
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
|
||||
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
"""
|
||||
Return a dictionary which contains config for all input fields.
|
||||
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
|
||||
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
|
||||
The type can be a list for selection.
|
||||
|
||||
Returns: `dict`:
|
||||
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
|
||||
- Value input_fields (`dict`): Contains input fields config:
|
||||
* Key field_name (`string`): Name of a entry-point method's argument
|
||||
* Value field_config (`tuple`):
|
||||
+ First value is a string indicate the type of field or a list for selection.
|
||||
+ Second value is a config for type "INT", "STRING" or "FLOAT".
|
||||
"""
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"int_field": ("INT", {
|
||||
"default": 0,
|
||||
"min": 0, #Minimum value
|
||||
"max": 4096, #Maximum value
|
||||
"step": 64, #Slider's step
|
||||
"display": "number", # Cosmetic only: display as "number" or "slider"
|
||||
"lazy": True # Will only be evaluated if check_lazy_status requires it
|
||||
}),
|
||||
"float_field": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
"min": 0.0,
|
||||
"max": 10.0,
|
||||
"step": 0.01,
|
||||
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
|
||||
"display": "number",
|
||||
"lazy": True
|
||||
}),
|
||||
"print_to_screen": (["enable", "disable"],),
|
||||
"string_field": ("STRING", {
|
||||
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||
"default": "Hello World!",
|
||||
"lazy": True
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
#RETURN_NAMES = ("image_output_name",)
|
||||
|
||||
FUNCTION = "test"
|
||||
|
||||
#OUTPUT_NODE = False
|
||||
|
||||
CATEGORY = "Example"
|
||||
|
||||
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
|
||||
"""
|
||||
Return a list of input names that need to be evaluated.
|
||||
|
||||
This function will be called if there are any lazy inputs which have not yet been
|
||||
evaluated. As long as you return at least one field which has not yet been evaluated
|
||||
(and more exist), this function will be called again once the value of the requested
|
||||
field is available.
|
||||
|
||||
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
|
||||
inputs will have the value None.
|
||||
"""
|
||||
if print_to_screen == "enable":
|
||||
return ["int_field", "float_field", "string_field"]
|
||||
else:
|
||||
return []
|
||||
|
||||
def test(self, image, string_field, int_field, float_field, print_to_screen):
|
||||
if print_to_screen == "enable":
|
||||
print(f"""Your input contains:
|
||||
string_field aka input text: {string_field}
|
||||
int_field: {int_field}
|
||||
float_field: {float_field}
|
||||
""")
|
||||
#do some processing on the image, in this example I just invert it
|
||||
image = 1.0 - image
|
||||
return (image,)
|
||||
|
||||
"""
|
||||
The node will always be re executed if any of the inputs change but
|
||||
this method can be used to force the node to execute again even when the inputs don't change.
|
||||
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
|
||||
executed, if it is different the node will be executed again.
|
||||
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
|
||||
changes between executions the LoadImage node is executed again.
|
||||
"""
|
||||
#@classmethod
|
||||
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
|
||||
# return ""
|
||||
|
||||
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
|
||||
# WEB_DIRECTORY = "./somejs"
|
||||
|
||||
|
||||
# Add custom API routes, using router
|
||||
from aiohttp import web
|
||||
from server import PromptServer
|
||||
|
||||
@PromptServer.instance.routes.get("/hello")
|
||||
async def get_hello(request):
|
||||
return web.json_response("hello")
|
||||
|
||||
|
||||
# A dictionary that contains all nodes you want to export with their names
|
||||
# NOTE: names should be globally unique
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"Example": Example
|
||||
}
|
||||
|
||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"Example": "Example Node"
|
||||
}
|
Binary file not shown.
Before Width: | Height: | Size: 8.4 KiB |
480
memedeck.py
480
memedeck.py
@ -5,7 +5,6 @@ import logging
|
||||
import uuid
|
||||
from PIL import Image, ImageOps
|
||||
from functools import partial
|
||||
|
||||
import pika
|
||||
import json
|
||||
|
||||
@ -63,6 +62,13 @@ class MemedeckWorker:
|
||||
self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383'
|
||||
|
||||
self.training_only = os.getenv('TRAINING_ONLY') or False
|
||||
self.video_gen_only = False
|
||||
|
||||
self.azure_storage = MemedeckAzureStorage()
|
||||
|
||||
if self.queue_name == 'video-gen-queue':
|
||||
print(f"[memedeck]: video gen only mode enabled")
|
||||
self.video_gen_only = True
|
||||
|
||||
if self.training_only:
|
||||
self.queue_name = 'training-queue'
|
||||
@ -145,10 +151,35 @@ class MemedeckWorker:
|
||||
valid = self.validate_prompt(prompt)
|
||||
|
||||
routing_key = method.routing_key
|
||||
workflow = 'training' if self.training_only else 'faceswap' if routing_key == 'faceswap-queue' else 'generation'
|
||||
workflow = 'faceswap' if routing_key == 'faceswap-queue' else 'generation'
|
||||
user_id = None
|
||||
|
||||
if self.video_gen_only:
|
||||
workflow = 'video_gen'
|
||||
user_id = payload["user_id"]
|
||||
|
||||
if self.training_only:
|
||||
workflow = 'training'
|
||||
|
||||
end_node_id = None
|
||||
video_source_image_id = None
|
||||
|
||||
# Find the end_node_id
|
||||
if not self.video_gen_only and not self.training_only:
|
||||
for node in prompt:
|
||||
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
|
||||
end_node_id = node
|
||||
break
|
||||
elif self.video_gen_only:
|
||||
end_node_id = payload['end_node_id']
|
||||
video_source_image_id = payload['image_id']
|
||||
elif self.training_only:
|
||||
end_node_id = "130"
|
||||
|
||||
self.logger.info(f"[memedeck]: end_node_id: {end_node_id}")
|
||||
|
||||
# Prepare task_info
|
||||
prompt_id = str(uuid.uuid4())
|
||||
prompt_id = str(uuid.uuid4()).replace("-", "_")
|
||||
outputs_to_execute = valid[2]
|
||||
task_info = {
|
||||
"workflow": workflow,
|
||||
@ -157,7 +188,7 @@ class MemedeckWorker:
|
||||
"outputs_to_execute": outputs_to_execute,
|
||||
"client_id": "memedeck-1",
|
||||
"is_memedeck": True,
|
||||
"end_node_id": "130" if self.training_only else None,
|
||||
"end_node_id": end_node_id,
|
||||
"ws_id": payload["source_ws_id"],
|
||||
"context": payload["req_ctx"] if "req_ctx" in payload else {},
|
||||
"current_node": None,
|
||||
@ -181,15 +212,12 @@ class MemedeckWorker:
|
||||
"59": "loop-3",
|
||||
"60": "loop-4",
|
||||
},
|
||||
"training_filename": payload['filename'] if 'filename' in payload else None
|
||||
"training_filename": payload['filename'] if 'filename' in payload else None,
|
||||
# video data
|
||||
"image_id": video_source_image_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
# Find the end_node_id
|
||||
for node in prompt:
|
||||
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
|
||||
task_info['end_node_id'] = node
|
||||
break
|
||||
|
||||
if valid[0]:
|
||||
# Enqueue the task into the internal job queue
|
||||
self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
|
||||
@ -218,7 +246,7 @@ class MemedeckWorker:
|
||||
'context': task_info['context']
|
||||
}, task_info['outputs_to_execute']))
|
||||
|
||||
if task_info['workflow'] != 'training':
|
||||
if task_info['workflow'] != 'training' and task_info['workflow'] != 'video_gen':
|
||||
if 'faceswap_strength' not in task_info['context']['prompt_config']:
|
||||
# pretty print the prompt config
|
||||
self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}")
|
||||
@ -279,6 +307,8 @@ class MemedeckWorker:
|
||||
if task['workflow'] == 'training':
|
||||
return await self.handle_training_send(event=event, task=task, sid=sid, data=data)
|
||||
|
||||
if task['workflow'] == 'video_gen':
|
||||
return await self.handle_video_gen_send(event=event, task=task, sid=sid, data=data)
|
||||
|
||||
# this logic is for generation and faceswap (todo move to separate function)
|
||||
if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
@ -340,7 +370,6 @@ class MemedeckWorker:
|
||||
# if event == "progress":
|
||||
# self.logger.info(f"[memedeck]: training progress: {data}")
|
||||
if event == "executing":
|
||||
|
||||
training_progress = task['training_loop'][data['node']] if data['node'] in task['training_loop'] else None
|
||||
status = task['training_status'][data['node']] if data['node'] in task['training_status'] else None
|
||||
|
||||
@ -362,9 +391,54 @@ class MemedeckWorker:
|
||||
})
|
||||
# training task is done
|
||||
del self.tasks_by_ws_id[sid]
|
||||
|
||||
async def handle_video_gen_send(self, event, task, sid, data):
|
||||
if event == "progress":
|
||||
if data['max'] > 1:
|
||||
progress = data['value']
|
||||
max_progress = data['max']
|
||||
# calculate the percentage
|
||||
percentage = (progress / max_progress)
|
||||
await self.send_to_api({
|
||||
"ws_id": sid,
|
||||
"image_id": task['image_id'],
|
||||
"user_id": task['user_id'],
|
||||
"status": "generating",
|
||||
"progress": percentage * 0.9 # 90% of the progress is the gen step, 10% is the video encode step
|
||||
})
|
||||
|
||||
if event == "executed":
|
||||
if data['node'] == task['end_node_id']:
|
||||
filename = data['output']['images'][0]['filename']
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(current_dir, "output", filename)
|
||||
blob_name = f"{task['user_id']}/video_gen/video_{task['image_id'].replace('image:', '')}_{task['prompt_id']}.webp"
|
||||
|
||||
# TODO: take the file path and upload to azure blob storage
|
||||
# load image bytes
|
||||
with open(file_path, "rb") as image_file:
|
||||
image_bytes = image_file.read()
|
||||
|
||||
self.logger.info(f"[memedeck]: video gen completed for {sid}, file={file_path}, blob={blob_name}")
|
||||
|
||||
url = await self.azure_storage.save_image(blob_name, "image/webp", image_bytes)
|
||||
|
||||
self.logger.info(f"[memedeck]: video gen completed for {sid}, {url}")
|
||||
await self.send_to_api({
|
||||
"ws_id": sid,
|
||||
"progress": 1.0,
|
||||
"image_id": task['image_id'],
|
||||
"user_id": task['user_id'],
|
||||
"status": "completed",
|
||||
"output_video_url": url
|
||||
})
|
||||
# video gen task is done
|
||||
del self.tasks_by_ws_id[sid]
|
||||
|
||||
|
||||
async def send_preview(self, image_data, sid=None, progress=None, context=None, workflow=None):
|
||||
self.logger.info(f"[memedeck]: send_preview: {sid}")
|
||||
if sid is None:
|
||||
self.logger.warning("Received preview without sid")
|
||||
return
|
||||
@ -406,6 +480,8 @@ class MemedeckWorker:
|
||||
"context": context
|
||||
}
|
||||
|
||||
self.logger.info(f"[memedeck]: progress kind: {kind}")
|
||||
self.logger.info(f"[memedeck]: progress: {progress}")
|
||||
# set the kind to faceswap_generated if workflow is faceswap
|
||||
if workflow == 'faceswap':
|
||||
ai_queue_progress['kind'] = "faceswap_generated"
|
||||
@ -430,7 +506,13 @@ class MemedeckWorker:
|
||||
self.logger.error(f"[memedeck]: end_node_id is None for {ws_id}")
|
||||
return
|
||||
|
||||
api_endpoint = '/generation/update' if task['workflow'] != 'training' else '/training/update'
|
||||
api_endpoint = '/generation/update'
|
||||
if task['workflow'] == 'training':
|
||||
api_endpoint = '/training/update'
|
||||
|
||||
if task['workflow'] == 'video_gen':
|
||||
api_endpoint = '/generation/video/update'
|
||||
|
||||
# self.logger.info(f"[memedeck]: sending to api: {api_endpoint}")
|
||||
# self.logger.info(f"[memedeck]: data: {data}")
|
||||
try:
|
||||
@ -443,250 +525,188 @@ class MemedeckWorker:
|
||||
# --------------------------------------------------------------------------
|
||||
# MemedeckAzureStorage
|
||||
# --------------------------------------------------------------------------
|
||||
# from azure.storage.blob.aio import BlobClient, BlobServiceClient
|
||||
# from azure.storage.blob import ContentSettings
|
||||
# from typing import Optional, Tuple
|
||||
# import cairosvg
|
||||
from azure.storage.blob.aio import BlobClient, BlobServiceClient
|
||||
from azure.storage.blob import ContentSettings
|
||||
from typing import Optional, Tuple
|
||||
import cairosvg
|
||||
|
||||
# WATERMARK = '<svg width="256" height="256" viewBox="0 0 256 256" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M60.0859 196.8C65.9526 179.067 71.5526 161.667 76.8859 144.6C79.1526 137.4 81.4859 129.867 83.8859 122C86.2859 114.133 88.6859 106.333 91.0859 98.6C93.4859 90.8667 95.6859 83.4 97.6859 76.2C99.8193 69 101.686 62.3333 103.286 56.2C110.619 56.2 117.553 55.8 124.086 55C130.619 54.2 137.686 53.4667 145.286 52.8C144.886 55.7333 144.419 59.0667 143.886 62.8C143.486 66.4 142.953 70.2 142.286 74.2C141.753 78.2 141.153 82.3333 140.486 86.6C139.819 90.8667 139.019 96.3333 138.086 103C137.153 109.667 135.886 118 134.286 128H136.886C140.753 117.867 143.953 109.467 146.486 102.8C149.019 96 151.086 90.4667 152.686 86.2C154.286 81.9333 155.886 77.8 157.486 73.8C159.219 69.6667 160.819 65.8 162.286 62.2C163.886 58.4667 165.353 55.2 166.686 52.4C170.019 52.1333 173.153 51.8 176.086 51.4C179.019 51 181.953 50.6 184.886 50.2C187.819 49.6667 190.753 49.2 193.686 48.8C196.753 48.2667 200.086 47.6667 203.686 47C202.353 54.7333 201.086 62.6667 199.886 70.8C198.686 78.9333 197.619 87.0667 196.686 95.2C195.753 103.2 194.819 111.133 193.886 119C193.086 126.867 192.353 134.333 191.686 141.4C190.086 157.933 188.686 174.067 187.486 189.8L152.686 196C152.686 195.333 152.753 193.533 152.886 190.6C153.153 187.667 153.419 184.067 153.686 179.8C154.086 175.533 154.553 170.8 155.086 165.6C155.753 160.4 156.353 155.2 156.886 150C157.553 144.8 158.219 139.8 158.886 135C159.553 130.067 160.219 125.867 160.886 122.4H159.086C157.219 128 155.153 133.933 152.886 140.2C150.619 146.333 148.286 152.6 145.886 159C143.619 165.4 141.353 171.667 139.086 177.8C136.819 183.933 134.819 189.8 133.086 195.4C128.419 195.533 124.419 195.733 121.086 196C117.753 196.133 113.886 196.333 109.486 196.6L115.886 122.4H112.886C112.619 124.133 111.953 127.067 110.886 131.2C109.819 135.2 108.553 139.867 107.086 145.2C105.753 150.4 104.286 155.867 102.686 161.6C101.086 167.2 99.5526 172.467 98.0859 177.4C96.7526 182.2 95.6193 186.2 94.6859 189.4C93.7526 192.467 93.2193 194.2 93.0859 194.6L60.0859 196.8Z" fill="white"/></svg>'
|
||||
# WATERMARK_SIZE = 40
|
||||
WATERMARK = '<svg width="256" height="256" viewBox="0 0 256 256" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M60.0859 196.8C65.9526 179.067 71.5526 161.667 76.8859 144.6C79.1526 137.4 81.4859 129.867 83.8859 122C86.2859 114.133 88.6859 106.333 91.0859 98.6C93.4859 90.8667 95.6859 83.4 97.6859 76.2C99.8193 69 101.686 62.3333 103.286 56.2C110.619 56.2 117.553 55.8 124.086 55C130.619 54.2 137.686 53.4667 145.286 52.8C144.886 55.7333 144.419 59.0667 143.886 62.8C143.486 66.4 142.953 70.2 142.286 74.2C141.753 78.2 141.153 82.3333 140.486 86.6C139.819 90.8667 139.019 96.3333 138.086 103C137.153 109.667 135.886 118 134.286 128H136.886C140.753 117.867 143.953 109.467 146.486 102.8C149.019 96 151.086 90.4667 152.686 86.2C154.286 81.9333 155.886 77.8 157.486 73.8C159.219 69.6667 160.819 65.8 162.286 62.2C163.886 58.4667 165.353 55.2 166.686 52.4C170.019 52.1333 173.153 51.8 176.086 51.4C179.019 51 181.953 50.6 184.886 50.2C187.819 49.6667 190.753 49.2 193.686 48.8C196.753 48.2667 200.086 47.6667 203.686 47C202.353 54.7333 201.086 62.6667 199.886 70.8C198.686 78.9333 197.619 87.0667 196.686 95.2C195.753 103.2 194.819 111.133 193.886 119C193.086 126.867 192.353 134.333 191.686 141.4C190.086 157.933 188.686 174.067 187.486 189.8L152.686 196C152.686 195.333 152.753 193.533 152.886 190.6C153.153 187.667 153.419 184.067 153.686 179.8C154.086 175.533 154.553 170.8 155.086 165.6C155.753 160.4 156.353 155.2 156.886 150C157.553 144.8 158.219 139.8 158.886 135C159.553 130.067 160.219 125.867 160.886 122.4H159.086C157.219 128 155.153 133.933 152.886 140.2C150.619 146.333 148.286 152.6 145.886 159C143.619 165.4 141.353 171.667 139.086 177.8C136.819 183.933 134.819 189.8 133.086 195.4C128.419 195.533 124.419 195.733 121.086 196C117.753 196.133 113.886 196.333 109.486 196.6L115.886 122.4H112.886C112.619 124.133 111.953 127.067 110.886 131.2C109.819 135.2 108.553 139.867 107.086 145.2C105.753 150.4 104.286 155.867 102.686 161.6C101.086 167.2 99.5526 172.467 98.0859 177.4C96.7526 182.2 95.6193 186.2 94.6859 189.4C93.7526 192.467 93.2193 194.2 93.0859 194.6L60.0859 196.8Z" fill="white"/></svg>'
|
||||
WATERMARK_SIZE = 40
|
||||
|
||||
# class MemedeckAzureStorage:
|
||||
# def __init__(self, connection_string):
|
||||
# # get environment variables
|
||||
# self.storage_account = os.getenv('STORAGE_ACCOUNT')
|
||||
# self.storage_access_key = os.getenv('STORAGE_ACCESS_KEY')
|
||||
# self.storage_container = os.getenv('STORAGE_CONTAINER')
|
||||
# self.logger = logging.getLogger(__name__)
|
||||
class MemedeckAzureStorage:
|
||||
def __init__(self):
|
||||
# get environment variables
|
||||
self.account = os.getenv('STORAGE_ACCOUNT')
|
||||
self.access_key = os.getenv('STORAGE_ACCESS_KEY')
|
||||
self.container = os.getenv('STORAGE_CONTAINER')
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# self.blob_service_client = BlobServiceClient.from_connection_string(conn_str=connection_string)
|
||||
if not all([self.account, self.access_key, self.container]):
|
||||
raise EnvironmentError("Missing STORAGE_ACCOUNT, STORAGE_ACCESS_KEY, or STORAGE_CONTAINER environment variables")
|
||||
|
||||
# async def upload_image(
|
||||
# self,
|
||||
# by: str,
|
||||
# image_id: str,
|
||||
# source_url: Optional[str],
|
||||
# bytes_data: Optional[bytes],
|
||||
# filetype: Optional[str],
|
||||
# ) -> Tuple[str, Tuple[int, int]]:
|
||||
# """
|
||||
# Uploads an image to Azure Blob Storage.
|
||||
# Initialize BlobServiceClient
|
||||
self.blob_service_client = BlobServiceClient(
|
||||
account_url=f"https://{self.account}.blob.core.windows.net",
|
||||
credential=self.access_key
|
||||
)
|
||||
self.logger.info(f"[memedeck]: Azure Storage connected.")
|
||||
|
||||
# Args:
|
||||
# by (str): Identifier for the uploader.
|
||||
# image_id (str): Unique identifier for the image.
|
||||
# source_url (Optional[str]): URL to fetch the image from.
|
||||
# bytes_data (Optional[bytes]): Image data in bytes.
|
||||
# filetype (Optional[str]): Desired file type (e.g., 'jpeg', 'png').
|
||||
async def save_image(
|
||||
self,
|
||||
blob_name: str,
|
||||
content_type: str,
|
||||
bytes_data: bytes
|
||||
) -> str:
|
||||
"""
|
||||
Saves image bytes to Azure Blob Storage.
|
||||
|
||||
# Returns:
|
||||
# Tuple[str, Tuple[int, int]]: URL of the uploaded image and its dimensions.
|
||||
# """
|
||||
# # Retrieve image bytes either from the provided bytes_data or by fetching from source_url
|
||||
# if source_url is None:
|
||||
# if bytes_data is None:
|
||||
# raise ValueError("Could not get image bytes")
|
||||
# image_bytes = bytes_data
|
||||
# else:
|
||||
# self.logger.info(f"Requesting image from URL: {source_url}")
|
||||
# async with aiohttp.ClientSession() as session:
|
||||
# try:
|
||||
# async with session.get(source_url) as response:
|
||||
# if response.status != 200:
|
||||
# raise Exception(f"Failed to fetch image, status code {response.status}")
|
||||
# image_bytes = await response.read()
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Error fetching image from URL: {e}")
|
||||
Args:
|
||||
blob_name (str): Name of the blob in Azure Storage.
|
||||
content_type (str): MIME type of the content.
|
||||
bytes_data (bytes): Image data in bytes.
|
||||
|
||||
# # Open image using Pillow to get dimensions and format
|
||||
# try:
|
||||
# img = Image.open(BytesIO(image_bytes))
|
||||
# width, height = img.size
|
||||
# inferred_filetype = img.format.lower()
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to decode image: {e}")
|
||||
Returns:
|
||||
str: URL of the uploaded blob.
|
||||
"""
|
||||
|
||||
# # Determine the final file type
|
||||
# final_filetype = filetype.lower() if filetype else inferred_filetype
|
||||
blob_client = self.blob_service_client.get_blob_client(container=self.container, blob=blob_name)
|
||||
|
||||
# # Construct the blob name
|
||||
# blob_name = f"{by}/{image_id.replace('image:', '')}.{final_filetype}"
|
||||
# Upload the blob
|
||||
try:
|
||||
await blob_client.upload_blob(
|
||||
bytes_data,
|
||||
overwrite=True,
|
||||
content_settings=ContentSettings(content_type=content_type)
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to upload blob: {e}")
|
||||
|
||||
# close the blob client
|
||||
blob_client.close()
|
||||
|
||||
# # Upload the image to Azure Blob Storage
|
||||
# try:
|
||||
# image_url = await self.save_image(blob_name, img.format, image_bytes)
|
||||
# return image_url, (width, height)
|
||||
# except Exception as e:
|
||||
# self.logger.error(f"Trouble saving image: {e}")
|
||||
# raise Exception(f"Trouble saving image: {e}")
|
||||
# Construct and return the blob URL
|
||||
blob_url = f"https://media.memedeck.xyz/{self.container}/{blob_name}"
|
||||
return blob_url
|
||||
|
||||
# async def save_image(
|
||||
# self,
|
||||
# blob_name: str,
|
||||
# content_type: str,
|
||||
# bytes_data: bytes
|
||||
# ) -> str:
|
||||
# """
|
||||
# Saves image bytes to Azure Blob Storage.
|
||||
# async def add_watermark(
|
||||
# self,
|
||||
# base_blob_name: str,
|
||||
# base_image: bytes
|
||||
# ) -> str:
|
||||
# """
|
||||
# Adds a watermark to the provided image and uploads the watermarked image.
|
||||
|
||||
# Args:
|
||||
# blob_name (str): Name of the blob in Azure Storage.
|
||||
# content_type (str): MIME type of the content.
|
||||
# bytes_data (bytes): Image data in bytes.
|
||||
# Args:
|
||||
# base_blob_name (str): Original blob name of the image.
|
||||
# base_image (bytes): Image data in bytes.
|
||||
|
||||
# Returns:
|
||||
# str: URL of the uploaded blob.
|
||||
# """
|
||||
# # Retrieve environment variables
|
||||
# account = os.getenv("STORAGE_ACCOUNT")
|
||||
# access_key = os.getenv("STORAGE_ACCESS_KEY")
|
||||
# container = os.getenv("STORAGE_CONTAINER")
|
||||
# Returns:
|
||||
# str: URL of the watermarked image.
|
||||
# """
|
||||
# # Load the input image
|
||||
# try:
|
||||
# img = Image.open(BytesIO(base_image)).convert("RGBA")
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to load image: {e}")
|
||||
|
||||
# if not all([account, access_key, container]):
|
||||
# raise EnvironmentError("Missing STORAGE_ACCOUNT, STORAGE_ACCESS_KEY, or STORAGE_CONTAINER environment variables")
|
||||
# # Calculate position for the watermark (bottom right corner with padding)
|
||||
# padding = 12
|
||||
# x = img.width - WATERMARK_SIZE - padding
|
||||
# y = img.height - WATERMARK_SIZE - padding
|
||||
|
||||
# # Initialize BlobServiceClient
|
||||
# blob_service_client = BlobServiceClient(
|
||||
# account_url=f"https://{account}.blob.core.windows.net",
|
||||
# credential=access_key
|
||||
# )
|
||||
# blob_client = blob_service_client.get_blob_client(container=container, blob=blob_name)
|
||||
# # Analyze background brightness where the watermark will be placed
|
||||
# background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE)
|
||||
# self.logger.info(f"Background brightness: {background_brightness}")
|
||||
|
||||
# # Upload the blob
|
||||
# try:
|
||||
# await blob_client.upload_blob(
|
||||
# bytes_data,
|
||||
# overwrite=True,
|
||||
# content_settings=ContentSettings(content_type=content_type)
|
||||
# )
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to upload blob: {e}")
|
||||
# # Render SVG watermark to PNG bytes using cairosvg
|
||||
# try:
|
||||
# watermark_png_bytes = cairosvg.svg2png(bytestring=WATERMARK.encode('utf-8'), output_width=WATERMARK_SIZE, output_height=WATERMARK_SIZE)
|
||||
# watermark = Image.open(BytesIO(watermark_png_bytes)).convert("RGBA")
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to render watermark SVG: {e}")
|
||||
|
||||
# self.logger.debug(f"Blob uploaded: name={blob_name}, content_type={content_type}")
|
||||
# # Determine watermark color based on background brightness
|
||||
# if background_brightness > 128:
|
||||
# # Dark watermark for light backgrounds
|
||||
# watermark_color = (0, 0, 0, int(255 * 0.65)) # Black with 65% opacity
|
||||
# else:
|
||||
# # Light watermark for dark backgrounds
|
||||
# watermark_color = (255, 255, 255, int(255 * 0.65)) # White with 65% opacity
|
||||
|
||||
# # Construct and return the blob URL
|
||||
# blob_url = f"https://media.memedeck.xyz//{container}/{blob_name}"
|
||||
# return blob_url
|
||||
# # Apply the watermark color by blending
|
||||
# solid_color = Image.new("RGBA", watermark.size, watermark_color)
|
||||
# watermark = Image.alpha_composite(watermark, solid_color)
|
||||
|
||||
# async def add_watermark(
|
||||
# self,
|
||||
# base_blob_name: str,
|
||||
# base_image: bytes
|
||||
# ) -> str:
|
||||
# """
|
||||
# Adds a watermark to the provided image and uploads the watermarked image.
|
||||
# # Overlay the watermark onto the original image
|
||||
# img.paste(watermark, (x, y), watermark)
|
||||
|
||||
# Args:
|
||||
# base_blob_name (str): Original blob name of the image.
|
||||
# base_image (bytes): Image data in bytes.
|
||||
# # Save the watermarked image to bytes
|
||||
# buffer = BytesIO()
|
||||
# img = img.convert("RGB") # Convert back to RGB for JPEG format
|
||||
# img.save(buffer, format="JPEG")
|
||||
# buffer.seek(0)
|
||||
# jpeg_bytes = buffer.read()
|
||||
|
||||
# Returns:
|
||||
# str: URL of the watermarked image.
|
||||
# """
|
||||
# # Load the input image
|
||||
# try:
|
||||
# img = Image.open(BytesIO(base_image)).convert("RGBA")
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to load image: {e}")
|
||||
# # Modify the blob name to include '_watermarked'
|
||||
# try:
|
||||
# if "memes/" in base_blob_name:
|
||||
# base_blob_name_right = base_blob_name.split("memes/", 1)[1]
|
||||
# else:
|
||||
# base_blob_name_right = base_blob_name
|
||||
# base_blob_name_split = base_blob_name_right.rsplit(".", 1)
|
||||
# base_blob_name_without_extension = base_blob_name_split[0]
|
||||
# extension = base_blob_name_split[1]
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to process blob name: {e}")
|
||||
|
||||
# # Calculate position for the watermark (bottom right corner with padding)
|
||||
# padding = 12
|
||||
# x = img.width - WATERMARK_SIZE - padding
|
||||
# y = img.height - WATERMARK_SIZE - padding
|
||||
# watermarked_blob_name = f"{base_blob_name_without_extension}_watermarked.{extension}"
|
||||
|
||||
# # Analyze background brightness where the watermark will be placed
|
||||
# background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE)
|
||||
# self.logger.info(f"Background brightness: {background_brightness}")
|
||||
# # Upload the watermarked image
|
||||
# try:
|
||||
# watermarked_blob_url = await self.save_image(
|
||||
# watermarked_blob_name,
|
||||
# "image/jpeg",
|
||||
# jpeg_bytes
|
||||
# )
|
||||
# return watermarked_blob_url
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to upload watermarked image: {e}")
|
||||
|
||||
# # Render SVG watermark to PNG bytes using cairosvg
|
||||
# try:
|
||||
# watermark_png_bytes = cairosvg.svg2png(bytestring=WATERMARK.encode('utf-8'), output_width=WATERMARK_SIZE, output_height=WATERMARK_SIZE)
|
||||
# watermark = Image.open(BytesIO(watermark_png_bytes)).convert("RGBA")
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to render watermark SVG: {e}")
|
||||
# def analyze_background_brightness(
|
||||
# self,
|
||||
# img: Image.Image,
|
||||
# x: int,
|
||||
# y: int,
|
||||
# size: int
|
||||
# ) -> int:
|
||||
# """
|
||||
# Analyzes the brightness of a specific region in the image.
|
||||
|
||||
# # Determine watermark color based on background brightness
|
||||
# if background_brightness > 128:
|
||||
# # Dark watermark for light backgrounds
|
||||
# watermark_color = (0, 0, 0, int(255 * 0.65)) # Black with 65% opacity
|
||||
# else:
|
||||
# # Light watermark for dark backgrounds
|
||||
# watermark_color = (255, 255, 255, int(255 * 0.65)) # White with 65% opacity
|
||||
# Args:
|
||||
# img (Image.Image): The image to analyze.
|
||||
# x (int): X-coordinate of the top-left corner of the region.
|
||||
# y (int): Y-coordinate of the top-left corner of the region.
|
||||
# size (int): Size of the square region to analyze.
|
||||
|
||||
# # Apply the watermark color by blending
|
||||
# solid_color = Image.new("RGBA", watermark.size, watermark_color)
|
||||
# watermark = Image.alpha_composite(watermark, solid_color)
|
||||
# Returns:
|
||||
# int: Average brightness (0-255) of the region.
|
||||
# """
|
||||
# # Crop the specified region
|
||||
# sub_image = img.crop((x, y, x + size, y + size)).convert("RGB")
|
||||
|
||||
# # Overlay the watermark onto the original image
|
||||
# img.paste(watermark, (x, y), watermark)
|
||||
# # Calculate average brightness using the luminance formula
|
||||
# total_brightness = 0
|
||||
# pixel_count = 0
|
||||
# for pixel in sub_image.getdata():
|
||||
# r, g, b = pixel
|
||||
# brightness = (r * 299 + g * 587 + b * 114) // 1000
|
||||
# total_brightness += brightness
|
||||
# pixel_count += 1
|
||||
|
||||
# # Save the watermarked image to bytes
|
||||
# buffer = BytesIO()
|
||||
# img = img.convert("RGB") # Convert back to RGB for JPEG format
|
||||
# img.save(buffer, format="JPEG")
|
||||
# buffer.seek(0)
|
||||
# jpeg_bytes = buffer.read()
|
||||
# if pixel_count == 0:
|
||||
# return 0
|
||||
|
||||
# # Modify the blob name to include '_watermarked'
|
||||
# try:
|
||||
# if "memes/" in base_blob_name:
|
||||
# base_blob_name_right = base_blob_name.split("memes/", 1)[1]
|
||||
# else:
|
||||
# base_blob_name_right = base_blob_name
|
||||
# base_blob_name_split = base_blob_name_right.rsplit(".", 1)
|
||||
# base_blob_name_without_extension = base_blob_name_split[0]
|
||||
# extension = base_blob_name_split[1]
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to process blob name: {e}")
|
||||
|
||||
# watermarked_blob_name = f"{base_blob_name_without_extension}_watermarked.{extension}"
|
||||
|
||||
# # Upload the watermarked image
|
||||
# try:
|
||||
# watermarked_blob_url = await self.save_image(
|
||||
# watermarked_blob_name,
|
||||
# "image/jpeg",
|
||||
# jpeg_bytes
|
||||
# )
|
||||
# return watermarked_blob_url
|
||||
# except Exception as e:
|
||||
# raise Exception(f"Failed to upload watermarked image: {e}")
|
||||
|
||||
# def analyze_background_brightness(
|
||||
# self,
|
||||
# img: Image.Image,
|
||||
# x: int,
|
||||
# y: int,
|
||||
# size: int
|
||||
# ) -> int:
|
||||
# """
|
||||
# Analyzes the brightness of a specific region in the image.
|
||||
|
||||
# Args:
|
||||
# img (Image.Image): The image to analyze.
|
||||
# x (int): X-coordinate of the top-left corner of the region.
|
||||
# y (int): Y-coordinate of the top-left corner of the region.
|
||||
# size (int): Size of the square region to analyze.
|
||||
|
||||
# Returns:
|
||||
# int: Average brightness (0-255) of the region.
|
||||
# """
|
||||
# # Crop the specified region
|
||||
# sub_image = img.crop((x, y, x + size, y + size)).convert("RGB")
|
||||
|
||||
# # Calculate average brightness using the luminance formula
|
||||
# total_brightness = 0
|
||||
# pixel_count = 0
|
||||
# for pixel in sub_image.getdata():
|
||||
# r, g, b = pixel
|
||||
# brightness = (r * 299 + g * 587 + b * 114) // 1000
|
||||
# total_brightness += brightness
|
||||
# pixel_count += 1
|
||||
|
||||
# if pixel_count == 0:
|
||||
# return 0
|
||||
|
||||
# average_brightness = total_brightness // pixel_count
|
||||
# return average_brightness
|
||||
# average_brightness = total_brightness // pixel_count
|
||||
# return average_brightness
|
||||
|
||||
|
||||
|
277
notebooks/memedeck-nodes-testing.ipynb
Normal file
277
notebooks/memedeck-nodes-testing.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -1,4 +0,0 @@
|
||||
# Config for testing nodes
|
||||
testing:
|
||||
custom_nodes: tests/inference/testing_nodes
|
||||
|
Loading…
Reference in New Issue
Block a user