mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
766c7b3815
Don't add SRFormer because the code license is incompatible with the GPL. Remove MAT because it's unused and the license is incompatible with GPL.
100 lines
3.4 KiB
Python
100 lines
3.4 KiB
Python
import logging as logger
|
|
|
|
from .architecture.DAT import DAT
|
|
from .architecture.face.codeformer import CodeFormer
|
|
from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
|
|
from .architecture.face.restoreformer_arch import RestoreFormer
|
|
from .architecture.HAT import HAT
|
|
from .architecture.LaMa import LaMa
|
|
from .architecture.OmniSR.OmniSR import OmniSR
|
|
from .architecture.RRDB import RRDBNet as ESRGAN
|
|
from .architecture.SCUNet import SCUNet
|
|
from .architecture.SPSR import SPSRNet as SPSR
|
|
from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
|
|
from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
|
|
from .architecture.Swin2SR import Swin2SR
|
|
from .architecture.SwinIR import SwinIR
|
|
from .types import PyTorchModel
|
|
|
|
|
|
class UnsupportedModel(Exception):
|
|
pass
|
|
|
|
|
|
def load_state_dict(state_dict) -> PyTorchModel:
|
|
logger.debug(f"Loading state dict into pytorch model arch")
|
|
|
|
state_dict_keys = list(state_dict.keys())
|
|
|
|
if "params_ema" in state_dict_keys:
|
|
state_dict = state_dict["params_ema"]
|
|
elif "params-ema" in state_dict_keys:
|
|
state_dict = state_dict["params-ema"]
|
|
elif "params" in state_dict_keys:
|
|
state_dict = state_dict["params"]
|
|
|
|
state_dict_keys = list(state_dict.keys())
|
|
# SRVGGNet Real-ESRGAN (v2)
|
|
if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
|
|
model = RealESRGANv2(state_dict)
|
|
# SPSR (ESRGAN with lots of extra layers)
|
|
elif "f_HR_conv1.0.weight" in state_dict:
|
|
model = SPSR(state_dict)
|
|
# Swift-SRGAN
|
|
elif (
|
|
"model" in state_dict_keys
|
|
and "initial.cnn.depthwise.weight" in state_dict["model"].keys()
|
|
):
|
|
model = SwiftSRGAN(state_dict)
|
|
# SwinIR, Swin2SR, HAT
|
|
elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys:
|
|
if (
|
|
"layers.0.residual_group.blocks.0.conv_block.cab.0.weight"
|
|
in state_dict_keys
|
|
):
|
|
model = HAT(state_dict)
|
|
elif "patch_embed.proj.weight" in state_dict_keys:
|
|
model = Swin2SR(state_dict)
|
|
else:
|
|
model = SwinIR(state_dict)
|
|
# GFPGAN
|
|
elif (
|
|
"toRGB.0.weight" in state_dict_keys
|
|
and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
|
|
):
|
|
model = GFPGANv1Clean(state_dict)
|
|
# RestoreFormer
|
|
elif (
|
|
"encoder.conv_in.weight" in state_dict_keys
|
|
and "encoder.down.0.block.0.norm1.weight" in state_dict_keys
|
|
):
|
|
model = RestoreFormer(state_dict)
|
|
elif (
|
|
"encoder.blocks.0.weight" in state_dict_keys
|
|
and "quantize.embedding.weight" in state_dict_keys
|
|
):
|
|
model = CodeFormer(state_dict)
|
|
# LaMa
|
|
elif (
|
|
"model.model.1.bn_l.running_mean" in state_dict_keys
|
|
or "generator.model.1.bn_l.running_mean" in state_dict_keys
|
|
):
|
|
model = LaMa(state_dict)
|
|
# Omni-SR
|
|
elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
|
|
model = OmniSR(state_dict)
|
|
# SCUNet
|
|
elif "m_head.0.weight" in state_dict_keys and "m_tail.0.weight" in state_dict_keys:
|
|
model = SCUNet(state_dict)
|
|
# DAT
|
|
elif "layers.0.blocks.2.attn.attn_mask_0" in state_dict_keys:
|
|
model = DAT(state_dict)
|
|
# Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
|
|
else:
|
|
try:
|
|
model = ESRGAN(state_dict)
|
|
except:
|
|
# pylint: disable=raise-missing-from
|
|
raise UnsupportedModel
|
|
return model
|