added docstrings to parse_parentheses() and token_weights()

This commit is contained in:
N3ther 2024-08-27 05:38:07 -04:00
parent 2ca8f6e23d
commit 8a2abb857b

View File

@ -232,6 +232,27 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self.transformer.load_state_dict(sd, strict=False)
def parse_parentheses(string):
"""
Split a string based off top-level nested parentheses.
Parameters
----------
string : string
The string to be split into its top nested groups.
Returns
-------
result : list
A list of each element in string, split by top-level elements
Examples
--------
>>> string = "(foo)(bar)"
['(foo)', '(bar)']
>>> string = "(foo(bar)(test1))(test2(test3))"
['(foo(bar)(test1))', '(test2(test3))']
"""
result = []
current_item = ""
nesting_level = 0
@ -260,6 +281,55 @@ def parse_parentheses(string):
return result
def token_weights(string, current_weight):
"""
Find the requested weight of a token, and multiply it by the current weight. For parentheses groupings with no set weight, multiply by 1.1.
Parameters
----------
string : string
The input of tokens to calculate requested weights for
current_weight : float
The current weight of all tokens
Returns
-------
out : list
A list of each token paired with the calculated weight
Examples
--------
>>> string = "(foo)"
>>> current_weight = 1.0
[('foo', 1.1)]
>>> string = "(foo)(bar)"
>>> current_weight = 1.0
[('foo', 1.1), ('bar', 1.1)]
>>> string = "(foo:2.0)"
>>> current_weight = 1.0
[('foo', 2.0)]
>>> string = "((foo))"
>>> current_weight = 1.0
[('foo', 1.21)]
>>> string = "((foo):1.1)"
>>> current_weight = 1.0
[('foo', 1.21)]
>>> string = "((foo:1.1))"
>>> current_weight = 1.0
[('foo', 1.1)]
>>> string = "(foo:0.0)"
>>> current_weight = 1.0
[('foo', 0.0)]
>>> string = "((foo:1.0):0.0)"
>>> current_weight = 1.0
[('foo', 1.0)]
"""
a = parse_parentheses(string)
out = []
for x in a:
@ -267,6 +337,7 @@ def token_weights(string, current_weight):
if len(x) >= 2 and x[-1] == ')' and x[0] == '(':
x = x[1:-1]
xx = x.rfind(":")
# This line makes *all nestings* multiply the weight by 1.1
weight *= 1.1
if xx > 0:
try: