From 334ba48cea2961994e92c2fb25de9417b19897ed Mon Sep 17 00:00:00 2001
From: comfyanonymous <comfyanonymous@protonmail.com>
Date: Tue, 23 Jul 2024 14:13:32 -0400
Subject: [PATCH] More generic unet prefix detection code.

---
 comfy/model_detection.py | 21 +++++++++++++++------
 1 file changed, 15 insertions(+), 6 deletions(-)

diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index c62e2b82..ae88eeb9 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -261,13 +261,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
         return model_config
 
 def unet_prefix_from_state_dict(state_dict):
-    if "model.model.postprocess_conv.weight" in state_dict: #audio models
-        unet_key_prefix = "model.model."
-    elif "model.double_layers.0.attn.w1q.weight" in state_dict: #aura flow
-        unet_key_prefix = "model."
+    candidates = ["model.diffusion_model.", #ldm/sgm models
+                  "model.model.", #audio models
+                  ]
+    counts = {k: 0 for k in candidates}
+    for k in state_dict:
+        for c in candidates:
+            if k.startswith(c):
+                counts[c] += 1
+                break
+
+    top = max(counts, key=counts.get)
+    if counts[top] > 5:
+        return top
     else:
-        unet_key_prefix = "model.diffusion_model."
-    return unet_key_prefix
+        return "model." #aura flow and others
+
 
 def convert_config(unet_config):
     new_config = unet_config.copy()