mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-04-19 10:53:29 +00:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
6ee74beec4
@ -9,6 +9,7 @@ import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -34,6 +35,11 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
@ -42,10 +48,11 @@ class ClipVisionModel():
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision"))
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
@ -111,6 +118,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||
elif "embeddings.patch_embeddings.projection.weight" in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
141
comfy/image_encoders/dino2.py
Normal file
141
comfy/image_encoders/dino2.py
Normal file
@ -0,0 +1,141 @@
|
||||
import torch
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dense(x)
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
|
||||
|
||||
|
||||
class SwiGLUFFN(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
in_features = out_features = dim
|
||||
hidden_features = int(dim * 4)
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
|
||||
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
|
||||
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.weights_in(x)
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x = torch.nn.functional.silu(x1) * x2
|
||||
return self.weights_out(x)
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layer) + intermediate_output
|
||||
|
||||
intermediate = None
|
||||
for i, l in enumerate(self.layer):
|
||||
x = l(x, optimized_attention)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.projection = operations.Conv2d(
|
||||
in_channels=num_channels,
|
||||
out_channels=dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
return self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, pixel_values):
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
return x
|
||||
|
||||
|
||||
class Dinov2Model(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
dim = config_dict["hidden_size"]
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
pooled_output = x[:, 0, :]
|
||||
return x, i, pooled_output, None
|
21
comfy/image_encoders/dino2_giant.json
Normal file
21
comfy/image_encoders/dino2_giant.json
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"drop_path_rate": 0.0,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_size": 1536,
|
||||
"image_size": 518,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"layerscale_value": 1.0,
|
||||
"mlp_ratio": 4,
|
||||
"model_type": "dinov2",
|
||||
"num_attention_heads": 24,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 40,
|
||||
"patch_size": 14,
|
||||
"qkv_bias": true,
|
||||
"use_swiglu_ffn": true,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
@ -1419,6 +1419,6 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
old_denoised_d = denoised_d
|
||||
|
||||
if s_noise != 0 and sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt()
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
|
||||
old_denoised = denoised
|
||||
return x
|
||||
|
@ -10,8 +10,8 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
@ -36,8 +36,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
@ -747,6 +747,7 @@ class ModelPatcher:
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
@ -770,6 +771,10 @@ class ModelPatcher:
|
||||
move_weight = False
|
||||
break
|
||||
|
||||
if not hooks_unpatched:
|
||||
self.unpatch_hooks()
|
||||
hooks_unpatched = True
|
||||
|
||||
if bk.inplace_update:
|
||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
|
15
comfy/sd.py
15
comfy/sd.py
@ -440,6 +440,10 @@ class VAE:
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
@ -495,6 +499,7 @@ class VAE:
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
@ -525,6 +530,7 @@ class VAE:
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
dims = samples.ndim - 2
|
||||
@ -553,6 +559,7 @@ class VAE:
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||
@ -585,6 +592,7 @@ class VAE:
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
@ -899,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
return None
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
|
||||
if diffusion_model is None:
|
||||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user