Hunyuan Custom initial untested implementation. (#8101)

This commit is contained in:
comfyanonymous 2025-05-13 12:53:47 -07:00 committed by GitHub
parent 8a7c894d54
commit 4a9014e201
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 5 deletions

View File

@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module):
y: Tensor,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
control=None,
transformer_options={},
) -> Tensor:
@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
img = torch.cat([ref_latent, img], dim=-2)
ref_latent_ids[..., 0] = -1
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module):
img[:, : img_len] += add
img = img[:, : img_len]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module):
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
def img_ids(self, x):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out

View File

@ -924,6 +924,10 @@ class HunyuanVideo(BaseModel):
if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
ref_latent = kwargs.get("ref_latent", None)
if ref_latent is not None:
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
return out
def scale_latent_inpaint(self, latent_image, **kwargs):

View File

@ -77,7 +77,7 @@ class HunyuanImageToVideo:
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
},
"optional": {"start_image": ("IMAGE", ),
}}
@ -101,10 +101,12 @@ class HunyuanImageToVideo:
if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
else:
elif guidance_type == "v2 (replace)":
cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask
elif guidance_type == "custom":
cond = {"ref_latent": concat_latent_image}
positive = node_helpers.conditioning_set_values(positive, cond)