diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index fc3a67444..f3f445843 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -310,7 +310,7 @@ class HunyuanVideo(nn.Module): shape[i] = shape[i] // self.patch_size[i] img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) - img = img.reshape(initial_shape) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): diff --git a/comfy/model_base.py b/comfy/model_base.py index 98f462b32..0eeaed790 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -871,6 +871,15 @@ class HunyuanVideo(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + + if image is not None: + padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4]) + latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype) + image_latents = torch.cat([image.to(noise), latent_padding], dim=2) + out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents)) + guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2644dd0dc..5051f821d 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -136,7 +136,7 @@ def detect_unet_config(state_dict, key_prefix): if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video dit_config = {} dit_config["image_model"] = "hunyuan_video" - dit_config["in_channels"] = 16 + dit_config["in_channels"] = state_dict["img_in.proj.weight"].shape[1] #SkyReels img2video has 32 input channels dit_config["patch_size"] = [1, 2, 2] dit_config["out_channels"] = 16 dit_config["vec_in_dim"] = 768