diff --git a/comfy/utils.py b/comfy/utils.py index 985cd9a1b..3add621ff 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -46,7 +46,13 @@ def load_torch_file(ckpt, safe_load=False, device=None): if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: - sd = pl_sd + if len(pl_sd) == 1: + key = list(pl_sd.keys())[0] + sd = pl_sd[key] + if not isinstance(sd, dict): + sd = pl_sd + else: + sd = pl_sd return sd def save_torch_file(sd, ckpt, metadata=None):