diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py index c85adc7a..b2a2f1bd 100644 --- a/comfy/weight_adapter/boft.py +++ b/comfy/weight_adapter/boft.py @@ -24,7 +24,7 @@ class BOFTAdapter(WeightAdapterBase): ) -> Optional["BOFTAdapter"]: if loaded_keys is None: loaded_keys = set() - blocks_name = "{}.boft_blocks".format(x) + blocks_name = "{}.oft_blocks".format(x) rescale_name = "{}.rescale".format(x) blocks = None @@ -32,17 +32,18 @@ class BOFTAdapter(WeightAdapterBase): blocks = lora[blocks_name] if blocks.ndim == 4: loaded_keys.add(blocks_name) + else: + blocks = None + if blocks is None: + return None rescale = None if rescale_name in lora.keys(): rescale = lora[rescale_name] loaded_keys.add(rescale_name) - if blocks is not None: - weights = (blocks, rescale, alpha, dora_scale) - return cls(loaded_keys, weights) - else: - return None + weights = (blocks, rescale, alpha, dora_scale) + return cls(loaded_keys, weights) def calculate_weight( self, @@ -71,7 +72,7 @@ class BOFTAdapter(WeightAdapterBase): # Get r I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype) # for Q = -Q^T - q = blocks - blocks.transpose(1, 2) + q = blocks - blocks.transpose(-1, -2) normed_q = q if alpha > 0: # alpha in boft/bboft is for constraint q_norm = torch.norm(q) + 1e-8 @@ -79,9 +80,8 @@ class BOFTAdapter(WeightAdapterBase): normed_q = q * alpha / q_norm # use float() to prevent unsupported type in .inverse() r = (I + normed_q) @ (I - normed_q).float().inverse() - r = r.to(original_weight) - - inp = org = original_weight + r = r.to(weight) + inp = org = weight r_b = boft_b//2 for i in range(boft_m): @@ -91,14 +91,14 @@ class BOFTAdapter(WeightAdapterBase): if strength != 1: bi = bi * strength + (1-strength) * I inp = ( - inp.unflatten(-1, (-1, g, k)) - .transpose(-2, -1) - .flatten(-3) - .unflatten(-1, (-1, boft_b)) + inp.unflatten(0, (-1, g, k)) + .transpose(1, 2) + .flatten(0, 2) + .unflatten(0, (-1, boft_b)) ) - inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi) + inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp) inp = ( - inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2) ) if rescale is not None: @@ -109,7 +109,7 @@ class BOFTAdapter(WeightAdapterBase): if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index 0ea229b7..25009eca 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -32,17 +32,18 @@ class OFTAdapter(WeightAdapterBase): blocks = lora[blocks_name] if blocks.ndim == 3: loaded_keys.add(blocks_name) + else: + blocks = None + if blocks is None: + return None rescale = None if rescale_name in lora.keys(): rescale = lora[rescale_name] loaded_keys.add(rescale_name) - if blocks is not None: - weights = (blocks, rescale, alpha, dora_scale) - return cls(loaded_keys, weights) - else: - return None + weights = (blocks, rescale, alpha, dora_scale) + return cls(loaded_keys, weights) def calculate_weight( self, @@ -79,16 +80,17 @@ class OFTAdapter(WeightAdapterBase): normed_q = q * alpha / q_norm # use float() to prevent unsupported type in .inverse() r = (I + normed_q) @ (I - normed_q).float().inverse() - r = r.to(original_weight) + r = r.to(weight) + _, *shape = weight.shape lora_diff = torch.einsum( "k n m, k n ... -> k m ...", (r * strength) - strength * I, - original_weight, - ) + weight.view(block_num, block_size, *shape), + ).view(-1, *shape) if dora_scale is not None: weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight