From 06c661004ed67ef33a59df8de7136ad6f4542945 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 27 May 2025 12:09:05 -0700 Subject: [PATCH] Memory estimation code can now take into account conds. (#8307) --- comfy/conds.py | 8 ++++++++ comfy/model_base.py | 16 +++++++++++++--- comfy/sampler_helpers.py | 22 +++++++++++++++++++--- comfy/samplers.py | 8 +++++++- 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/comfy/conds.py b/comfy/conds.py index 211fb8d57..920e25488 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -24,6 +24,10 @@ class CONDRegular: conds.append(x.cond) return torch.cat(conds) + def size(self): + return list(self.cond.size()) + + class CONDNoiseShape(CONDRegular): def process_cond(self, batch_size, device, area, **kwargs): data = self.cond @@ -64,6 +68,7 @@ class CONDCrossAttn(CONDRegular): out.append(c) return torch.cat(out) + class CONDConstant(CONDRegular): def __init__(self, cond): self.cond = cond @@ -78,3 +83,6 @@ class CONDConstant(CONDRegular): def concat(self, others): return self.cond + + def size(self): + return [1] diff --git a/comfy/model_base.py b/comfy/model_base.py index fb4724690..cfd10d726 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -135,6 +135,7 @@ class BaseModel(torch.nn.Module): logging.info("model_type {}".format(model_type.name)) logging.debug("adm {}".format(self.adm_channels)) self.memory_usage_factor = model_config.memory_usage_factor + self.memory_usage_factor_conds = () def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -325,19 +326,28 @@ class BaseModel(torch.nn.Module): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image) - def memory_required(self, input_shape): + def memory_required(self, input_shape, cond_shapes={}): + input_shapes = [input_shape] + for c in self.memory_usage_factor_conds: + shape = cond_shapes.get(c, None) + if shape is not None and len(shape) > 0: + input_shapes += shape + if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): dtype = self.get_dtype() if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype #TODO: this needs to be tweaked - area = input_shape[0] * math.prod(input_shape[2:]) + area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. - area = input_shape[0] * math.prod(input_shape[2:]) + area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) + def extra_conds_shapes(self, **kwargs): + return {} + def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 96a3040a1..8dbc41455 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,5 +1,7 @@ from __future__ import annotations import uuid +import math +import collections import comfy.model_management import comfy.conds import comfy.utils @@ -104,6 +106,21 @@ def cleanup_additional_models(models): if hasattr(m, 'cleanup'): m.cleanup() +def estimate_memory(model, noise_shape, conds): + cond_shapes = collections.defaultdict(list) + cond_shapes_min = {} + for _, cs in conds.items(): + for cond in cs: + for k, v in model.model.extra_conds_shapes(**cond).items(): + cond_shapes[k].append(v) + if cond_shapes_min.get(k, None) is None: + cond_shapes_min[k] = [v] + elif math.prod(v) > math.prod(cond_shapes_min[k][0]): + cond_shapes_min[k] = [v] + + memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes) + minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) + return memory_required, minimum_memory_required def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None): executor = comfy.patcher_extension.WrapperExecutor.new_executor( @@ -117,9 +134,8 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? - memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory - minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required) + memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) + comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory) real_model = model.model return real_model, conds, models diff --git a/comfy/samplers.py b/comfy/samplers.py index 67ae09a25..efe9bf867 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -256,7 +256,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) * 1.5 < free_memory: + cond_shapes = collections.defaultdict(list) + for tt in batch_amount: + cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()} + for k, v in to_run[tt][0].conditioning.items(): + cond_shapes[k].append(v.size()) + + if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: to_batch = batch_amount break