mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 02:15:17 +00:00
Support new flux model variants.
This commit is contained in:
parent
41444b5236
commit
8f0009aad0
@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
|
|||||||
|
|
||||||
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
|
||||||
"gelu": torch.nn.functional.gelu,
|
"gelu": torch.nn.functional.gelu,
|
||||||
|
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||||
}
|
}
|
||||||
|
|
||||||
class CLIPMLP(torch.nn.Module):
|
class CLIPMLP(torch.nn.Module):
|
||||||
@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||||
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
|
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
|
||||||
|
num_patches = (image_size // patch_size) ** 2
|
||||||
|
if model_type == "siglip_vision_model":
|
||||||
|
self.class_embedding = None
|
||||||
|
patch_bias = True
|
||||||
|
else:
|
||||||
|
num_patches = num_patches + 1
|
||||||
|
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
|
||||||
|
patch_bias = False
|
||||||
|
|
||||||
self.patch_embedding = operations.Conv2d(
|
self.patch_embedding = operations.Conv2d(
|
||||||
in_channels=num_channels,
|
in_channels=num_channels,
|
||||||
out_channels=embed_dim,
|
out_channels=embed_dim,
|
||||||
kernel_size=patch_size,
|
kernel_size=patch_size,
|
||||||
stride=patch_size,
|
stride=patch_size,
|
||||||
bias=False,
|
bias=patch_bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
num_patches = (image_size // patch_size) ** 2
|
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
|
||||||
num_positions = num_patches + 1
|
|
||||||
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
def forward(self, pixel_values):
|
||||||
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
||||||
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
if self.class_embedding is not None:
|
||||||
|
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
|
||||||
|
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
|
||||||
|
|
||||||
|
|
||||||
class CLIPVision(torch.nn.Module):
|
class CLIPVision(torch.nn.Module):
|
||||||
@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
|
|||||||
heads = config_dict["num_attention_heads"]
|
heads = config_dict["num_attention_heads"]
|
||||||
intermediate_size = config_dict["intermediate_size"]
|
intermediate_size = config_dict["intermediate_size"]
|
||||||
intermediate_activation = config_dict["hidden_act"]
|
intermediate_activation = config_dict["hidden_act"]
|
||||||
|
model_type = config_dict["model_type"]
|
||||||
|
|
||||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
|
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||||
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
if model_type == "siglip_vision_model":
|
||||||
|
self.pre_layrnorm = lambda a: a
|
||||||
|
self.output_layernorm = True
|
||||||
|
else:
|
||||||
|
self.pre_layrnorm = operations.LayerNorm(embed_dim)
|
||||||
|
self.output_layernorm = False
|
||||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||||
|
|
||||||
@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
|
|||||||
x = self.pre_layrnorm(x)
|
x = self.pre_layrnorm(x)
|
||||||
#TODO: attention_mask?
|
#TODO: attention_mask?
|
||||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
if self.output_layernorm:
|
||||||
|
x = self.post_layernorm(x)
|
||||||
|
pooled_output = x
|
||||||
|
else:
|
||||||
|
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||||
return x, i, pooled_output
|
return x, i, pooled_output
|
||||||
|
|
||||||
class CLIPVisionModelProjection(torch.nn.Module):
|
class CLIPVisionModelProjection(torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
|
||||||
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
if "projection_dim" in config_dict:
|
||||||
|
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
|
||||||
|
else:
|
||||||
|
self.visual_projection = lambda a: a
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
x = self.vision_model(*args, **kwargs)
|
x = self.vision_model(*args, **kwargs)
|
||||||
|
@ -16,9 +16,9 @@ class Output:
|
|||||||
def __setitem__(self, key, item):
|
def __setitem__(self, key, item):
|
||||||
setattr(self, key, item)
|
setattr(self, key, item)
|
||||||
|
|
||||||
def clip_preprocess(image, size=224):
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]):
|
||||||
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
image = image.movedim(-1, 1)
|
image = image.movedim(-1, 1)
|
||||||
if not (image.shape[2] == size and image.shape[3] == size):
|
if not (image.shape[2] == size and image.shape[3] == size):
|
||||||
scale = (size / min(image.shape[2], image.shape[3]))
|
scale = (size / min(image.shape[2], image.shape[3]))
|
||||||
@ -35,6 +35,8 @@ class ClipVisionModel():
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
self.image_size = config.get("image_size", 224)
|
self.image_size = config.get("image_size", 224)
|
||||||
|
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||||
|
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
|
||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
@ -51,7 +53,7 @@ class ClipVisionModel():
|
|||||||
|
|
||||||
def encode_image(self, image):
|
def encode_image(self, image):
|
||||||
comfy.model_management.load_model_gpu(self.patcher)
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float()
|
||||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||||
|
|
||||||
outputs = Output()
|
outputs = Output()
|
||||||
@ -94,7 +96,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
|
||||||
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
|
||||||
|
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
|
||||||
else:
|
else:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
|
||||||
|
13
comfy/clip_vision_siglip_384.json
Normal file
13
comfy/clip_vision_siglip_384.json
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"num_channels": 3,
|
||||||
|
"hidden_act": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"image_size": 384,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"model_type": "siglip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 27,
|
||||||
|
"patch_size": 14,
|
||||||
|
"image_mean": [0.5, 0.5, 0.5],
|
||||||
|
"image_std": [0.5, 0.5, 0.5]
|
||||||
|
}
|
@ -20,6 +20,7 @@ import comfy.ldm.common_dit
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FluxParams:
|
class FluxParams:
|
||||||
in_channels: int
|
in_channels: int
|
||||||
|
out_channels: int
|
||||||
vec_in_dim: int
|
vec_in_dim: int
|
||||||
context_in_dim: int
|
context_in_dim: int
|
||||||
hidden_size: int
|
hidden_size: int
|
||||||
@ -29,6 +30,7 @@ class FluxParams:
|
|||||||
depth_single_blocks: int
|
depth_single_blocks: int
|
||||||
axes_dim: list
|
axes_dim: list
|
||||||
theta: int
|
theta: int
|
||||||
|
patch_size: int
|
||||||
qkv_bias: bool
|
qkv_bias: bool
|
||||||
guidance_embed: bool
|
guidance_embed: bool
|
||||||
|
|
||||||
@ -43,8 +45,9 @@ class Flux(nn.Module):
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
params = FluxParams(**kwargs)
|
params = FluxParams(**kwargs)
|
||||||
self.params = params
|
self.params = params
|
||||||
self.in_channels = params.in_channels * 2 * 2
|
self.patch_size = params.patch_size
|
||||||
self.out_channels = self.in_channels
|
self.in_channels = params.in_channels * params.patch_size * params.patch_size
|
||||||
|
self.out_channels = params.out_channels * params.patch_size * params.patch_size
|
||||||
if params.hidden_size % params.num_heads != 0:
|
if params.hidden_size % params.num_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||||
@ -165,7 +168,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = 2
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
25
comfy/ldm/flux/redux.py
Normal file
25
comfy/ldm/flux/redux.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
ops = comfy.ops.manual_cast
|
||||||
|
|
||||||
|
class ReduxImageEncoder(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redux_dim: int = 1152,
|
||||||
|
txt_in_features: int = 4096,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.redux_dim = redux_dim
|
||||||
|
self.device = device
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||||
|
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||||
|
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
|
||||||
|
return projected_x
|
17
comfy/lora_convert.py
Normal file
17
comfy/lora_convert.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||||
|
sd_out = {}
|
||||||
|
for k in sd:
|
||||||
|
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||||
|
sd_out[k_to] = sd[k]
|
||||||
|
|
||||||
|
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||||
|
return sd_out
|
||||||
|
|
||||||
|
|
||||||
|
def convert_lora(sd):
|
||||||
|
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
|
||||||
|
return convert_lora_bfl_control(sd)
|
||||||
|
return sd
|
@ -710,6 +710,38 @@ class Flux(BaseModel):
|
|||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||||
|
|
||||||
|
def concat_cond(self, **kwargs):
|
||||||
|
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
|
||||||
|
out_channels = self.model_config.unet_config["out_channels"]
|
||||||
|
|
||||||
|
if num_channels <= out_channels:
|
||||||
|
return None
|
||||||
|
|
||||||
|
image = kwargs.get("concat_latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
image = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
|
image = self.process_latent_in(image)
|
||||||
|
if num_channels <= out_channels * 2:
|
||||||
|
return image
|
||||||
|
|
||||||
|
#inpaint model
|
||||||
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.ones_like(noise)[:, :1]
|
||||||
|
|
||||||
|
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||||
|
print(mask.shape)
|
||||||
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||||
|
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||||
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
return torch.cat((image, mask), dim=1)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return kwargs["pooled_output"]
|
return kwargs["pooled_output"]
|
||||||
|
|
||||||
|
@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "flux"
|
dit_config["image_model"] = "flux"
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = 16
|
||||||
|
patch_size = 2
|
||||||
|
dit_config["patch_size"] = patch_size
|
||||||
|
in_key = "{}img_in.weight".format(key_prefix)
|
||||||
|
if in_key in state_dict_keys:
|
||||||
|
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||||
|
dit_config["out_channels"] = 16
|
||||||
dit_config["vec_in_dim"] = 768
|
dit_config["vec_in_dim"] = 768
|
||||||
dit_config["context_in_dim"] = 4096
|
dit_config["context_in_dim"] = 4096
|
||||||
dit_config["hidden_size"] = 3072
|
dit_config["hidden_size"] = 3072
|
||||||
|
@ -30,9 +30,12 @@ import comfy.text_encoders.genmo
|
|||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
|
import comfy.lora_convert
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.taesd.taesd
|
import comfy.taesd.taesd
|
||||||
|
|
||||||
|
import comfy.ldm.flux.redux
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
key_map = {}
|
key_map = {}
|
||||||
if model is not None:
|
if model is not None:
|
||||||
@ -40,6 +43,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||||||
if clip is not None:
|
if clip is not None:
|
||||||
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||||
|
|
||||||
|
lora = comfy.lora_convert.convert_lora(lora)
|
||||||
loaded = comfy.lora.load_lora(lora, key_map)
|
loaded = comfy.lora.load_lora(lora, key_map)
|
||||||
if model is not None:
|
if model is not None:
|
||||||
new_modelpatcher = model.clone()
|
new_modelpatcher = model.clone()
|
||||||
@ -433,6 +437,8 @@ def load_style_model(ckpt_path):
|
|||||||
keys = model_data.keys()
|
keys = model_data.keys()
|
||||||
if "style_embedding" in keys:
|
if "style_embedding" in keys:
|
||||||
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
|
||||||
|
elif "redux_down.weight" in keys:
|
||||||
|
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||||
else:
|
else:
|
||||||
raise Exception("invalid style model {}".format(ckpt_path))
|
raise Exception("invalid style model {}".format(ckpt_path))
|
||||||
model.load_state_dict(model_data)
|
model.load_state_dict(model_data)
|
||||||
|
Loading…
Reference in New Issue
Block a user