adding custom_node permanently

This commit is contained in:
drunkplato 2024-12-11 20:36:06 +00:00 committed by Ubuntu
parent 497390aff6
commit e512458a79
18 changed files with 1829 additions and 238 deletions

17
.gitignore vendored
View File

@ -3,10 +3,18 @@ __pycache__/
/output/
/input/
!/input/example.png
!comfy/ldm/models/autoencoder.py
/models/
/temp/
/custom_nodes/
# Ignore everything in custom_nodes
/custom_nodes/*
!comfy/ldm/models/autoencoder.py
# Explicitly allow these files/directories
!custom_nodes/example_node.py.example
!custom_nodes/MemedeckComfyNodes/
!custom_nodes/MemedeckComfyNodes/**
extra_model_paths.yaml
/.vs
.vscode/
@ -26,12 +34,13 @@ web_custom_versions/
models
custom_nodes
models/
flux_loras/
custom_nodes/
pysssss-workflows/
comfy-venv
comfy_venv_3.11
models-2
models-2
!comfy/ldm/models/autoencoder.py

View File

@ -0,0 +1,2 @@
__pycache__/

View File

@ -0,0 +1,12 @@
# MemedeckComfyNodes
This is a collection of nodes for Memedeck ComfyUI.
## Nodes
- MD_LoadImageFromUrl
- MD_ImageToMotionPrompt
- MD_SaveAnimatedWEBP
- MD_LoadVideoModel
- MD_ImgToVideo
- MD_VideoToImg

View File

@ -0,0 +1,32 @@
from .nodes_preprocessing import MD_LoadImageFromUrl, MD_CompressAdjustNode, MD_ImageToMotionPrompt
from .nodes_model import MD_LoadVideoModel, MD_ImgToVideo, MD_VideoSampler
from .nodes_output import MD_SaveAnimatedWEBP, MD_SaveMP4
NODE_CLASS_MAPPINGS = {
# PREPROCESSING
"Memedeck_ImageToMotionPrompt": MD_ImageToMotionPrompt,
"Memedeck_CompressAdjustNode": MD_CompressAdjustNode,
"Memedeck_LoadImageFromUrl": MD_LoadImageFromUrl,
# MODEL NODES
"Memedeck_LoadVideoModel": MD_LoadVideoModel,
"Memedeck_ImgToVideo": MD_ImgToVideo,
"Memedeck_VideoSampler": MD_VideoSampler,
# POSTPROCESSING
"Memedeck_SaveMP4": MD_SaveMP4,
"Memedeck_SaveAnimatedWEBP": MD_SaveAnimatedWEBP
# "Memedeck_SaveAnimatedGIF": MD_SaveAnimatedGIF
}
NODE_DISPLAY_NAME_MAPPINGS = {
# PREPROCESSING
"Memedeck_ImageToMotionPrompt": "MemeDeck: Image To Motion Prompt",
"Memedeck_CompressAdjustNode": "MemeDeck: Compression Detector & Adjuster",
"Memedeck_LoadImageFromUrl": "MemeDeck: Load Image From URL",
# MODEL NODES
"Memedeck_LoadVideoModel": "MemeDeck: Load Video Model",
"Memedeck_VideoScheduler": "MemeDeck: Video Scheduler",
"Memedeck_ImgToVideo": "MemeDeck: Image To Video",
"Memedeck_VideoSampler": "MemeDeck: Video Sampler",
# POSTPROCESSING
"Memedeck_SaveMP4": "MemeDeck: Save MP4"
# "Memedeck_SaveAnimatedGIF": "MemeDeck: Save Animated GIF"
}

View File

@ -0,0 +1,32 @@
import base64
import PIL
import numpy as np
from PIL import Image
from torch import Tensor
import torch
def tensor2pil(image: Tensor) -> PIL.Image.Image:
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
def pil2base64(image: PIL.Image.Image) -> str:
from io import BytesIO
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def pil2tensor(images: Image.Image | list[Image.Image]) -> torch.Tensor:
"""Converts a PIL Image or a list of PIL Images to a tensor."""
def single_pil2tensor(image: Image.Image) -> torch.Tensor:
np_image = np.array(image).astype(np.float32) / 255.0
if np_image.ndim == 2: # Grayscale
return torch.from_numpy(np_image).unsqueeze(0) # (1, H, W)
else: # RGB or RGBA
return torch.from_numpy(np_image).unsqueeze(0) # (1, H, W, C)
if isinstance(images, Image.Image):
return single_pil2tensor(images)
else:
return torch.cat([single_pil2tensor(img) for img in images], dim=0)

View File

@ -0,0 +1,56 @@
import os
import shutil
import subprocess
def ffmpeg_suitability(path):
try:
version = subprocess.run([path, "-version"], check=True,
capture_output=True).stdout.decode("utf-8")
except:
return 0
score = 0
#rough layout of the importance of various features
simple_criterion = [("libvpx", 20),("264",10), ("265",3),
("svtav1",5),("libopus", 1)]
for criterion in simple_criterion:
if version.find(criterion[0]) >= 0:
score += criterion[1]
#obtain rough compile year from copyright information
copyright_index = version.find('2000-2')
if copyright_index >= 0:
copyright_year = version[copyright_index+6:copyright_index+9]
if copyright_year.isnumeric():
score += int(copyright_year)
return score
if "VHS_FORCE_FFMPEG_PATH" in os.environ:
ffmpeg_path = os.environ.get("VHS_FORCE_FFMPEG_PATH")
else:
ffmpeg_paths = []
try:
from imageio_ffmpeg import get_ffmpeg_exe
imageio_ffmpeg_path = get_ffmpeg_exe()
ffmpeg_paths.append(imageio_ffmpeg_path)
except:
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
raise
# logger.warn("Failed to import imageio_ffmpeg")
if "VHS_USE_IMAGEIO_FFMPEG" in os.environ:
ffmpeg_path = imageio_ffmpeg_path
else:
system_ffmpeg = shutil.which("ffmpeg")
if system_ffmpeg is not None:
ffmpeg_paths.append(system_ffmpeg)
if os.path.isfile("ffmpeg"):
ffmpeg_paths.append(os.path.abspath("ffmpeg"))
if os.path.isfile("ffmpeg.exe"):
ffmpeg_paths.append(os.path.abspath("ffmpeg.exe"))
if len(ffmpeg_paths) == 0:
# logger.error("No valid ffmpeg found.")
ffmpeg_path = None
elif len(ffmpeg_paths) == 1:
#Evaluation of suitability isn't required, can take sole option
#to reduce startup time
ffmpeg_path = ffmpeg_paths[0]
else:
ffmpeg_path = max(ffmpeg_paths, key=ffmpeg_suitability)

View File

@ -0,0 +1,186 @@
import torch
from torch import nn
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
import math
from comfy.ldm.lightricks.model import apply_rotary_emb, precompute_freqs_cis, LTXVModel, BasicTransformerBlock
class GenesisModifiedCrossAttention(nn.Module):
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
context = x if context is None else context
context_v = x if context is None else context
step = transformer_options.get('step', -1)
total_steps = transformer_options.get('total_steps', 0)
attn_bank = transformer_options.get('attn_bank', None)
sample_mode = transformer_options.get('sample_mode', None)
if attn_bank is not None and self.idx in attn_bank['block_map']:
len_conds = len(transformer_options['cond_or_uncond'])
pred_order = transformer_options['pred_order']
if sample_mode == 'forward' and total_steps-step-1 < attn_bank['save_steps']:
step_idx = f'{pred_order}_{total_steps-step-1}'
attn_bank['block_map'][self.idx][step_idx] = x.cpu()
elif sample_mode == 'reverse' and step < attn_bank['inject_steps']:
step_idx = f'{pred_order}_{step}'
inject_settings = attn_bank.get('inject_settings', {})
if len(inject_settings) > 0:
inj = attn_bank['block_map'][self.idx][step_idx].to(x.device).repeat(len_conds, 1, 1)
if 'q' in inject_settings:
x = inj
if 'k' in inject_settings:
context = inj
if 'v' in inject_settings:
context_v = inj
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context_v)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
alt_attn_fn = transformer_options.get('patches_replace', {}).get(f'layer', {}).get(('self_attn', self.idx), None)
if alt_attn_fn is not None:
out = alt_attn_fn(q,k,v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
elif mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
class GenesisModifiedBasicTransformerBlock(BasicTransformerBlock):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
return x
class GenesisModelModified(LTXVModel):
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latents={}, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
guiding_latents = transformer_options.get('patches', {}).get('guiding_latents', None)
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
ts = None
input_x = None
if guiding_latents is not None:
input_x = x.clone()
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
for guide in guiding_latents:
ts[:, :, guide.index] = 0.0
x[:,:,guide.index] = guide.latent[:,:,0]
timestep = self.patchifier.patchify(ts)
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe,
transformer_options=transformer_options
)
# 3. Output
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = self.proj_out(x)
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],
output_width=orig_shape[4],
output_num_frames=orig_shape[2],
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latents is not None:
for guide in guiding_latents:
x[:, :, guide.index] = (input_x[:, :, guide.index] - guide.latent[:, :, 0]) / input_ts[:, :, 0]
return x
def inject_model(diffusion_model):
diffusion_model.__class__ = GenesisModelModified
for idx, transformer_block in enumerate(diffusion_model.transformer_blocks):
transformer_block.__class__ = GenesisModifiedBasicTransformerBlock
transformer_block.idx = idx
transformer_block.attn1.__class__ = GenesisModifiedCrossAttention
transformer_block.attn1.idx = idx
return diffusion_model

View File

@ -0,0 +1,390 @@
import json
import math
from comfy_extras.nodes_custom_sampler import Noise_RandomNoise
import latent_preview
from .modules.video_model import inject_model
import folder_paths
import node_helpers
import torch
import comfy
class MD_LoadVideoModel:
"""
Loads the DIT model for video generation.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"chkpt_name": (folder_paths.get_filename_list("checkpoints"), {
"default": "genesis-dit-video-2b.safetensors",
"tooltip": "The name of the checkpoint (model) to load."
}),
"clip_name": (folder_paths.get_filename_list("text_encoders"), {
"default": "t5xxl_fp16.safetensors",
"tooltip": "The name of the clip (model) to load."
}),
},
}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
RETURN_NAMES = ("model", "clip", "vae")
FUNCTION = "load_model"
CATEGORY = "MemeDeck"
def load_model(self, chkpt_name, clip_name):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", chkpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
model = out[0]
vae = out[2]
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=8)
# modify model
model.model.diffusion_model = inject_model(model.model.diffusion_model)
return (model, clip, vae, )
class LatentGuide(torch.nn.Module):
def __init__(self, latent: torch.Tensor, index) -> None:
super().__init__()
self.index = index
self.register_buffer('latent', latent)
class MD_ImgToVideo:
"""
Sets the conditioning and dimensions for video generation.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"image": ("IMAGE",),
"width": ("INT", {
"default": 832,
"description": "The width of the video."
}),
"height": ("INT", {
"default": 832,
"description": "The height of the video."
}),
"length": ("INT", {
"default": 97,
"description": "The length of the video."
}),
"fps": ("INT", {
"default": 24,
"description": "The fps of the video."
}),
# LATENT GUIDE INPUTS
"add_latent_guide_index": ("INT", {
"default": 0,
"description": "The index of the latent to add to the guide."
}),
"add_latent_guide_insert": ("BOOLEAN", {
"default": False,
"description": "Whether to add the latent to the guide."
}),
# SCHEDULER INPUTS
"steps": ("INT", {
"default": 40,
"description": "Number of steps to generate the video."
}),
"max_shift": ("FLOAT", {
"default": 1.5,
"step": 0.01,
"description": "The maximum shift of the video."
}),
"base_shift": ("FLOAT", {
"default": 0.95,
"step": 0.01,
"description": "The base shift of the video."
}),
"stretch": ("BOOLEAN", {
"default": True,
"description": "Stretch the sigmas to be in the range [terminal, 1]."
}),
"terminal": ("FLOAT", {
"default": 0.1,
"description": "The terminal values of the sigmas after stretching."
}),
# ATTENTION OVERRIDE INPUTS
"attention_override": ("STRING", {
"default": 14,
"description": "The amount of attention to override the model with."
}),
"attention_adjustment_scale": ("FLOAT", {
"default": 1.0,
"description": "The scale of the attention adjustment."
}),
"attention_adjustment_rescale": ("FLOAT", {
"default": 0.5,
"description": "The scale of the attention adjustment."
}),
"attention_adjustment_cfg": ("FLOAT", {
"default": 3.0,
"description": "The scale of the attention adjustment."
}),
},
}
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING", "SIGMAS", "LATENT", "STRING")
RETURN_NAMES = ("model", "positive", "negative", "sigmas", "latent", "img2vid_metadata")
FUNCTION = "img_to_video"
CATEGORY = "MemeDeck"
def img_to_video(self, model, positive, negative, vae, image, width, height, length, fps, add_latent_guide_index, add_latent_guide_insert, steps, max_shift, base_shift, stretch, terminal, attention_override, attention_adjustment_scale, attention_adjustment_rescale, attention_adjustment_cfg):
batch_size = 1
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
latent[:, :, :t.shape[2]] = t
latent_samples = {"samples": latent}
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": fps})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": fps})
# 2. add latent guide
model, latent_updated = self.add_latent_guide(model, latent_samples, latent_samples, add_latent_guide_index, add_latent_guide_insert)
# 3. apply attention override
attn_override_layers = self.attention_override(attention_override)
model = self.apply_attention_override(model, attention_adjustment_scale, attention_adjustment_rescale, attention_adjustment_cfg, attn_override_layers)
# 5. configure scheduler
sigmas = self.get_sigmas(steps, max_shift, base_shift, stretch, terminal, latent_updated)
# all parameters starting with width, height, fps, crf, etc
img2vid_metadata = {
"width": width,
"height": height,
"length": length,
"fps": fps,
"steps": steps,
"max_shift": max_shift,
"base_shift": base_shift,
"stretch": stretch,
"terminal": terminal,
"attention_override": attention_override,
"attention_adjustment_scale": attention_adjustment_scale,
"attention_adjustment_rescale": attention_adjustment_rescale,
"attention_adjustment_cfg": attention_adjustment_cfg,
}
json_img2vid_metadata = json.dumps(img2vid_metadata)
return (model, positive, negative, sigmas, latent_updated, json_img2vid_metadata)
# -----------------------------
# Attention functions
# -----------------------------
# 1. Add latent guide
def add_latent_guide(self, model, latent, image_latent, index, insert):
image_latent = image_latent['samples']
latent = latent['samples'].clone()
# Convert negative index to positive
if insert:
# Handle insertion
if index == 0:
# Insert at beginning
latent = torch.cat([image_latent[:,:,0:1], latent], dim=2)
elif index >= latent.shape[2] or index < 0:
# Append to end
latent = torch.cat([latent, image_latent[:,:,0:1]], dim=2)
else:
# Insert in middle
latent = torch.cat([
latent[:,:,:index],
image_latent[:,:,0:1],
latent[:,:,index:]
], dim=2)
else:
# Original replacement behavior
latent[:,:,index] = image_latent[:,:,0]
model = model.clone()
guiding_latent = LatentGuide(image_latent, index)
model.set_model_patch(guiding_latent, 'guiding_latents')
return (model, {"samples": latent},)
# 2. Apply attention override
def is_integer(self, string):
try:
int(string)
return True
except ValueError:
return False
def attention_override(self, layers: str = "14"):
layers_map = set([])
for block in layers.split(','):
block = block.strip()
if self.is_integer(block):
layers_map.add(block)
return layers_map
def apply_attention_override(self, model, scale, rescale, cfg, attention_override: set):
m = model.clone()
def pag_fn(q, k,v, heads, attn_precision=None, transformer_options=None):
return v
def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
uncond_pred = args["uncond_denoised"]
len_conds = 1 if args.get('uncond', None) is None else 2
cond = args["cond"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
if scale == 0:
if len_conds == 1:
return cond_pred
return uncond_pred + (cond_pred - uncond_pred)
for block_idx in attention_override:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, pag_fn, f"layer", "self_attn", int(block_idx))
(perturbed,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
output = uncond_pred + cfg * (cond_pred - uncond_pred) \
+ scale * (cond_pred - perturbed)
if rescale > 0:
factor = cond_pred.std() / output.std()
factor = rescale * factor + (1 - rescale)
output = output * factor
return output
m.set_model_sampler_post_cfg_function(post_cfg_function)
return m
# -----------------------------
# Scheduler
# -----------------------------
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
if latent is None:
tokens = 4096
else:
tokens = math.prod(latent["samples"].shape[2:])
sigmas = torch.linspace(1.0, 0.0, steps + 1)
x1 = 1024
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = (tokens) * mm + b
power = 1
sigmas = torch.where(
sigmas != 0,
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
0,
)
# Stretch sigmas so that its final value matches the given terminal value.
if stretch:
non_zero_mask = sigmas != 0
non_zero_sigmas = sigmas[non_zero_mask]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z[-1] / (1.0 - terminal)
stretched = 1.0 - (one_minus_z / scale_factor)
sigmas[non_zero_mask] = stretched
return sigmas
KSAMPLER_NAMES = ["euler", "ddim", "euler_ancestral", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
"ipndm", "ipndm_v", "deis"]
class MD_VideoSampler:
"""
Samples the video.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"sigmas": ("SIGMAS",),
"latent_image": ("LATENT",),
"sampler": (KSAMPLER_NAMES, ),
"noise_seed": ("INT", {
"default": 42,
"description": "The seed of the noise."
}),
"cfg": ("FLOAT", {
"default": 5.0,
"min": 0.0,
"max": 30.0,
"step": 0.01,
"description": "The cfg of the video."
}),
},
}
RETURN_TYPES = ("LATENT", "LATENT", "STRING")
RETURN_NAMES = ("output", "denoised_output", "img2vid_metadata")
FUNCTION = "video_sampler"
CATEGORY = "MemeDeck"
def video_sampler(self, model, positive, negative, sigmas, latent_image, sampler, noise_seed, cfg):
latent = latent_image
latent_image = latent["samples"]
latent = latent.copy()
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
latent["samples"] = latent_image
sampler_name = sampler
noise = Noise_RandomNoise(noise_seed).generate_noise(latent)
sampler = comfy.samplers.sampler_object(sampler)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
x0_output = {}
callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
out_denoised = latent.copy()
out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
else:
out_denoised = out
sampler_metadata = {
"sampler": sampler_name,
"noise_seed": noise_seed,
"cfg": cfg,
}
json_sampler_metadata = json.dumps(sampler_metadata)
return (out, out_denoised, json_sampler_metadata)

View File

@ -0,0 +1,142 @@
import folder_paths
from comfy.cli_args import args
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
class MD_SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "memedeck_video"}),
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": False}),
"quality": ("INT", {"default": 90, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
"crf": ("INT",),
"motion_prompt": ("STRING", ),
"negative_prompt": ("STRING", ),
"img2vid_metadata": ("STRING", ),
"sampler_metadata": ("STRING", ),
},
# "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "MemeDeck"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, crf=None, motion_prompt=None, negative_prompt=None, img2vid_metadata=None, sampler_metadata=None):
method = self.methods.get(method)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = pil_images[0].getexif()
num_frames = len(pil_images)
json_metadata = json.dumps({
"crf": crf,
"motion_prompt": motion_prompt,
"negative_prompt": negative_prompt,
"img2vid_metadata": img2vid_metadata,
"sampler_metadata": sampler_metadata,
}, indent=4)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type,
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,), "metadata": json_metadata } }
class MD_SaveMP4:
def __init__(self):
# Get absolute path of the output directory
self.output_dir = os.path.abspath("output/video_gen")
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_video"
OUTPUT_NODE = True
CATEGORY = "MemeDeck"
def save_video(self, images, fps, filename_prefix, quality, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
)
results = list()
video_path = os.path.join(full_output_folder, f"{filename}_{counter:05}.mp4")
# Determine video resolution
height, width = images[0].shape[1], images[0].shape[2]
video_writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*'mp4v'),
fps,
(width, height)
)
# Write each frame to the video
for image in images:
i = 255. * image.cpu().numpy()
frame = np.clip(i, 0, 255).astype(np.uint8)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert RGB to BGR for OpenCV
video_writer.write(frame)
video_writer.release()
results.append({
"filename": os.path.basename(video_path),
"subfolder": subfolder,
"type": self.type
})
return {"ui": {"videos": results}}

View File

@ -0,0 +1,276 @@
from pathlib import Path
import sys
import time
from typing import Tuple
import requests
import folder_paths
import numpy as np
import torch
from PIL import Image, ImageOps
import cv2
import io
from typing import Tuple
import torch
import subprocess
import torchvision.transforms as transforms
from .lib import image, utils
from .lib.image import pil2tensor
import os
import logging
# setup logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class MD_LoadImageFromUrl:
"""Load an image from the given URL"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"url": (
"STRING",
{
"default": "https://media.memedeck.xyz/memes/user:08bdc8ed_6015_44f2_9808_7cb54051c666/35c95dfd_b186_4a40_9ef1_ac770f453706.jpeg"
},
),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "load"
CATEGORY = "MemeDeck"
def load(self, url):
# strip out any quote characters
url = url.replace("'", "")
url = url.replace('"', '')
img = Image.open(requests.get(url, stream=True).raw)
img = ImageOps.exif_transpose(img)
return (pil2tensor(img),)
class MD_ImageToMotionPrompt:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Image": ("IMAGE", {}),
"clip": ("CLIP", {"tooltip": "The CLIP model used for encoding the text."}),
"pre_prompt": (
"STRING",
{
"multiline": False,
"default": "masterpiece, 4k, HDR, cinematic,",
},
),
"prompt": (
"STRING",
{
"multiline": True,
"default": "Respond in a single flowing paragraph. Start with main action in a single sentence. Then add specific details about movements and gestures. Then describe character/object appearances precisely. After that, specify camera angles and movements, static camera motion, or minimal camera motion. Then describe lighting and colors.\nNo more than 200 words.\nAdditional instructions:",
},
),
"negative_prompt": (
"STRING",
{
"multiline": True,
"default": "low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, unnatural motion, fused fingers, extra limbs, floating away, bad anatomy, weird hand, ugly, disappearing objects, closed captions, cross-eyed",
},
),
"max_tokens": ("INT", {"min": 1, "max": 2048, "default": 200}),
}
}
RETURN_TYPES = ("STRING", "STRING", "CONDITIONING", "CONDITIONING",)
RETURN_NAMES = ("prompt_string", "negative_prompt", "positive_conditioning", "negative_conditioning")
FUNCTION = "generate_completion"
CATEGORY = "MemeDeck"
def generate_completion(
self, pre_prompt: str, Image: torch.Tensor, clip, prompt: str, negative_prompt: str, max_tokens: int
) -> Tuple[str]:
# start a timer
start_time = time.time()
b64image = image.pil2base64(image.tensor2pil(Image))
# change this to a endpoint on localhost:5010/inference that takes a json with the image and the prompt
response = requests.post("http://127.0.0.1:5010/inference", json={"image_url": f"data:image/jpeg;base64,{b64image}", "prompt": prompt})
if response.status_code != 200:
raise Exception(f"Failed to generate completion: {response.text}")
end_time = time.time()
logger.info(f"Motion prompt took: {end_time - start_time} seconds")
full_prompt = f"{pre_prompt}\n{response.json()['result']}"
pos_tokens = clip.tokenize(full_prompt)
pos_output = clip.encode_from_tokens(pos_tokens, return_pooled=True, return_dict=True)
pos_cond = pos_output.pop("cond")
neg_tokens = clip.tokenize(negative_prompt)
neg_output = clip.encode_from_tokens(neg_tokens, return_pooled=True, return_dict=True)
neg_cond = neg_output.pop("cond")
return (full_prompt, negative_prompt, [[pos_cond, pos_output]], [[neg_cond, neg_output]])
class MD_CompressAdjustNode:
"""
Detect compression level and adjust to desired CRF.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"desired_crf": ("INT", {
"default": 25,
"min": 0,
"max": 51,
"step": 1
}),
},
}
RETURN_TYPES = ("IMAGE", "INT")
RETURN_NAMES = ("adjusted_image", "crf")
FUNCTION = "tensor_to_video_and_back"
CATEGORY = "MemeDeck"
def tensor_to_int(self,tensor, bits):
tensor = tensor.cpu().numpy() * (2**bits-1)
return np.clip(tensor, 0, (2**bits-1))
def tensor_to_bytes(self, tensor):
return self.tensor_to_int(tensor, 8).astype(np.uint8)
def ffmpeg_process(self, args, file_path, env):
res = None
frame_data = yield
total_frames_output = 0
if res != b'':
with subprocess.Popen(args + [file_path], stderr=subprocess.PIPE,
stdin=subprocess.PIPE, env=env) as proc:
try:
while frame_data is not None:
proc.stdin.write(frame_data)
frame_data = yield
total_frames_output+=1
proc.stdin.flush()
proc.stdin.close()
res = proc.stderr.read()
except BrokenPipeError as e:
res = proc.stderr.read()
raise Exception("An error occurred in the ffmpeg subprocess:\n" \
+ res.decode("utf-8"))
yield total_frames_output
if len(res) > 0:
print(res.decode("utf-8"), end="", file=sys.stderr)
def detect_image_clarity(self, image):
# detect the clarity of the image
# return a score between 0 and 100
# 0 is the lowest clarity
# 100 is the highest clarity
return 100
def tensor_to_video_and_back(self, image, desired_crf=30):
temp_dir = "temp_video"
filename = f"frame_{time.time()}".split('.')[0]
os.makedirs(temp_dir, exist_ok=True)
# Convert single image to list if necessary
if len(image.shape) == 3:
image = [image]
first_image = image[0]
has_alpha = first_image.shape[-1] == 4
dim_alignment = 8
if (first_image.shape[1] % dim_alignment) or (first_image.shape[0] % dim_alignment):
# pad the image to the nearest multiple of 8
to_pad = (-first_image.shape[1] % dim_alignment,
-first_image.shape[0] % dim_alignment)
padding = (to_pad[0]//2, to_pad[0] - to_pad[0]//2,
to_pad[1]//2, to_pad[1] - to_pad[1]//2)
padfunc = torch.nn.ReplicationPad2d(padding)
def pad(image):
image = image.permute((2,0,1))#HWC to CHW
padded = padfunc(image.to(dtype=torch.float32))
return padded.permute((1,2,0))
# pad single image
first_image = pad(first_image)
new_dims = (-first_image.shape[1] % dim_alignment + first_image.shape[1],
-first_image.shape[0] % dim_alignment + first_image.shape[0])
dimensions = f"{new_dims[0]}x{new_dims[1]}"
logger.warn("Output images were not of valid resolution and have had padding applied")
else:
dimensions = f"{first_image.shape[1]}x{first_image.shape[0]}"
first_image_bytes = self.tensor_to_bytes(first_image).tobytes()
if has_alpha:
i_pix_fmt = 'rgba'
else:
i_pix_fmt = 'rgb24'
# default bitrate and frame rate
frame_rate = 25
args = [
utils.ffmpeg_path,
"-v", "error",
"-f", "rawvideo",
"-pix_fmt", i_pix_fmt,
"-s", dimensions,
"-r", str(frame_rate),
"-i", "-",
"-y",
"-c:v", "libx264",
"-pix_fmt", "yuv420p",
"-crf", str(desired_crf),
]
video_path = os.path.abspath(str(Path(temp_dir) / f"{filename}.mp4"))
env = os.environ.copy()
output_process = self.ffmpeg_process(args, video_path, env)
# Proceed to first yield
output_process.send(None)
output_process.send(first_image_bytes)
try:
output_process.send(None) # Signal end of input
next(output_process) # Get the final yield
except StopIteration:
pass
time.sleep(0.5)
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not created at {video_path}")
# load the video h264 codec
video = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)
if not video.isOpened():
raise RuntimeError(f"Failed to open video file: {video_path}")
# read the first frame
ret, frame = video.read()
if not ret:
raise RuntimeError("Failed to read frame from video")
video.release()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
try:
os.remove(video_path)
except OSError as e:
print(f"Warning: Could not remove temporary file {video_path}: {e}")
# convert the frame to a PIL image for ComfyUI
frame = Image.fromarray(frame)
frame_tensor = pil2tensor(frame)
return (frame_tensor, desired_crf)

View File

@ -0,0 +1,6 @@
requests
numpy
torch
Pillow
opencv-python
torchvision

View File

@ -0,0 +1,155 @@
class Example:
"""
A example node
Class methods
-------------
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.
Attributes
----------
RETURN_TYPES (`tuple`):
The type of each element in the output tuple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tuple.
FUNCTION (`str`):
The name of the entry-point method. For example, if `FUNCTION = "execute"` then it will run Example().execute()
OUTPUT_NODE ([`bool`]):
If this node is an output node that outputs a result/image from the graph. The SaveImage node is an example.
The backend iterates on these output nodes and tries to execute all their parents if their parent graph is properly connected.
Assumed to be False if not present.
CATEGORY (`str`):
The category the node should appear in the UI.
DEPRECATED (`bool`):
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
functional in existing workflows that use them.
EXPERIMENTAL (`bool`):
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
significant changes or removal in future versions. Use with caution in production workflows.
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
"""
Return a dictionary which contains config for all input fields.
Some types (string): "MODEL", "VAE", "CLIP", "CONDITIONING", "LATENT", "IMAGE", "INT", "STRING", "FLOAT".
Input types "INT", "STRING" or "FLOAT" are special values for fields on the node.
The type can be a list for selection.
Returns: `dict`:
- Key input_fields_group (`string`): Can be either required, hidden or optional. A node class must have property `required`
- Value input_fields (`dict`): Contains input fields config:
* Key field_name (`string`): Name of a entry-point method's argument
* Value field_config (`tuple`):
+ First value is a string indicate the type of field or a list for selection.
+ Second value is a config for type "INT", "STRING" or "FLOAT".
"""
return {
"required": {
"image": ("IMAGE",),
"int_field": ("INT", {
"default": 0,
"min": 0, #Minimum value
"max": 4096, #Maximum value
"step": 64, #Slider's step
"display": "number", # Cosmetic only: display as "number" or "slider"
"lazy": True # Will only be evaluated if check_lazy_status requires it
}),
"float_field": ("FLOAT", {
"default": 1.0,
"min": 0.0,
"max": 10.0,
"step": 0.01,
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
"display": "number",
"lazy": True
}),
"print_to_screen": (["enable", "disable"],),
"string_field": ("STRING", {
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
"default": "Hello World!",
"lazy": True
}),
},
}
RETURN_TYPES = ("IMAGE",)
#RETURN_NAMES = ("image_output_name",)
FUNCTION = "test"
#OUTPUT_NODE = False
CATEGORY = "Example"
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
"""
Return a list of input names that need to be evaluated.
This function will be called if there are any lazy inputs which have not yet been
evaluated. As long as you return at least one field which has not yet been evaluated
(and more exist), this function will be called again once the value of the requested
field is available.
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
inputs will have the value None.
"""
if print_to_screen == "enable":
return ["int_field", "float_field", "string_field"]
else:
return []
def test(self, image, string_field, int_field, float_field, print_to_screen):
if print_to_screen == "enable":
print(f"""Your input contains:
string_field aka input text: {string_field}
int_field: {int_field}
float_field: {float_field}
""")
#do some processing on the image, in this example I just invert it
image = 1.0 - image
return (image,)
"""
The node will always be re executed if any of the inputs change but
this method can be used to force the node to execute again even when the inputs don't change.
You can make this node return a number or a string. This value will be compared to the one returned the last time the node was
executed, if it is different the node will be executed again.
This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash
changes between executions the LoadImage node is executed again.
"""
#@classmethod
#def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen):
# return ""
# Set the web directory, any .js file in that directory will be loaded by the frontend as a frontend extension
# WEB_DIRECTORY = "./somejs"
# Add custom API routes, using router
from aiohttp import web
from server import PromptServer
@PromptServer.instance.routes.get("/hello")
async def get_hello(request):
return web.json_response("hello")
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Example": Example
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"Example": "Example Node"
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.4 KiB

View File

@ -5,7 +5,6 @@ import logging
import uuid
from PIL import Image, ImageOps
from functools import partial
import pika
import json
@ -63,6 +62,13 @@ class MemedeckWorker:
self.api_key = os.getenv('API_KEY') or 'eb46e20a-cc25-4ed4-a39b-f47ca8ff3383'
self.training_only = os.getenv('TRAINING_ONLY') or False
self.video_gen_only = False
self.azure_storage = MemedeckAzureStorage()
if self.queue_name == 'video-gen-queue':
print(f"[memedeck]: video gen only mode enabled")
self.video_gen_only = True
if self.training_only:
self.queue_name = 'training-queue'
@ -145,10 +151,35 @@ class MemedeckWorker:
valid = self.validate_prompt(prompt)
routing_key = method.routing_key
workflow = 'training' if self.training_only else 'faceswap' if routing_key == 'faceswap-queue' else 'generation'
workflow = 'faceswap' if routing_key == 'faceswap-queue' else 'generation'
user_id = None
if self.video_gen_only:
workflow = 'video_gen'
user_id = payload["user_id"]
if self.training_only:
workflow = 'training'
end_node_id = None
video_source_image_id = None
# Find the end_node_id
if not self.video_gen_only and not self.training_only:
for node in prompt:
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
end_node_id = node
break
elif self.video_gen_only:
end_node_id = payload['end_node_id']
video_source_image_id = payload['image_id']
elif self.training_only:
end_node_id = "130"
self.logger.info(f"[memedeck]: end_node_id: {end_node_id}")
# Prepare task_info
prompt_id = str(uuid.uuid4())
prompt_id = str(uuid.uuid4()).replace("-", "_")
outputs_to_execute = valid[2]
task_info = {
"workflow": workflow,
@ -157,7 +188,7 @@ class MemedeckWorker:
"outputs_to_execute": outputs_to_execute,
"client_id": "memedeck-1",
"is_memedeck": True,
"end_node_id": "130" if self.training_only else None,
"end_node_id": end_node_id,
"ws_id": payload["source_ws_id"],
"context": payload["req_ctx"] if "req_ctx" in payload else {},
"current_node": None,
@ -181,15 +212,12 @@ class MemedeckWorker:
"59": "loop-3",
"60": "loop-4",
},
"training_filename": payload['filename'] if 'filename' in payload else None
"training_filename": payload['filename'] if 'filename' in payload else None,
# video data
"image_id": video_source_image_id,
"user_id": user_id,
}
# Find the end_node_id
for node in prompt:
if isinstance(prompt[node], dict) and prompt[node].get("class_type") == "SaveImageWebsocket":
task_info['end_node_id'] = node
break
if valid[0]:
# Enqueue the task into the internal job queue
self.loop.call_soon_threadsafe(self.internal_job_queue.put_nowait, (prompt_id, prompt, task_info))
@ -218,7 +246,7 @@ class MemedeckWorker:
'context': task_info['context']
}, task_info['outputs_to_execute']))
if task_info['workflow'] != 'training':
if task_info['workflow'] != 'training' and task_info['workflow'] != 'video_gen':
if 'faceswap_strength' not in task_info['context']['prompt_config']:
# pretty print the prompt config
self.logger.info(f"[memedeck]: prompt: {task_info['context']['prompt_config']['character']['id']} {task_info['context']['prompt_config']['positive_prompt']}")
@ -279,6 +307,8 @@ class MemedeckWorker:
if task['workflow'] == 'training':
return await self.handle_training_send(event=event, task=task, sid=sid, data=data)
if task['workflow'] == 'video_gen':
return await self.handle_video_gen_send(event=event, task=task, sid=sid, data=data)
# this logic is for generation and faceswap (todo move to separate function)
if event == MemedeckWorker.BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
@ -340,7 +370,6 @@ class MemedeckWorker:
# if event == "progress":
# self.logger.info(f"[memedeck]: training progress: {data}")
if event == "executing":
training_progress = task['training_loop'][data['node']] if data['node'] in task['training_loop'] else None
status = task['training_status'][data['node']] if data['node'] in task['training_status'] else None
@ -362,9 +391,54 @@ class MemedeckWorker:
})
# training task is done
del self.tasks_by_ws_id[sid]
async def handle_video_gen_send(self, event, task, sid, data):
if event == "progress":
if data['max'] > 1:
progress = data['value']
max_progress = data['max']
# calculate the percentage
percentage = (progress / max_progress)
await self.send_to_api({
"ws_id": sid,
"image_id": task['image_id'],
"user_id": task['user_id'],
"status": "generating",
"progress": percentage * 0.9 # 90% of the progress is the gen step, 10% is the video encode step
})
if event == "executed":
if data['node'] == task['end_node_id']:
filename = data['output']['images'][0]['filename']
current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(current_dir, "output", filename)
blob_name = f"{task['user_id']}/video_gen/video_{task['image_id'].replace('image:', '')}_{task['prompt_id']}.webp"
# TODO: take the file path and upload to azure blob storage
# load image bytes
with open(file_path, "rb") as image_file:
image_bytes = image_file.read()
self.logger.info(f"[memedeck]: video gen completed for {sid}, file={file_path}, blob={blob_name}")
url = await self.azure_storage.save_image(blob_name, "image/webp", image_bytes)
self.logger.info(f"[memedeck]: video gen completed for {sid}, {url}")
await self.send_to_api({
"ws_id": sid,
"progress": 1.0,
"image_id": task['image_id'],
"user_id": task['user_id'],
"status": "completed",
"output_video_url": url
})
# video gen task is done
del self.tasks_by_ws_id[sid]
async def send_preview(self, image_data, sid=None, progress=None, context=None, workflow=None):
self.logger.info(f"[memedeck]: send_preview: {sid}")
if sid is None:
self.logger.warning("Received preview without sid")
return
@ -406,6 +480,8 @@ class MemedeckWorker:
"context": context
}
self.logger.info(f"[memedeck]: progress kind: {kind}")
self.logger.info(f"[memedeck]: progress: {progress}")
# set the kind to faceswap_generated if workflow is faceswap
if workflow == 'faceswap':
ai_queue_progress['kind'] = "faceswap_generated"
@ -430,7 +506,13 @@ class MemedeckWorker:
self.logger.error(f"[memedeck]: end_node_id is None for {ws_id}")
return
api_endpoint = '/generation/update' if task['workflow'] != 'training' else '/training/update'
api_endpoint = '/generation/update'
if task['workflow'] == 'training':
api_endpoint = '/training/update'
if task['workflow'] == 'video_gen':
api_endpoint = '/generation/video/update'
# self.logger.info(f"[memedeck]: sending to api: {api_endpoint}")
# self.logger.info(f"[memedeck]: data: {data}")
try:
@ -443,250 +525,188 @@ class MemedeckWorker:
# --------------------------------------------------------------------------
# MemedeckAzureStorage
# --------------------------------------------------------------------------
# from azure.storage.blob.aio import BlobClient, BlobServiceClient
# from azure.storage.blob import ContentSettings
# from typing import Optional, Tuple
# import cairosvg
from azure.storage.blob.aio import BlobClient, BlobServiceClient
from azure.storage.blob import ContentSettings
from typing import Optional, Tuple
import cairosvg
# WATERMARK = '<svg width="256" height="256" viewBox="0 0 256 256" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M60.0859 196.8C65.9526 179.067 71.5526 161.667 76.8859 144.6C79.1526 137.4 81.4859 129.867 83.8859 122C86.2859 114.133 88.6859 106.333 91.0859 98.6C93.4859 90.8667 95.6859 83.4 97.6859 76.2C99.8193 69 101.686 62.3333 103.286 56.2C110.619 56.2 117.553 55.8 124.086 55C130.619 54.2 137.686 53.4667 145.286 52.8C144.886 55.7333 144.419 59.0667 143.886 62.8C143.486 66.4 142.953 70.2 142.286 74.2C141.753 78.2 141.153 82.3333 140.486 86.6C139.819 90.8667 139.019 96.3333 138.086 103C137.153 109.667 135.886 118 134.286 128H136.886C140.753 117.867 143.953 109.467 146.486 102.8C149.019 96 151.086 90.4667 152.686 86.2C154.286 81.9333 155.886 77.8 157.486 73.8C159.219 69.6667 160.819 65.8 162.286 62.2C163.886 58.4667 165.353 55.2 166.686 52.4C170.019 52.1333 173.153 51.8 176.086 51.4C179.019 51 181.953 50.6 184.886 50.2C187.819 49.6667 190.753 49.2 193.686 48.8C196.753 48.2667 200.086 47.6667 203.686 47C202.353 54.7333 201.086 62.6667 199.886 70.8C198.686 78.9333 197.619 87.0667 196.686 95.2C195.753 103.2 194.819 111.133 193.886 119C193.086 126.867 192.353 134.333 191.686 141.4C190.086 157.933 188.686 174.067 187.486 189.8L152.686 196C152.686 195.333 152.753 193.533 152.886 190.6C153.153 187.667 153.419 184.067 153.686 179.8C154.086 175.533 154.553 170.8 155.086 165.6C155.753 160.4 156.353 155.2 156.886 150C157.553 144.8 158.219 139.8 158.886 135C159.553 130.067 160.219 125.867 160.886 122.4H159.086C157.219 128 155.153 133.933 152.886 140.2C150.619 146.333 148.286 152.6 145.886 159C143.619 165.4 141.353 171.667 139.086 177.8C136.819 183.933 134.819 189.8 133.086 195.4C128.419 195.533 124.419 195.733 121.086 196C117.753 196.133 113.886 196.333 109.486 196.6L115.886 122.4H112.886C112.619 124.133 111.953 127.067 110.886 131.2C109.819 135.2 108.553 139.867 107.086 145.2C105.753 150.4 104.286 155.867 102.686 161.6C101.086 167.2 99.5526 172.467 98.0859 177.4C96.7526 182.2 95.6193 186.2 94.6859 189.4C93.7526 192.467 93.2193 194.2 93.0859 194.6L60.0859 196.8Z" fill="white"/></svg>'
# WATERMARK_SIZE = 40
WATERMARK = '<svg width="256" height="256" viewBox="0 0 256 256" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M60.0859 196.8C65.9526 179.067 71.5526 161.667 76.8859 144.6C79.1526 137.4 81.4859 129.867 83.8859 122C86.2859 114.133 88.6859 106.333 91.0859 98.6C93.4859 90.8667 95.6859 83.4 97.6859 76.2C99.8193 69 101.686 62.3333 103.286 56.2C110.619 56.2 117.553 55.8 124.086 55C130.619 54.2 137.686 53.4667 145.286 52.8C144.886 55.7333 144.419 59.0667 143.886 62.8C143.486 66.4 142.953 70.2 142.286 74.2C141.753 78.2 141.153 82.3333 140.486 86.6C139.819 90.8667 139.019 96.3333 138.086 103C137.153 109.667 135.886 118 134.286 128H136.886C140.753 117.867 143.953 109.467 146.486 102.8C149.019 96 151.086 90.4667 152.686 86.2C154.286 81.9333 155.886 77.8 157.486 73.8C159.219 69.6667 160.819 65.8 162.286 62.2C163.886 58.4667 165.353 55.2 166.686 52.4C170.019 52.1333 173.153 51.8 176.086 51.4C179.019 51 181.953 50.6 184.886 50.2C187.819 49.6667 190.753 49.2 193.686 48.8C196.753 48.2667 200.086 47.6667 203.686 47C202.353 54.7333 201.086 62.6667 199.886 70.8C198.686 78.9333 197.619 87.0667 196.686 95.2C195.753 103.2 194.819 111.133 193.886 119C193.086 126.867 192.353 134.333 191.686 141.4C190.086 157.933 188.686 174.067 187.486 189.8L152.686 196C152.686 195.333 152.753 193.533 152.886 190.6C153.153 187.667 153.419 184.067 153.686 179.8C154.086 175.533 154.553 170.8 155.086 165.6C155.753 160.4 156.353 155.2 156.886 150C157.553 144.8 158.219 139.8 158.886 135C159.553 130.067 160.219 125.867 160.886 122.4H159.086C157.219 128 155.153 133.933 152.886 140.2C150.619 146.333 148.286 152.6 145.886 159C143.619 165.4 141.353 171.667 139.086 177.8C136.819 183.933 134.819 189.8 133.086 195.4C128.419 195.533 124.419 195.733 121.086 196C117.753 196.133 113.886 196.333 109.486 196.6L115.886 122.4H112.886C112.619 124.133 111.953 127.067 110.886 131.2C109.819 135.2 108.553 139.867 107.086 145.2C105.753 150.4 104.286 155.867 102.686 161.6C101.086 167.2 99.5526 172.467 98.0859 177.4C96.7526 182.2 95.6193 186.2 94.6859 189.4C93.7526 192.467 93.2193 194.2 93.0859 194.6L60.0859 196.8Z" fill="white"/></svg>'
WATERMARK_SIZE = 40
# class MemedeckAzureStorage:
# def __init__(self, connection_string):
# # get environment variables
# self.storage_account = os.getenv('STORAGE_ACCOUNT')
# self.storage_access_key = os.getenv('STORAGE_ACCESS_KEY')
# self.storage_container = os.getenv('STORAGE_CONTAINER')
# self.logger = logging.getLogger(__name__)
class MemedeckAzureStorage:
def __init__(self):
# get environment variables
self.account = os.getenv('STORAGE_ACCOUNT')
self.access_key = os.getenv('STORAGE_ACCESS_KEY')
self.container = os.getenv('STORAGE_CONTAINER')
self.logger = logging.getLogger(__name__)
# self.blob_service_client = BlobServiceClient.from_connection_string(conn_str=connection_string)
if not all([self.account, self.access_key, self.container]):
raise EnvironmentError("Missing STORAGE_ACCOUNT, STORAGE_ACCESS_KEY, or STORAGE_CONTAINER environment variables")
# async def upload_image(
# self,
# by: str,
# image_id: str,
# source_url: Optional[str],
# bytes_data: Optional[bytes],
# filetype: Optional[str],
# ) -> Tuple[str, Tuple[int, int]]:
# """
# Uploads an image to Azure Blob Storage.
# Initialize BlobServiceClient
self.blob_service_client = BlobServiceClient(
account_url=f"https://{self.account}.blob.core.windows.net",
credential=self.access_key
)
self.logger.info(f"[memedeck]: Azure Storage connected.")
# Args:
# by (str): Identifier for the uploader.
# image_id (str): Unique identifier for the image.
# source_url (Optional[str]): URL to fetch the image from.
# bytes_data (Optional[bytes]): Image data in bytes.
# filetype (Optional[str]): Desired file type (e.g., 'jpeg', 'png').
async def save_image(
self,
blob_name: str,
content_type: str,
bytes_data: bytes
) -> str:
"""
Saves image bytes to Azure Blob Storage.
# Returns:
# Tuple[str, Tuple[int, int]]: URL of the uploaded image and its dimensions.
# """
# # Retrieve image bytes either from the provided bytes_data or by fetching from source_url
# if source_url is None:
# if bytes_data is None:
# raise ValueError("Could not get image bytes")
# image_bytes = bytes_data
# else:
# self.logger.info(f"Requesting image from URL: {source_url}")
# async with aiohttp.ClientSession() as session:
# try:
# async with session.get(source_url) as response:
# if response.status != 200:
# raise Exception(f"Failed to fetch image, status code {response.status}")
# image_bytes = await response.read()
# except Exception as e:
# raise Exception(f"Error fetching image from URL: {e}")
Args:
blob_name (str): Name of the blob in Azure Storage.
content_type (str): MIME type of the content.
bytes_data (bytes): Image data in bytes.
# # Open image using Pillow to get dimensions and format
# try:
# img = Image.open(BytesIO(image_bytes))
# width, height = img.size
# inferred_filetype = img.format.lower()
# except Exception as e:
# raise Exception(f"Failed to decode image: {e}")
Returns:
str: URL of the uploaded blob.
"""
# # Determine the final file type
# final_filetype = filetype.lower() if filetype else inferred_filetype
blob_client = self.blob_service_client.get_blob_client(container=self.container, blob=blob_name)
# # Construct the blob name
# blob_name = f"{by}/{image_id.replace('image:', '')}.{final_filetype}"
# Upload the blob
try:
await blob_client.upload_blob(
bytes_data,
overwrite=True,
content_settings=ContentSettings(content_type=content_type)
)
except Exception as e:
raise Exception(f"Failed to upload blob: {e}")
# close the blob client
blob_client.close()
# # Upload the image to Azure Blob Storage
# try:
# image_url = await self.save_image(blob_name, img.format, image_bytes)
# return image_url, (width, height)
# except Exception as e:
# self.logger.error(f"Trouble saving image: {e}")
# raise Exception(f"Trouble saving image: {e}")
# Construct and return the blob URL
blob_url = f"https://media.memedeck.xyz/{self.container}/{blob_name}"
return blob_url
# async def save_image(
# self,
# blob_name: str,
# content_type: str,
# bytes_data: bytes
# ) -> str:
# """
# Saves image bytes to Azure Blob Storage.
# async def add_watermark(
# self,
# base_blob_name: str,
# base_image: bytes
# ) -> str:
# """
# Adds a watermark to the provided image and uploads the watermarked image.
# Args:
# blob_name (str): Name of the blob in Azure Storage.
# content_type (str): MIME type of the content.
# bytes_data (bytes): Image data in bytes.
# Args:
# base_blob_name (str): Original blob name of the image.
# base_image (bytes): Image data in bytes.
# Returns:
# str: URL of the uploaded blob.
# """
# # Retrieve environment variables
# account = os.getenv("STORAGE_ACCOUNT")
# access_key = os.getenv("STORAGE_ACCESS_KEY")
# container = os.getenv("STORAGE_CONTAINER")
# Returns:
# str: URL of the watermarked image.
# """
# # Load the input image
# try:
# img = Image.open(BytesIO(base_image)).convert("RGBA")
# except Exception as e:
# raise Exception(f"Failed to load image: {e}")
# if not all([account, access_key, container]):
# raise EnvironmentError("Missing STORAGE_ACCOUNT, STORAGE_ACCESS_KEY, or STORAGE_CONTAINER environment variables")
# # Calculate position for the watermark (bottom right corner with padding)
# padding = 12
# x = img.width - WATERMARK_SIZE - padding
# y = img.height - WATERMARK_SIZE - padding
# # Initialize BlobServiceClient
# blob_service_client = BlobServiceClient(
# account_url=f"https://{account}.blob.core.windows.net",
# credential=access_key
# )
# blob_client = blob_service_client.get_blob_client(container=container, blob=blob_name)
# # Analyze background brightness where the watermark will be placed
# background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE)
# self.logger.info(f"Background brightness: {background_brightness}")
# # Upload the blob
# try:
# await blob_client.upload_blob(
# bytes_data,
# overwrite=True,
# content_settings=ContentSettings(content_type=content_type)
# )
# except Exception as e:
# raise Exception(f"Failed to upload blob: {e}")
# # Render SVG watermark to PNG bytes using cairosvg
# try:
# watermark_png_bytes = cairosvg.svg2png(bytestring=WATERMARK.encode('utf-8'), output_width=WATERMARK_SIZE, output_height=WATERMARK_SIZE)
# watermark = Image.open(BytesIO(watermark_png_bytes)).convert("RGBA")
# except Exception as e:
# raise Exception(f"Failed to render watermark SVG: {e}")
# self.logger.debug(f"Blob uploaded: name={blob_name}, content_type={content_type}")
# # Determine watermark color based on background brightness
# if background_brightness > 128:
# # Dark watermark for light backgrounds
# watermark_color = (0, 0, 0, int(255 * 0.65)) # Black with 65% opacity
# else:
# # Light watermark for dark backgrounds
# watermark_color = (255, 255, 255, int(255 * 0.65)) # White with 65% opacity
# # Construct and return the blob URL
# blob_url = f"https://media.memedeck.xyz//{container}/{blob_name}"
# return blob_url
# # Apply the watermark color by blending
# solid_color = Image.new("RGBA", watermark.size, watermark_color)
# watermark = Image.alpha_composite(watermark, solid_color)
# async def add_watermark(
# self,
# base_blob_name: str,
# base_image: bytes
# ) -> str:
# """
# Adds a watermark to the provided image and uploads the watermarked image.
# # Overlay the watermark onto the original image
# img.paste(watermark, (x, y), watermark)
# Args:
# base_blob_name (str): Original blob name of the image.
# base_image (bytes): Image data in bytes.
# # Save the watermarked image to bytes
# buffer = BytesIO()
# img = img.convert("RGB") # Convert back to RGB for JPEG format
# img.save(buffer, format="JPEG")
# buffer.seek(0)
# jpeg_bytes = buffer.read()
# Returns:
# str: URL of the watermarked image.
# """
# # Load the input image
# try:
# img = Image.open(BytesIO(base_image)).convert("RGBA")
# except Exception as e:
# raise Exception(f"Failed to load image: {e}")
# # Modify the blob name to include '_watermarked'
# try:
# if "memes/" in base_blob_name:
# base_blob_name_right = base_blob_name.split("memes/", 1)[1]
# else:
# base_blob_name_right = base_blob_name
# base_blob_name_split = base_blob_name_right.rsplit(".", 1)
# base_blob_name_without_extension = base_blob_name_split[0]
# extension = base_blob_name_split[1]
# except Exception as e:
# raise Exception(f"Failed to process blob name: {e}")
# # Calculate position for the watermark (bottom right corner with padding)
# padding = 12
# x = img.width - WATERMARK_SIZE - padding
# y = img.height - WATERMARK_SIZE - padding
# watermarked_blob_name = f"{base_blob_name_without_extension}_watermarked.{extension}"
# # Analyze background brightness where the watermark will be placed
# background_brightness = self.analyze_background_brightness(img, x, y, WATERMARK_SIZE)
# self.logger.info(f"Background brightness: {background_brightness}")
# # Upload the watermarked image
# try:
# watermarked_blob_url = await self.save_image(
# watermarked_blob_name,
# "image/jpeg",
# jpeg_bytes
# )
# return watermarked_blob_url
# except Exception as e:
# raise Exception(f"Failed to upload watermarked image: {e}")
# # Render SVG watermark to PNG bytes using cairosvg
# try:
# watermark_png_bytes = cairosvg.svg2png(bytestring=WATERMARK.encode('utf-8'), output_width=WATERMARK_SIZE, output_height=WATERMARK_SIZE)
# watermark = Image.open(BytesIO(watermark_png_bytes)).convert("RGBA")
# except Exception as e:
# raise Exception(f"Failed to render watermark SVG: {e}")
# def analyze_background_brightness(
# self,
# img: Image.Image,
# x: int,
# y: int,
# size: int
# ) -> int:
# """
# Analyzes the brightness of a specific region in the image.
# # Determine watermark color based on background brightness
# if background_brightness > 128:
# # Dark watermark for light backgrounds
# watermark_color = (0, 0, 0, int(255 * 0.65)) # Black with 65% opacity
# else:
# # Light watermark for dark backgrounds
# watermark_color = (255, 255, 255, int(255 * 0.65)) # White with 65% opacity
# Args:
# img (Image.Image): The image to analyze.
# x (int): X-coordinate of the top-left corner of the region.
# y (int): Y-coordinate of the top-left corner of the region.
# size (int): Size of the square region to analyze.
# # Apply the watermark color by blending
# solid_color = Image.new("RGBA", watermark.size, watermark_color)
# watermark = Image.alpha_composite(watermark, solid_color)
# Returns:
# int: Average brightness (0-255) of the region.
# """
# # Crop the specified region
# sub_image = img.crop((x, y, x + size, y + size)).convert("RGB")
# # Overlay the watermark onto the original image
# img.paste(watermark, (x, y), watermark)
# # Calculate average brightness using the luminance formula
# total_brightness = 0
# pixel_count = 0
# for pixel in sub_image.getdata():
# r, g, b = pixel
# brightness = (r * 299 + g * 587 + b * 114) // 1000
# total_brightness += brightness
# pixel_count += 1
# # Save the watermarked image to bytes
# buffer = BytesIO()
# img = img.convert("RGB") # Convert back to RGB for JPEG format
# img.save(buffer, format="JPEG")
# buffer.seek(0)
# jpeg_bytes = buffer.read()
# if pixel_count == 0:
# return 0
# # Modify the blob name to include '_watermarked'
# try:
# if "memes/" in base_blob_name:
# base_blob_name_right = base_blob_name.split("memes/", 1)[1]
# else:
# base_blob_name_right = base_blob_name
# base_blob_name_split = base_blob_name_right.rsplit(".", 1)
# base_blob_name_without_extension = base_blob_name_split[0]
# extension = base_blob_name_split[1]
# except Exception as e:
# raise Exception(f"Failed to process blob name: {e}")
# watermarked_blob_name = f"{base_blob_name_without_extension}_watermarked.{extension}"
# # Upload the watermarked image
# try:
# watermarked_blob_url = await self.save_image(
# watermarked_blob_name,
# "image/jpeg",
# jpeg_bytes
# )
# return watermarked_blob_url
# except Exception as e:
# raise Exception(f"Failed to upload watermarked image: {e}")
# def analyze_background_brightness(
# self,
# img: Image.Image,
# x: int,
# y: int,
# size: int
# ) -> int:
# """
# Analyzes the brightness of a specific region in the image.
# Args:
# img (Image.Image): The image to analyze.
# x (int): X-coordinate of the top-left corner of the region.
# y (int): Y-coordinate of the top-left corner of the region.
# size (int): Size of the square region to analyze.
# Returns:
# int: Average brightness (0-255) of the region.
# """
# # Crop the specified region
# sub_image = img.crop((x, y, x + size, y + size)).convert("RGB")
# # Calculate average brightness using the luminance formula
# total_brightness = 0
# pixel_count = 0
# for pixel in sub_image.getdata():
# r, g, b = pixel
# brightness = (r * 299 + g * 587 + b * 114) // 1000
# total_brightness += brightness
# pixel_count += 1
# if pixel_count == 0:
# return 0
# average_brightness = total_brightness // pixel_count
# return average_brightness
# average_brightness = total_brightness // pixel_count
# return average_brightness

File diff suppressed because one or more lines are too long

View File

@ -1,4 +0,0 @@
# Config for testing nodes
testing:
custom_nodes: tests/inference/testing_nodes