mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-03-15 05:57:20 +00:00
Support loading and using SkyReels-V1-Hunyuan-I2V (#6862)
* Support SkyReels-V1-Hunyuan-I2V * VAE scaling * Fix T2V oops * Proper latent scaling
This commit is contained in:
parent
b07258cef2
commit
acc152b674
@ -310,7 +310,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
shape[i] = shape[i] // self.patch_size[i]
|
shape[i] = shape[i] // self.patch_size[i]
|
||||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||||
img = img.reshape(initial_shape)
|
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
@ -871,6 +871,15 @@ class HunyuanVideo(BaseModel):
|
|||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
image = kwargs.get("concat_latent_image", None)
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
padding_shape = (noise.shape[0], 16, noise.shape[2] - 1, noise.shape[3], noise.shape[4])
|
||||||
|
latent_padding = torch.zeros(padding_shape, device=noise.device, dtype=noise.dtype)
|
||||||
|
image_latents = torch.cat([image.to(noise), latent_padding], dim=2)
|
||||||
|
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_latents))
|
||||||
|
|
||||||
guidance = kwargs.get("guidance", 6.0)
|
guidance = kwargs.get("guidance", 6.0)
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||||
|
@ -136,7 +136,7 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "hunyuan_video"
|
dit_config["image_model"] = "hunyuan_video"
|
||||||
dit_config["in_channels"] = 16
|
dit_config["in_channels"] = state_dict["img_in.proj.weight"].shape[1] #SkyReels img2video has 32 input channels
|
||||||
dit_config["patch_size"] = [1, 2, 2]
|
dit_config["patch_size"] = [1, 2, 2]
|
||||||
dit_config["out_channels"] = 16
|
dit_config["out_channels"] = 16
|
||||||
dit_config["vec_in_dim"] = 768
|
dit_config["vec_in_dim"] = 768
|
||||||
|
Loading…
Reference in New Issue
Block a user