diff --git a/comfy/model_management.py b/comfy/model_management.py index dcfd57b5..574fbf21 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -456,7 +456,13 @@ def mps_mode(): def is_device_cpu(device): if hasattr(device, 'type'): - if (device.type == 'cpu' or device.type == 'mps'): + if (device.type == 'cpu'): + return True + return False + +def is_device_mps(device): + if hasattr(device, 'type'): + if (device.type == 'mps'): return True return False @@ -468,7 +474,7 @@ def should_use_fp16(device=None, model_params=0): return True if device is not None: #TODO - if is_device_cpu(device): + if is_device_cpu(device) or is_device_mps(device): return False if FORCE_FP32: