mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-25 15:55:18 +00:00
Make HyperTile deterministic
This commit is contained in:
parent
2db86b4676
commit
03eadbb53c
@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import random
|
# Use torch rng for consistency across generations
|
||||||
|
from torch import randint
|
||||||
|
|
||||||
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
|
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
|
||||||
min_value = min(min_value, value)
|
min_value = min(min_value, value)
|
||||||
|
|
||||||
# All big divisors of value (inclusive)
|
# All big divisors of value (inclusive)
|
||||||
@ -12,8 +13,7 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter
|
|||||||
|
|
||||||
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
|
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
|
||||||
|
|
||||||
random.seed(counter)
|
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
|
||||||
idx = random.randint(0, len(ns) - 1)
|
|
||||||
|
|
||||||
return ns[idx]
|
return ns[idx]
|
||||||
|
|
||||||
@ -42,7 +42,6 @@ class HyperTile:
|
|||||||
|
|
||||||
latent_tile_size = max(32, tile_size) // 8
|
latent_tile_size = max(32, tile_size) // 8
|
||||||
self.temp = None
|
self.temp = None
|
||||||
self.counter = 1
|
|
||||||
|
|
||||||
def hypertile_in(q, k, v, extra_options):
|
def hypertile_in(q, k, v, extra_options):
|
||||||
if q.shape[-1] in apply_to:
|
if q.shape[-1] in apply_to:
|
||||||
@ -53,10 +52,8 @@ class HyperTile:
|
|||||||
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
|
||||||
|
|
||||||
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
|
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
|
||||||
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
|
nh = random_divisor(h, latent_tile_size * factor, swap_size)
|
||||||
self.counter += 1
|
nw = random_divisor(w, latent_tile_size * factor, swap_size)
|
||||||
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
|
|
||||||
self.counter += 1
|
|
||||||
|
|
||||||
if nh * nw > 1:
|
if nh * nw > 1:
|
||||||
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
|
||||||
|
Loading…
Reference in New Issue
Block a user