mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-05 23:07:09 +08:00
Merge remote-tracking branch 'origin/master' into feature/custom_workflow_templates
This commit is contained in:
commit
e10cbaddea
@ -189,7 +189,7 @@ Nvidia users should install stable pytorch using this command:
|
|||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
|
||||||
|
|
||||||
#### Troubleshooting
|
#### Troubleshooting
|
||||||
|
|
||||||
|
@ -188,6 +188,12 @@ def is_nvidia():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_amd():
|
||||||
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if torch.version.hip:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
@ -198,27 +204,17 @@ if args.use_pytorch_cross_attention:
|
|||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
VAE_DTYPES = [torch.float32]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
|
||||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if is_intel_xpu():
|
|
||||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
|
||||||
|
|
||||||
if args.cpu_vae:
|
|
||||||
VAE_DTYPES = [torch.float32]
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
@ -754,7 +750,6 @@ def vae_offload_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def vae_dtype(device=None, allowed_dtypes=[]):
|
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||||
global VAE_DTYPES
|
|
||||||
if args.fp16_vae:
|
if args.fp16_vae:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
elif args.bf16_vae:
|
elif args.bf16_vae:
|
||||||
@ -763,12 +758,14 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
for d in allowed_dtypes:
|
for d in allowed_dtypes:
|
||||||
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
if d == torch.float16 and should_use_fp16(device):
|
||||||
return d
|
|
||||||
if d in VAE_DTYPES:
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return VAE_DTYPES[0]
|
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||||
|
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||||
|
return d
|
||||||
|
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
@ -889,14 +886,19 @@ def pytorch_attention_flash_attention():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def mac_version():
|
||||||
|
try:
|
||||||
|
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
|
||||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
macos_version = mac_version()
|
||||||
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
|
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if upcast:
|
if upcast:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
else:
|
else:
|
||||||
@ -967,17 +969,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if FORCE_FP16:
|
if FORCE_FP16:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
if is_device_mps(device):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if mps_mode():
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
@ -1026,17 +1024,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
if is_device_mps(device):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if mps_mode():
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||||
|
if mac_version() < (14,):
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
|
18
comfy/ops.py
18
comfy/ops.py
@ -255,9 +255,10 @@ def fp8_linear(self, input):
|
|||||||
tensor_2d = True
|
tensor_2d = True
|
||||||
input = input.unsqueeze(1)
|
input = input.unsqueeze(1)
|
||||||
|
|
||||||
|
input_shape = input.shape
|
||||||
|
input_dtype = input.dtype
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
@ -269,23 +270,24 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
|
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
else:
|
else:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input.shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
40
comfy/sd.py
40
comfy/sd.py
@ -111,7 +111,7 @@ class CLIP:
|
|||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.use_clip_schedule = False
|
self.use_clip_schedule = False
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@ -259,6 +259,9 @@ class VAE:
|
|||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
self.downscale_index_formula = None
|
||||||
|
self.upscale_index_formula = None
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
@ -338,7 +341,9 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||||
|
self.upscale_index_formula = (6, 8, 8)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
||||||
|
self.downscale_index_formula = (6, 8, 8)
|
||||||
self.working_dtypes = [torch.float16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.float32]
|
||||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||||
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
||||||
@ -353,14 +358,18 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||||
|
self.upscale_index_formula = (8, 32, 32)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||||
|
self.downscale_index_formula = (8, 32, 32)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
elif "decoder.conv_in.conv.weight" in sd:
|
elif "decoder.conv_in.conv.weight" in sd:
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
ddconfig["conv3d"] = True
|
ddconfig["conv3d"] = True
|
||||||
ddconfig["time_compress"] = 4
|
ddconfig["time_compress"] = 4
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
self.upscale_index_formula = (4, 8, 8)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||||
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||||
@ -393,7 +402,7 @@ class VAE:
|
|||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
downscale_ratio = self.spacial_compression_encode()
|
downscale_ratio = self.spacial_compression_encode()
|
||||||
@ -426,7 +435,7 @@ class VAE:
|
|||||||
|
|
||||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
@ -447,7 +456,7 @@ class VAE:
|
|||||||
|
|
||||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
@ -479,7 +488,7 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
dims = samples.ndim - 2
|
dims = samples.ndim - 2
|
||||||
@ -497,6 +506,13 @@ class VAE:
|
|||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
output = self.decode_tiled_(samples, **args)
|
output = self.decode_tiled_(samples, **args)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
|
if overlap_t is None:
|
||||||
|
args["overlap"] = (1, overlap, overlap)
|
||||||
|
else:
|
||||||
|
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||||
|
if tile_t is not None:
|
||||||
|
args["tile_t"] = max(2, tile_t)
|
||||||
|
|
||||||
output = self.decode_tiled_3d(samples, **args)
|
output = self.decode_tiled_3d(samples, **args)
|
||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
@ -532,7 +548,7 @@ class VAE:
|
|||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None):
|
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
dims = self.latent_dim
|
dims = self.latent_dim
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
@ -556,6 +572,12 @@ class VAE:
|
|||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
samples = self.encode_tiled_(pixel_samples, **args)
|
samples = self.encode_tiled_(pixel_samples, **args)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
|
if overlap_t is None:
|
||||||
|
args["overlap"] = (1, overlap, overlap)
|
||||||
|
else:
|
||||||
|
args["overlap"] = (overlap_t, overlap, overlap)
|
||||||
|
if tile_t is not None:
|
||||||
|
args["tile_t"] = tile_t
|
||||||
samples = self.encode_tiled_3d(pixel_samples, **args)
|
samples = self.encode_tiled_3d(pixel_samples, **args)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
@ -575,6 +597,12 @@ class VAE:
|
|||||||
except:
|
except:
|
||||||
return self.downscale_ratio
|
return self.downscale_ratio
|
||||||
|
|
||||||
|
def temporal_compression_decode(self):
|
||||||
|
try:
|
||||||
|
return round(self.upscale_ratio[0](8192) / 8192)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
def __init__(self, model, device="cpu"):
|
def __init__(self, model, device="cpu"):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -822,7 +822,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
|||||||
return rows * cols
|
return rows * cols
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
||||||
dims = len(tile)
|
dims = len(tile)
|
||||||
|
|
||||||
if not (isinstance(upscale_amount, (tuple, list))):
|
if not (isinstance(upscale_amount, (tuple, list))):
|
||||||
@ -831,6 +831,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
if not (isinstance(overlap, (tuple, list))):
|
if not (isinstance(overlap, (tuple, list))):
|
||||||
overlap = [overlap] * dims
|
overlap = [overlap] * dims
|
||||||
|
|
||||||
|
if index_formulas is None:
|
||||||
|
index_formulas = upscale_amount
|
||||||
|
|
||||||
|
if not (isinstance(index_formulas, (tuple, list))):
|
||||||
|
index_formulas = [index_formulas] * dims
|
||||||
|
|
||||||
def get_upscale(dim, val):
|
def get_upscale(dim, val):
|
||||||
up = upscale_amount[dim]
|
up = upscale_amount[dim]
|
||||||
if callable(up):
|
if callable(up):
|
||||||
@ -845,10 +851,26 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
else:
|
else:
|
||||||
return val / up
|
return val / up
|
||||||
|
|
||||||
|
def get_upscale_pos(dim, val):
|
||||||
|
up = index_formulas[dim]
|
||||||
|
if callable(up):
|
||||||
|
return up(val)
|
||||||
|
else:
|
||||||
|
return up * val
|
||||||
|
|
||||||
|
def get_downscale_pos(dim, val):
|
||||||
|
up = index_formulas[dim]
|
||||||
|
if callable(up):
|
||||||
|
return up(val)
|
||||||
|
else:
|
||||||
|
return val / up
|
||||||
|
|
||||||
if downscale:
|
if downscale:
|
||||||
get_scale = get_downscale
|
get_scale = get_downscale
|
||||||
|
get_pos = get_downscale_pos
|
||||||
else:
|
else:
|
||||||
get_scale = get_upscale
|
get_scale = get_upscale
|
||||||
|
get_pos = get_upscale_pos
|
||||||
|
|
||||||
def mult_list_upscale(a):
|
def mult_list_upscale(a):
|
||||||
out = []
|
out = []
|
||||||
@ -881,7 +903,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(get_scale(d, pos)))
|
upscaled.append(round(get_pos(d, pos)))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
|
60
main.py
60
main.py
@ -150,9 +150,10 @@ def cuda_malloc_warning():
|
|||||||
if cuda_malloc_warning:
|
if cuda_malloc_warning:
|
||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
def prompt_worker(q, server):
|
|
||||||
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
|
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
@ -167,7 +168,7 @@ def prompt_worker(q, server):
|
|||||||
item, item_id = queue_item
|
item, item_id = queue_item
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
prompt_id = item[1]
|
prompt_id = item[1]
|
||||||
server.last_prompt_id = prompt_id
|
server_instance.last_prompt_id = prompt_id
|
||||||
|
|
||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
@ -177,8 +178,8 @@ def prompt_worker(q, server):
|
|||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages))
|
messages=e.status_messages))
|
||||||
if server.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
@ -205,21 +206,23 @@ def prompt_worker(q, server):
|
|||||||
last_gc_collect = current_time
|
last_gc_collect = current_time
|
||||||
need_gc = False
|
need_gc = False
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
|
||||||
|
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
addresses = []
|
addresses = []
|
||||||
for addr in address.split(","):
|
for addr in address.split(","):
|
||||||
addresses.append((addr, port))
|
addresses.append((addr, port))
|
||||||
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
|
await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop())
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server_instance):
|
||||||
def hook(value, total, preview_image):
|
def hook(value, total, preview_image):
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
||||||
|
|
||||||
server.send_sync("progress", progress, server.client_id)
|
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||||
if preview_image is not None:
|
if preview_image is not None:
|
||||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||||
|
|
||||||
comfy.utils.set_progress_bar_global_hook(hook)
|
comfy.utils.set_progress_bar_global_hook(hook)
|
||||||
|
|
||||||
|
|
||||||
@ -229,7 +232,11 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def start_comfyui(asyncio_loop=None):
|
||||||
|
"""
|
||||||
|
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
||||||
|
Returns the event loop, server instance, and a function to start the server asynchronously.
|
||||||
|
"""
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
logging.info(f"Setting temp directory to: {temp_dir}")
|
logging.info(f"Setting temp directory to: {temp_dir}")
|
||||||
@ -243,19 +250,20 @@ if __name__ == "__main__":
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
if not asyncio_loop:
|
||||||
asyncio.set_event_loop(loop)
|
asyncio_loop = asyncio.new_event_loop()
|
||||||
server = server.PromptServer(loop)
|
asyncio.set_event_loop(asyncio_loop)
|
||||||
q = execution.PromptQueue(server)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
|
q = execution.PromptQueue(prompt_server)
|
||||||
|
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
|
||||||
server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(server)
|
hijack_progress(prompt_server)
|
||||||
|
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
|
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
@ -272,9 +280,19 @@ if __name__ == "__main__":
|
|||||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
async def start_all():
|
||||||
|
await prompt_server.setup()
|
||||||
|
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
||||||
|
|
||||||
|
# Returning these so that other code can integrate with the ComfyUI loop and server
|
||||||
|
return asyncio_loop, prompt_server, start_all
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Running directly, just start ComfyUI.
|
||||||
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(server.setup())
|
event_loop.run_until_complete(start_all_func())
|
||||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("\nStopped server")
|
logging.info("\nStopped server")
|
||||||
|
|
||||||
|
22
nodes.py
22
nodes.py
@ -293,17 +293,29 @@ class VAEDecodeTiled:
|
|||||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
|
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
|
||||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||||
|
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}),
|
||||||
|
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def decode(self, vae, samples, tile_size, overlap=64):
|
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||||
if tile_size < overlap * 4:
|
if tile_size < overlap * 4:
|
||||||
overlap = tile_size // 4
|
overlap = tile_size // 4
|
||||||
|
if temporal_size < temporal_overlap * 2:
|
||||||
|
temporal_overlap = temporal_overlap // 2
|
||||||
|
temporal_compression = vae.temporal_compression_decode()
|
||||||
|
if temporal_compression is not None:
|
||||||
|
temporal_size = max(2, temporal_size // temporal_compression)
|
||||||
|
temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression)
|
||||||
|
else:
|
||||||
|
temporal_size = None
|
||||||
|
temporal_overlap = None
|
||||||
|
|
||||||
compression = vae.spacial_compression_decode()
|
compression = vae.spacial_compression_decode()
|
||||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
|
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||||
if len(images.shape) == 5: #Combine batches
|
if len(images.shape) == 5: #Combine batches
|
||||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
return (images, )
|
return (images, )
|
||||||
@ -327,14 +339,16 @@ class VAEEncodeTiled:
|
|||||||
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
|
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
|
||||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
|
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
|
||||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||||
|
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}),
|
||||||
|
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def encode(self, vae, pixels, tile_size, overlap):
|
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap)
|
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||||
return ({"samples": t}, )
|
return ({"samples": t}, )
|
||||||
|
|
||||||
class VAEEncodeForInpaint:
|
class VAEEncodeForInpaint:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user