Add get_nested_additional_models for cases where additional_models could have their own additional_models, and add robustness for circular additional_models references

This commit is contained in:
Jedrzej Kosinski 2024-12-01 15:58:41 -06:00
parent bdde26b212
commit 000a21a0c1
2 changed files with 24 additions and 2 deletions

View File

@ -852,12 +852,34 @@ class ModelPatcher:
if key in self.additional_models:
self.additional_models.pop(key)
def get_all_additional_models(self):
def get_additional_models_with_key(self, key: str):
return self.additional_models.get(key, [])
def get_additional_models(self):
all_models = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
def get_nested_additional_models(self):
def _evaluate_sub_additional_models(prev_models: list[ModelPatcher], cache_set: set[ModelPatcher]):
'''Make sure circular references do not cause infinite recursion.'''
next_models = []
for model in prev_models:
candidates = model.get_additional_models()
for c in candidates:
if c not in cache_set:
next_models.append(c)
cache_set.add(c)
if len(next_models) == 0:
return prev_models
return prev_models + _evaluate_sub_additional_models(next_models, cache_set)
all_models = self.get_additional_models()
models_set = set(all_models)
real_all_models = _evaluate_sub_additional_models(prev_models=all_models, cache_set=models_set)
return real_all_models
def use_ejected(self, skip_and_inject_on_exit_only=False):
return AutoPatcherEjector(self, skip_and_inject_on_exit_only=skip_and_inject_on_exit_only)

View File

@ -107,7 +107,7 @@ def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
device = model.load_device
real_model: 'BaseModel' = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += model.get_all_additional_models() # TODO: does this require inference_memory update?
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)