diff --git a/comfy/ops.py b/comfy/ops.py index ce132911..418d59e1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -21,6 +21,8 @@ import comfy.model_management def cast_to(weight, dtype=None, device=None, non_blocking=False): + if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device): + return weight r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r