No need to check filename extensions to detect shuffle controlnet.

This commit is contained in:
comfyanonymous 2023-08-28 16:49:06 -04:00
parent 4e89b2c25a
commit 65cae62c71

View File

@ -1,5 +1,6 @@
import torch
import math
import os
import comfy.utils
import comfy.model_management
import comfy.model_detection
@ -386,7 +387,8 @@ def load_controlnet(ckpt_path, model=None):
control_model = control_model.half()
global_average_pooling = False
if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling)