Keep a set of model_keys for faster add_patches.

This commit is contained in:
comfyanonymous 2023-06-20 19:08:48 -04:00
parent 45beebd33c
commit 8125b51a62

View File

@ -302,12 +302,14 @@ class ModelPatcher:
t = model_sd[k]
size += t.nelement() * t.element_size()
self.size = size
self.model_keys = set(model_sd.keys())
return size
def clone(self):
n = ModelPatcher(self.model, self.size)
n.patches = self.patches[:]
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
return n
def set_model_tomesd(self, ratio):
@ -349,9 +351,8 @@ class ModelPatcher:
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = {}
model_sd = self.model.state_dict()
for k in patches:
if k in model_sd:
if k in self.model_keys:
p[k] = patches[k]
self.patches += [(strength_patch, p, strength_model)]
return p.keys()
@ -365,7 +366,7 @@ class ModelPatcher:
return sd
def patch_model(self):
model_sd = self.model.state_dict()
model_sd = self.model_state_dict()
for p in self.patches:
for k in p[1]:
v = p[1][k]