From 853e96ada31bfd22640c740f9ee317dcfa52bc04 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 8 Feb 2023 14:05:31 -0500 Subject: [PATCH] Increase it/s by batching together some stuff sent to unet. --- comfy/model_management.py | 35 ++++++++-- comfy/samplers.py | 131 ++++++++++++++++++++++++++++---------- 2 files changed, 126 insertions(+), 40 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 97ffc2833..ece6db3f5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -7,6 +7,7 @@ NORMAL_VRAM = 3 accelerate_enabled = False vram_state = NORMAL_VRAM +total_vram = 0 total_vram_available_mb = -1 import sys @@ -17,6 +18,12 @@ if "--lowvram" in sys.argv: if "--novram" in sys.argv: set_vram_to = NO_VRAM +try: + import torch + total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) +except: + pass + if set_vram_to != NORMAL_VRAM: try: import accelerate @@ -26,12 +33,8 @@ if set_vram_to != NORMAL_VRAM: import traceback print(traceback.format_exc()) print("ERROR: COULD NOT ENABLE LOW VRAM MODE.") - try: - import torch - total_vram_available_mb = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) - except: - pass - total_vram_available_mb = (total_vram_available_mb - 1024) // 2 + + total_vram_available_mb = (total_vram - 1024) // 2 total_vram_available_mb = int(max(256, total_vram_available_mb)) @@ -81,6 +84,26 @@ def load_model_gpu(model): device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) elif vram_state == LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) + print(device_map, "{}MiB".format(total_vram_available_mb)) accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda") model_accelerated = True return current_loaded_model + + +def get_free_memory(): + dev = torch.cuda.current_device() + stats = torch.cuda.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + return mem_free_cuda + mem_free_torch + +def maximum_batch_area(): + global vram_state + if vram_state == NO_VRAM: + return 0 + + memory_free = get_free_memory() / (1024 * 1024) + area = ((memory_free - 1024) * 0.9) / (0.6) + return int(max(area, 0)) diff --git a/comfy/samplers.py b/comfy/samplers.py index 84df79522..7ab57fc9f 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -2,6 +2,7 @@ import k_diffusion.sampling import k_diffusion.external import torch import contextlib +import model_management class CFGDenoiser(torch.nn.Module): def __init__(self, model): @@ -24,55 +25,117 @@ class CFGDenoiserComplex(torch.nn.Module): super().__init__() self.inner_model = model def forward(self, x, sigma, uncond, cond, cond_scale): - def calc_cond(cond, x_in, sigma): + def get_area_and_mult(cond, x_in, sigma): + area = (x_in.shape[2], x_in.shape[3], 0, 0) + strength = 1.0 + min_sigma = 0.0 + max_sigma = 999.0 + if 'area' in cond[1]: + area = cond[1]['area'] + if 'strength' in cond[1]: + strength = cond[1]['strength'] + if 'min_sigma' in cond[1]: + min_sigma = cond[1]['min_sigma'] + if 'max_sigma' in cond[1]: + max_sigma = cond[1]['max_sigma'] + if sigma < min_sigma or sigma > max_sigma: + return None + input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + mult = torch.ones_like(input_x) * strength + + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,area[2]+t:area[2]+1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] + area[2] - 1 - t:area[0] + area[2] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,area[3]+t:area[3]+1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1)) + return (input_x, mult, cond[0], area) + + def calc_cond_uncond_batch(cond, uncond, x_in, sigma, max_total_area): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 + + out_uncond = torch.zeros_like(x_in) + out_uncond_count = torch.ones_like(x_in)/100000.0 + sigma_cmp = sigma[0] + COND = 0 + UNCOND = 1 + to_run = [] for x in cond: - area = (x_in.shape[2], x_in.shape[3], 0, 0) - strength = 1.0 - min_sigma = 0.0 - max_sigma = 999.0 - if 'area' in x[1]: - area = x[1]['area'] - if 'strength' in x[1]: - strength = x[1]['strength'] - if 'min_sigma' in x[1]: - min_sigma = x[1]['min_sigma'] - if 'max_sigma' in x[1]: - max_sigma = x[1]['max_sigma'] - if sigma_cmp < min_sigma or sigma_cmp > max_sigma: + p = get_area_and_mult(x, x_in, sigma_cmp) + if p is None: continue - input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - mult = torch.ones_like(input_x) * strength - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,area[2]+t:area[2]+1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] + area[2] - 1 - t:area[0] + area[2] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,area[3]+t:area[3]+1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1)) + to_run += [(p, COND)] + for x in uncond: + p = get_area_and_mult(x, x_in, sigma_cmp) + if p is None: + continue - out_cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] += self.inner_model(input_x, sigma, cond=x[0]) * mult - out_count[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] += mult + to_run += [(p, UNCOND)] + + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch = [] + for x in range(len(to_run)): + if to_run[x][0][0].shape == first_shape: + if to_run[x][0][2].shape == first[0][2].shape: + to_batch += [x] + if (len(to_batch) * first_shape[0] * first_shape[2] * first_shape[3] >= max_total_area): + break + + to_batch.reverse() + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + area = [] + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x += [p[0]] + mult += [p[1]] + c += [p[2]] + area += [p[3]] + cond_or_uncond += [o[1]] + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = torch.cat(c) + sigma_ = torch.cat([sigma] * batch_chunks) + + output = self.inner_model(input_x, sigma_, cond=c).chunk(batch_chunks) del input_x + + for o in range(batch_chunks): + if cond_or_uncond[o] == COND: + out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + else: + out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] del mult out_cond /= out_count del out_count - return out_cond + out_uncond /= out_uncond_count + del out_uncond_count - cond = calc_cond(cond, x, sigma) - uncond = calc_cond(uncond, x, sigma) + return out_cond, out_uncond + + max_total_area = model_management.maximum_batch_area() + cond, uncond = calc_cond_uncond_batch(cond, uncond, x, sigma, max_total_area) return uncond + (cond - uncond) * cond_scale def simple_scheduler(model, steps):