Fix control lora not working in fp32.

This commit is contained in:
comfyanonymous 2023-08-21 20:38:31 -04:00
parent bc76b3829f
commit 763b0cf024

View File

@ -926,8 +926,8 @@ class ControlLora(ControlNet):
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
self.control_model = cldm.ControlNet(**controlnet_config)
if model_management.should_use_fp16():
self.control_model.half()
dtype = model.get_dtype()
self.control_model.to(dtype)
self.control_model.to(model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
@ -947,7 +947,7 @@ class ControlLora(ControlNet):
for k in self.control_weights:
if k not in {"lora_controlnet"}:
set_attr(self.control_model, k, self.control_weights[k].to(model_management.get_torch_device()))
set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(model_management.get_torch_device()))
def copy(self):
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)