Fix multigpu ControlBase get_models and cleanup calls to avoid multiple calls of functions on multigpu_clones versions of controlnets

This commit is contained in:
Jedrzej Kosinski 2025-01-25 06:05:01 -06:00
parent 46969c380a
commit 51af7fa1b4

View File

@ -64,6 +64,18 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@ -112,7 +124,9 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
@ -120,19 +134,15 @@ class ControlBase:
def get_models(self):
out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models()
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
try:
orig_previous_controlnet = self.previous_controlnet
self.previous_controlnet = None
with ControlIsolation(self):
return self.get_models()
finally:
self.previous_controlnet = orig_previous_controlnet
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'