Greatly improve lowvram sampling speed by getting rid of accelerate.

Let me know if this breaks anything.
This commit is contained in:
comfyanonymous 2023-12-22 14:24:04 -05:00
parent 261bcbb0d9
commit 36a7953142
5 changed files with 103 additions and 50 deletions

View File

@ -283,7 +283,7 @@ class ControlLora(ControlNet):
cm = self.control_model.state_dict() cm = self.control_model.state_dict()
for k in sd: for k in sd:
weight = comfy.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k) weight = sd[k]
try: try:
comfy.utils.set_attr(self.control_model, k, weight) comfy.utils.set_attr(self.control_model, k, weight)
except: except:

View File

@ -162,11 +162,7 @@ class BaseModel(torch.nn.Module):
def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
unet_sd = self.diffusion_model.state_dict() unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = {}
for k in unet_sd:
unet_state_dict[k] = comfy.model_management.resolve_lowvram_weight(unet_sd[k], self.diffusion_model, k)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)
if self.get_dtype() == torch.float16: if self.get_dtype() == torch.float16:

View File

@ -218,15 +218,8 @@ if args.force_fp16:
FORCE_FP16 = True FORCE_FP16 = True
if lowvram_available: if lowvram_available:
try:
import accelerate
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to vram_state = set_vram_to
except Exception as e:
import traceback
print(traceback.format_exc())
print("ERROR: LOW VRAM MODE NEEDS accelerate.")
lowvram_available = False
if cpu_state != CPUState.GPU: if cpu_state != CPUState.GPU:
@ -298,8 +291,20 @@ class LoadedModel:
if lowvram_model_memory > 0: if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024)) print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
device_map = accelerate.infer_auto_device_map(self.real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) mem_counter = 0
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device) for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
module_mem = 0
sd = m.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
if mem_counter + module_mem < lowvram_model_memory:
m.to(self.device)
mem_counter += module_mem
self.model_accelerated = True self.model_accelerated = True
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
@ -309,7 +314,11 @@ class LoadedModel:
def model_unload(self): def model_unload(self):
if self.model_accelerated: if self.model_accelerated:
accelerate.hooks.remove_hook_from_submodules(self.real_model) for m in self.real_model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
self.model_accelerated = False self.model_accelerated = False
self.model.unpatch_model(self.model.offload_device) self.model.unpatch_model(self.model.offload_device)
@ -402,14 +411,14 @@ def load_models_gpu(models, memory_required=0):
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary if model_size > (current_free_mem - inference_memory): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM vram_set_state = VRAMState.LOW_VRAM
else: else:
lowvram_model_memory = 0 lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 256 * 1024 * 1024 lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory) cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
@ -566,6 +575,11 @@ def supports_dtype(device, dtype): #TODO
return True return True
return False return False
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
return True
def cast_to_device(tensor, device, dtype, copy=False): def cast_to_device(tensor, device, dtype, copy=False):
device_supports_cast = False device_supports_cast = False
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
@ -576,9 +590,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu(): elif is_intel_xpu():
device_supports_cast = True device_supports_cast = True
non_blocking = True non_blocking = device_supports_non_blocking(device)
if is_device_mps(device):
non_blocking = False #pytorch bug? mps doesn't support non blocking
if device_supports_cast: if device_supports_cast:
if copy: if copy:
@ -742,11 +754,7 @@ def soft_empty_cache(force=False):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key): def resolve_lowvram_weight(weight, model, key): #TODO: remove
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
return weight return weight
#TODO: might be cleaner to put this somewhere else #TODO: might be cleaner to put this somewhere else

View File

@ -1,27 +1,93 @@
import torch import torch
from contextlib import contextmanager from contextlib import contextmanager
import comfy.model_management
def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
return weight, bias
class disable_weight_init: class disable_weight_init:
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class Conv3d(torch.nn.Conv3d): class Conv3d(torch.nn.Conv3d):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class GroupNorm(torch.nn.GroupNorm): class GroupNorm(torch.nn.GroupNorm):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm): class LayerNorm(torch.nn.LayerNorm):
comfy_cast_weights = False
def reset_parameters(self): def reset_parameters(self):
return None return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod @classmethod
def conv_nd(s, dims, *args, **kwargs): def conv_nd(s, dims, *args, **kwargs):
if dims == 2: if dims == 2:
@ -31,35 +97,19 @@ class disable_weight_init:
else: else:
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")
def cast_bias_weight(s, input):
bias = None
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype)
weight = s.weight.to(device=input.device, dtype=input.dtype)
return weight, bias
class manual_cast(disable_weight_init): class manual_cast(disable_weight_init):
class Linear(disable_weight_init.Linear): class Linear(disable_weight_init.Linear):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(disable_weight_init.Conv2d): class Conv2d(disable_weight_init.Conv2d):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class Conv3d(disable_weight_init.Conv3d): class Conv3d(disable_weight_init.Conv3d):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
class GroupNorm(disable_weight_init.GroupNorm): class GroupNorm(disable_weight_init.GroupNorm):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
class LayerNorm(disable_weight_init.LayerNorm): class LayerNorm(disable_weight_init.LayerNorm):
def forward(self, input): comfy_cast_weights = True
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

View File

@ -4,7 +4,6 @@ einops
transformers>=4.25.1 transformers>=4.25.1
safetensors>=0.3.0 safetensors>=0.3.0
aiohttp aiohttp
accelerate
pyyaml pyyaml
Pillow Pillow
scipy scipy