mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Use the lowvram cast_to function for everything.
This commit is contained in:
parent
7390ff3b1e
commit
67158994a4
@ -840,27 +840,21 @@ def force_channels_last():
|
|||||||
#TODO
|
#TODO
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||||
|
if device is None or weight.device == device:
|
||||||
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
return r
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
device_supports_cast = False
|
non_blocking = device_supports_non_blocking(device)
|
||||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
device_supports_cast = True
|
|
||||||
elif tensor.dtype == torch.bfloat16:
|
|
||||||
if hasattr(device, 'type') and device.type.startswith("cuda"):
|
|
||||||
device_supports_cast = True
|
|
||||||
elif is_intel_xpu():
|
|
||||||
device_supports_cast = True
|
|
||||||
|
|
||||||
non_blocking = device_should_use_non_blocking(device)
|
|
||||||
|
|
||||||
if device_supports_cast:
|
|
||||||
if copy:
|
|
||||||
if tensor.device == device:
|
|
||||||
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
|
|
||||||
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
|
|
||||||
else:
|
|
||||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
|
||||||
|
|
||||||
def xformers_enabled():
|
def xformers_enabled():
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
17
comfy/ops.py
17
comfy/ops.py
@ -20,19 +20,10 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
if device is None or weight.device == device:
|
|
||||||
if not copy:
|
|
||||||
if dtype is None or weight.dtype == dtype:
|
|
||||||
return weight
|
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
|
||||||
return r
|
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
if input is not None:
|
if input is not None:
|
||||||
@ -47,12 +38,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
has_function = s.bias_function is not None
|
has_function = s.bias_function is not None
|
||||||
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
|
|
||||||
has_function = s.weight_function is not None
|
has_function = s.weight_function is not None
|
||||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
Loading…
Reference in New Issue
Block a user