Keep a set of model_keys for faster add_patches.

This commit is contained in:
comfyanonymous 2023-06-20 19:08:48 -04:00
parent 45beebd33c
commit 8125b51a62

View File

@ -302,12 +302,14 @@ class ModelPatcher:
t = model_sd[k] t = model_sd[k]
size += t.nelement() * t.element_size() size += t.nelement() * t.element_size()
self.size = size self.size = size
self.model_keys = set(model_sd.keys())
return size return size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.size) n = ModelPatcher(self.model, self.size)
n.patches = self.patches[:] n.patches = self.patches[:]
n.model_options = copy.deepcopy(self.model_options) n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n return n
def set_model_tomesd(self, ratio): def set_model_tomesd(self, ratio):
@ -349,9 +351,8 @@ class ModelPatcher:
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = {} p = {}
model_sd = self.model.state_dict()
for k in patches: for k in patches:
if k in model_sd: if k in self.model_keys:
p[k] = patches[k] p[k] = patches[k]
self.patches += [(strength_patch, p, strength_model)] self.patches += [(strength_patch, p, strength_model)]
return p.keys() return p.keys()
@ -365,7 +366,7 @@ class ModelPatcher:
return sd return sd
def patch_model(self): def patch_model(self):
model_sd = self.model.state_dict() model_sd = self.model_state_dict()
for p in self.patches: for p in self.patches:
for k in p[1]: for k in p[1]:
v = p[1][k] v = p[1][k]