Use the lowvram cast_to function for everything.

This commit is contained in:
comfyanonymous 2024-10-17 17:25:56 -04:00
parent 7390ff3b1e
commit 67158994a4
2 changed files with 17 additions and 32 deletions

View File

@ -840,27 +840,21 @@ def force_channels_last():
#TODO #TODO
return False return False
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_device(tensor, device, dtype, copy=False): def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False non_blocking = device_supports_non_blocking(device)
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
device_supports_cast = True
elif tensor.dtype == torch.bfloat16:
if hasattr(device, 'type') and device.type.startswith("cuda"):
device_supports_cast = True
elif is_intel_xpu():
device_supports_cast = True
non_blocking = device_should_use_non_blocking(device)
if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_enabled

View File

@ -20,19 +20,10 @@ import torch
import comfy.model_management import comfy.model_management
from comfy.cli_args import args from comfy.cli_args import args
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): cast_to = comfy.model_management.cast_to #TODO: remove once no more references
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None: if input is not None:
@ -47,12 +38,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None: if s.bias is not None:
has_function = s.bias_function is not None has_function = s.bias_function is not None
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function: if has_function:
bias = s.bias_function(bias) bias = s.bias_function(bias)
has_function = s.weight_function is not None has_function = s.weight_function is not None
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function: if has_function:
weight = s.weight_function(weight) weight = s.weight_function(weight)
return weight, bias return weight, bias