From 0aa667ed33aae800880153a91c283ac457d0b31c Mon Sep 17 00:00:00 2001
From: comfyanonymous <comfyanonymous@protonmail.com>
Date: Sun, 30 Apr 2023 17:28:55 -0400
Subject: [PATCH] Fix ConditioningAverage.

---
 nodes.py | 25 +++++++++++++++++--------
 1 file changed, 17 insertions(+), 8 deletions(-)

diff --git a/nodes.py b/nodes.py
index fc3d2f18..53e0f74b 100644
--- a/nodes.py
+++ b/nodes.py
@@ -62,21 +62,30 @@ class ConditioningCombine:
 class ConditioningAverage :
     @classmethod
     def INPUT_TYPES(s):
-        return {"required": {"conditioning_from": ("CONDITIONING", ), "conditioning_to": ("CONDITIONING", ),
-                              "conditioning_from_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.1})
+        return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
+                              "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
                              }}
     RETURN_TYPES = ("CONDITIONING",)
     FUNCTION = "addWeighted"
 
     CATEGORY = "conditioning"
 
-    def addWeighted(self, conditioning_from, conditioning_to, conditioning_from_strength):
+    def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
         out = []
-        for i in range(min(len(conditioning_from),len(conditioning_to))):
-            t0 = conditioning_from[i]
-            t1 = conditioning_to[i]
-            tw = torch.mul(t0[0],(1-conditioning_from_strength)) + torch.mul(t1[0],conditioning_from_strength)
-            n = [tw, t0[1].copy()]
+
+        if len(conditioning_from) > 1:
+            print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
+
+        cond_from = conditioning_from[0][0]
+
+        for i in range(len(conditioning_to)):
+            t1 = conditioning_to[i][0]
+            t0 = cond_from[:,:t1.shape[1]]
+            if t0.shape[1] < t1.shape[1]:
+                t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
+
+            tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
+            n = [tw, conditioning_to[i][1].copy()]
             out.append(n)
         return (out, )