diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 4a58c823f..1d7294bd6 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -361,7 +361,7 @@ class ControlNet(nn.Module): controlnet_cond = self.input_hint_block(hint[idx], emb, context) feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) if idx < len(control_type): - feat_seq += self.task_embedding[control_type[idx]] + feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device) inputs.append(feat_seq.unsqueeze(1)) condition_list.append(controlnet_cond)