Make it easy for models to process the unet state dict on load.

This commit is contained in:
comfyanonymous 2023-11-20 22:27:36 -05:00
parent 2dd5b4dd78
commit ce67dcbcda
2 changed files with 4 additions and 0 deletions

View File

@ -121,6 +121,7 @@ class BaseModel(torch.nn.Module):
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)

View File

@ -53,6 +53,9 @@ class BASE:
def process_clip_state_dict(self, state_dict):
return state_dict
def process_unet_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)