mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Fix some issues.
This commit is contained in:
parent
1e68002b87
commit
2ba5cc8b86
@ -450,8 +450,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required))
|
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.5)
|
||||||
lowvram_model_memory = int(min(current_free_mem * 0.5, lowvram_model_memory))
|
|
||||||
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
|
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = 0
|
||||||
|
|
||||||
|
@ -517,7 +517,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
|
if weight_dtype is not None:
|
||||||
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
|
@ -51,6 +51,9 @@ def weight_dtype(sd, prefix=""):
|
|||||||
w = sd[k]
|
w = sd[k]
|
||||||
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
|
||||||
|
|
||||||
|
if len(dtypes) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
return max(dtypes, key=dtypes.get)
|
return max(dtypes, key=dtypes.get)
|
||||||
|
|
||||||
def state_dict_key_replace(state_dict, keys_to_replace):
|
def state_dict_key_replace(state_dict, keys_to_replace):
|
||||||
|
Loading…
Reference in New Issue
Block a user