mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-01 17:18:37 +08:00
Compare commits
2 Commits
90d6a52cfe
...
ba55063ac5
Author | SHA1 | Date | |
---|---|---|---|
![]() |
ba55063ac5 | ||
![]() |
06c661004e |
@ -24,6 +24,10 @@ class CONDRegular:
|
|||||||
conds.append(x.cond)
|
conds.append(x.cond)
|
||||||
return torch.cat(conds)
|
return torch.cat(conds)
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return list(self.cond.size())
|
||||||
|
|
||||||
|
|
||||||
class CONDNoiseShape(CONDRegular):
|
class CONDNoiseShape(CONDRegular):
|
||||||
def process_cond(self, batch_size, device, area, **kwargs):
|
def process_cond(self, batch_size, device, area, **kwargs):
|
||||||
data = self.cond
|
data = self.cond
|
||||||
@ -64,6 +68,7 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
out.append(c)
|
out.append(c)
|
||||||
return torch.cat(out)
|
return torch.cat(out)
|
||||||
|
|
||||||
|
|
||||||
class CONDConstant(CONDRegular):
|
class CONDConstant(CONDRegular):
|
||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
@ -78,3 +83,6 @@ class CONDConstant(CONDRegular):
|
|||||||
|
|
||||||
def concat(self, others):
|
def concat(self, others):
|
||||||
return self.cond
|
return self.cond
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return [1]
|
||||||
|
@ -135,6 +135,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.info("model_type {}".format(model_type.name))
|
logging.info("model_type {}".format(model_type.name))
|
||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
|
self.memory_usage_factor_conds = ()
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
@ -325,19 +326,28 @@ class BaseModel(torch.nn.Module):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
|
||||||
|
|
||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape, cond_shapes={}):
|
||||||
|
input_shapes = [input_shape]
|
||||||
|
for c in self.memory_usage_factor_conds:
|
||||||
|
shape = cond_shapes.get(c, None)
|
||||||
|
if shape is not None and len(shape) > 0:
|
||||||
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
if self.manual_cast_dtype is not None:
|
if self.manual_cast_dtype is not None:
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = input_shape[0] * math.prod(input_shape[2:])
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
else:
|
else:
|
||||||
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
|
||||||
area = input_shape[0] * math.prod(input_shape[2:])
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
||||||
adm_inputs = []
|
adm_inputs = []
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import uuid
|
import uuid
|
||||||
|
import math
|
||||||
|
import collections
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -104,6 +106,21 @@ def cleanup_additional_models(models):
|
|||||||
if hasattr(m, 'cleanup'):
|
if hasattr(m, 'cleanup'):
|
||||||
m.cleanup()
|
m.cleanup()
|
||||||
|
|
||||||
|
def estimate_memory(model, noise_shape, conds):
|
||||||
|
cond_shapes = collections.defaultdict(list)
|
||||||
|
cond_shapes_min = {}
|
||||||
|
for _, cs in conds.items():
|
||||||
|
for cond in cs:
|
||||||
|
for k, v in model.model.extra_conds_shapes(**cond).items():
|
||||||
|
cond_shapes[k].append(v)
|
||||||
|
if cond_shapes_min.get(k, None) is None:
|
||||||
|
cond_shapes_min[k] = [v]
|
||||||
|
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
|
||||||
|
cond_shapes_min[k] = [v]
|
||||||
|
|
||||||
|
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
|
||||||
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
@ -117,9 +134,8 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
|
|||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
|
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
@ -256,7 +256,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
|
|||||||
for i in range(1, len(to_batch_temp) + 1):
|
for i in range(1, len(to_batch_temp) + 1):
|
||||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
cond_shapes = collections.defaultdict(list)
|
||||||
|
for tt in batch_amount:
|
||||||
|
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
|
||||||
|
for k, v in to_run[tt][0].conditioning.items():
|
||||||
|
cond_shapes[k].append(v.size())
|
||||||
|
|
||||||
|
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
|
||||||
to_batch = batch_amount
|
to_batch = batch_amount
|
||||||
break
|
break
|
||||||
|
|
||||||
|
15
execution.py
15
execution.py
@ -435,6 +435,20 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
|
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
|
def clean_inputs(prompt):
|
||||||
|
for unique_id, node in prompt.items():
|
||||||
|
inputs = node['inputs']
|
||||||
|
class_type = node['class_type']
|
||||||
|
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
|
||||||
|
class_inputs = obj_class.INPUT_TYPES()
|
||||||
|
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
|
||||||
|
for k in list(inputs.keys()):
|
||||||
|
if k not in valid_inputs:
|
||||||
|
inputs.pop(k)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server, cache_type=False, cache_size=None):
|
def __init__(self, server, cache_type=False, cache_size=None):
|
||||||
self.cache_size = cache_size
|
self.cache_size = cache_size
|
||||||
@ -486,6 +500,7 @@ class PromptExecutor:
|
|||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
|
prompt = clean_inputs(prompt)
|
||||||
|
|
||||||
if "client_id" in extra_data:
|
if "client_id" in extra_data:
|
||||||
self.server.client_id = extra_data["client_id"]
|
self.server.client_id = extra_data["client_id"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user