diff --git a/comfy/model_management.py b/comfy/model_management.py index 2008229f..bb4bcbb2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -450,8 +450,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = loaded_model.model_memory_required(torch_dev) current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)) - lowvram_model_memory = int(min(current_free_mem * 0.5, lowvram_model_memory)) + lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), current_free_mem * 0.5) if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary lowvram_model_memory = 0 diff --git a/comfy/sd.py b/comfy/sd.py index bf336c85..fac1a487 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -517,7 +517,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes) + unet_weight_dtype = list(model_config.supported_inference_dtypes) + if weight_dtype is not None: + unet_weight_dtype.append(weight_dtype) + + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/utils.py b/comfy/utils.py index d9fe36f9..06e09170 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -51,6 +51,9 @@ def weight_dtype(sd, prefix=""): w = sd[k] dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1 + if len(dtypes) == 0: + return None + return max(dtypes, key=dtypes.get) def state_dict_key_replace(state_dict, keys_to_replace):