diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 676653f7..a1856030 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -232,6 +232,27 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self.transformer.load_state_dict(sd, strict=False) def parse_parentheses(string): + """ + Split a string based off top-level nested parentheses. + + Parameters + ---------- + string : string + The string to be split into its top nested groups. + + Returns + ------- + result : list + A list of each element in string, split by top-level elements + + Examples + -------- + >>> string = "(foo)(bar)" + ['(foo)', '(bar)'] + + >>> string = "(foo(bar)(test1))(test2(test3))" + ['(foo(bar)(test1))', '(test2(test3))'] + """ result = [] current_item = "" nesting_level = 0 @@ -260,6 +281,55 @@ def parse_parentheses(string): return result def token_weights(string, current_weight): + """ + Find the requested weight of a token, and multiply it by the current weight. For parentheses groupings with no set weight, multiply by 1.1. + + Parameters + ---------- + string : string + The input of tokens to calculate requested weights for + current_weight : float + The current weight of all tokens + + Returns + ------- + out : list + A list of each token paired with the calculated weight + + Examples + -------- + >>> string = "(foo)" + >>> current_weight = 1.0 + [('foo', 1.1)] + + >>> string = "(foo)(bar)" + >>> current_weight = 1.0 + [('foo', 1.1), ('bar', 1.1)] + + >>> string = "(foo:2.0)" + >>> current_weight = 1.0 + [('foo', 2.0)] + + >>> string = "((foo))" + >>> current_weight = 1.0 + [('foo', 1.21)] + + >>> string = "((foo):1.1)" + >>> current_weight = 1.0 + [('foo', 1.21)] + + >>> string = "((foo:1.1))" + >>> current_weight = 1.0 + [('foo', 1.1)] + + >>> string = "(foo:0.0)" + >>> current_weight = 1.0 + [('foo', 0.0)] + + >>> string = "((foo:1.0):0.0)" + >>> current_weight = 1.0 + [('foo', 1.0)] + """ a = parse_parentheses(string) out = [] for x in a: @@ -267,6 +337,7 @@ def token_weights(string, current_weight): if len(x) >= 2 and x[-1] == ')' and x[0] == '(': x = x[1:-1] xx = x.rfind(":") + # This line makes *all nestings* multiply the weight by 1.1 weight *= 1.1 if xx > 0: try: