Merge branch 'comfyanonymous:master' into master

This commit is contained in:
Red 2025-03-17 07:55:05 -04:00 committed by GitHub
commit 6ee74beec4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 197 additions and 7 deletions

View File

@ -9,6 +9,7 @@ import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
class Output:
def __getitem__(self, key):
@ -34,6 +35,11 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
@ -42,10 +48,11 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision"))
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
@ -111,6 +118,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
else:
return None

View File

@ -0,0 +1,141 @@
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
class Dino2AttentionOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
def forward(self, x):
return self.dense(x)
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
class LayerScale(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
in_features = out_features = dim
hidden_features = int(dim * 4)
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
x = self.weights_in(x)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
return self.weights_out(x)
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class Dino2PatchEmbeddings(torch.nn.Module):
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
super().__init__()
self.projection = operations.Conv2d(
in_channels=num_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
dtype=dtype,
device=device
)
def forward(self, pixel_values):
return self.projection(pixel_values).flatten(2).transpose(1, 2)
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
patch_size = 14
image_size = 518
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x
class Dinov2Model(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
pooled_output = x[:, 0, :]
return x, i, pooled_output, None

View File

@ -0,0 +1,21 @@
{
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 1536,
"image_size": 518,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"layerscale_value": 1.0,
"mlp_ratio": 4,
"model_type": "dinov2",
"num_attention_heads": 24,
"num_channels": 3,
"num_hidden_layers": 40,
"patch_size": 14,
"qkv_bias": true,
"use_swiglu_ffn": true,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@ -1419,6 +1419,6 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised_d = denoised_d
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt()
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x

View File

@ -10,8 +10,8 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q_shape = q.shape
k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
@ -36,8 +36,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@ -747,6 +747,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
unload_list = self._load_list()
@ -770,6 +771,10 @@ class ModelPatcher:
move_weight = False
break
if not hooks_unpatched:
self.unpatch_hooks()
hooks_unpatched = True
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:

View File

@ -440,6 +440,10 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode()
@ -495,6 +499,7 @@ class VAE:
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in):
self.throw_exception_if_invalid()
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
@ -525,6 +530,7 @@ class VAE:
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2
@ -553,6 +559,7 @@ class VAE:
return output.movedim(1, -1)
def encode(self, pixel_samples):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3 and pixel_samples.ndim < 5:
@ -585,6 +592,7 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
@ -899,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
return None
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
if diffusion_model is None:
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:

View File

@ -770,6 +770,7 @@ class VAELoader:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid()
return (vae,)
class ControlNetLoader: