diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 1d24afa6..2c0305b4 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -60,7 +60,7 @@ class StrengthType(Enum): LINEAR_UP = 2 class ControlBase: - def __init__(self, device=None): + def __init__(self): self.cond_hint_original = None self.cond_hint = None self.strength = 1.0 @@ -72,10 +72,6 @@ class ControlBase: self.compression_ratio = 8 self.upscale_algorithm = 'nearest-exact' self.extra_args = {} - - if device is None: - device = comfy.model_management.get_torch_device() - self.device = device self.previous_controlnet = None self.extra_conds = [] self.strength_type = StrengthType.CONSTANT @@ -185,8 +181,8 @@ class ControlBase: class ControlNet(ControlBase): - def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False): - super().__init__(device) + def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False): + super().__init__() self.control_model = control_model self.load_device = load_device if control_model is not None: @@ -242,7 +238,7 @@ class ControlNet(ControlBase): to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0])) self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1) - self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) + self.cond_hint = self.cond_hint.to(device=self.load_device, dtype=dtype) if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) @@ -341,8 +337,8 @@ class ControlLoraOps: class ControlLora(ControlNet): - def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options - ControlBase.__init__(self, device) + def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options + ControlBase.__init__(self) self.control_weights = control_weights self.global_average_pooling = global_average_pooling self.extra_conds += ["y"] @@ -662,12 +658,15 @@ def load_controlnet(ckpt_path, model=None, model_options={}): class T2IAdapter(ControlBase): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): - super().__init__(device) + super().__init__() self.t2i_model = t2i_model self.channels_in = channels_in self.control_input = None self.compression_ratio = compression_ratio self.upscale_algorithm = upscale_algorithm + if device is None: + device = comfy.model_management.get_torch_device() + self.device = device def scale_image_to(self, width, height): unshuffle_amount = self.t2i_model.unshuffle_amount