Add get_duration method to Comfy VIDEO type (#8122)

* get duration from VIDEO type

* video get_duration unit test

* fix Windows unit test: can't delete opened temp file
This commit is contained in:
Christian Byrne 2025-05-14 21:11:41 -07:00 committed by GitHub
parent 08368f8e00
commit f1f9763b4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 281 additions and 0 deletions

View File

@ -43,3 +43,13 @@ class VideoInput(ABC):
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

View File

@ -80,6 +80,38 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []

View File

@ -0,0 +1,239 @@
import pytest
import torch
import tempfile
import os
import av
import io
from fractions import Fraction
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
from comfy_api.util.video_types import VideoComponents
from comfy_api.input.basic_types import AudioInput
from av.error import InvalidDataError
EPSILON = 0.0001
@pytest.fixture
def sample_images():
"""3-frame 2x2 RGB video tensor"""
return torch.rand(3, 2, 2, 3)
@pytest.fixture
def sample_audio():
"""Stereo audio with 44.1kHz sample rate"""
return AudioInput(
{
"waveform": torch.rand(1, 2, 1000),
"sample_rate": 44100,
}
)
@pytest.fixture
def video_components(sample_images, sample_audio):
"""VideoComponents with images, audio, and metadata"""
return VideoComponents(
images=sample_images,
audio=sample_audio,
frame_rate=Fraction(30),
metadata={"test": "metadata"},
)
def create_test_video(width=4, height=4, frames=3, fps=30):
"""Helper to create a temporary video file"""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame = av.VideoFrame.from_ndarray(
torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85),
format="rgb24",
)
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
# Flush
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def simple_video_file():
"""4x4 video with 3 frames at 30fps"""
file_path = create_test_video()
yield file_path
os.unlink(file_path)
def test_video_from_components_get_duration(video_components):
"""Duration calculated correctly from frame count and frame rate"""
video = VideoFromComponents(video_components)
duration = video.get_duration()
expected_duration = 3.0 / 30.0
assert duration == pytest.approx(expected_duration)
def test_video_from_components_get_duration_different_frame_rates(sample_images):
"""Duration correct for different frame rates including fractional"""
# Test with 60 fps
components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60))
video_60fps = VideoFromComponents(components_60fps)
assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0)
# Test with fractional frame rate (23.976fps)
components_frac = VideoComponents(
images=sample_images, frame_rate=Fraction(24000, 1001)
)
video_frac = VideoFromComponents(components_frac)
expected_frac = 3.0 / (24000.0 / 1001.0)
assert video_frac.get_duration() == pytest.approx(expected_frac)
def test_video_from_components_get_duration_empty_video():
"""Duration is zero for empty video"""
empty_components = VideoComponents(
images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30)
)
video = VideoFromComponents(empty_components)
assert video.get_duration() == 0.0
def test_video_from_components_get_dimensions(video_components):
"""Dimensions returned correctly from image tensor shape"""
video = VideoFromComponents(video_components)
width, height = video.get_dimensions()
assert width == 2
assert height == 2
def test_video_from_file_get_duration(simple_video_file):
"""Duration extracted from file metadata"""
video = VideoFromFile(simple_video_file)
duration = video.get_duration()
assert duration == pytest.approx(0.1, abs=0.01)
def test_video_from_file_get_dimensions(simple_video_file):
"""Dimensions read from stream without decoding frames"""
video = VideoFromFile(simple_video_file)
width, height = video.get_dimensions()
assert width == 4
assert height == 4
def test_video_from_file_bytesio_input():
"""VideoFromFile works with BytesIO input"""
buffer = io.BytesIO()
with av.open(buffer, mode="w", format="mp4") as container:
stream = container.add_stream("h264", rate=30)
stream.width = 2
stream.height = 2
stream.pix_fmt = "yuv420p"
frame = av.VideoFrame.from_ndarray(
torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24"
)
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
buffer.seek(0)
video = VideoFromFile(buffer)
assert video.get_dimensions() == (2, 2)
assert video.get_duration() == pytest.approx(1 / 30, abs=0.01)
def test_video_from_file_invalid_file_error():
"""InvalidDataError raised for non-video files"""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
tmp.write(b"not a video file")
tmp.flush()
tmp_name = tmp.name
try:
with pytest.raises(InvalidDataError):
video = VideoFromFile(tmp_name)
video.get_dimensions()
finally:
os.unlink(tmp_name)
def test_video_from_file_audio_only_error():
"""ValueError raised for audio-only files"""
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp:
tmp_name = tmp.name
try:
with av.open(tmp_name, mode="w") as container:
stream = container.add_stream("aac", rate=44100)
stream.sample_rate = 44100
stream.format = "fltp"
audio_data = torch.zeros(1, 1024).numpy()
audio_frame = av.AudioFrame.from_ndarray(
audio_data, format="fltp", layout="mono"
)
audio_frame.sample_rate = 44100
audio_frame.pts = 0
packet = stream.encode(audio_frame)
container.mux(packet)
for packet in stream.encode(None):
container.mux(packet)
with pytest.raises(ValueError, match="No video stream found"):
video = VideoFromFile(tmp_name)
video.get_dimensions()
finally:
os.unlink(tmp_name)
def test_single_frame_video():
"""Single frame video has correct duration"""
components = VideoComponents(
images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1)
)
video = VideoFromComponents(components)
assert video.get_duration() == 1.0
@pytest.mark.parametrize(
"frame_rate,expected_fps",
[
(Fraction(24000, 1001), 24000 / 1001),
(Fraction(30000, 1001), 30000 / 1001),
(Fraction(25, 1), 25.0),
(Fraction(50, 2), 25.0),
],
)
def test_fractional_frame_rates(frame_rate, expected_fps):
"""Duration calculated correctly for various fractional frame rates"""
components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate)
video = VideoFromComponents(components)
duration = video.get_duration()
expected_duration = 100.0 / expected_fps
assert duration == pytest.approx(expected_duration)
def test_duration_consistency(video_components):
"""get_duration() consistent with manual calculation from components"""
video = VideoFromComponents(video_components)
duration = video.get_duration()
components = video.get_components()
manual_duration = float(components.images.shape[0] / components.frame_rate)
assert duration == pytest.approx(manual_duration)