From 0d7fd5cd07833d5ee831ab262461f0bf8fa5a9f3 Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Thu, 19 Dec 2024 12:15:56 +0100 Subject: [PATCH] Fix out of range for trim frame, Fix ffmpeg extraction count (#836) * Fix out of range for trim frame, Fix ffmpeg extraction count * Move restrict of trim frame to the core, Make sure all values are within the range * Fix and merge testing * Fix typing --- facefusion/content_analyser.py | 7 ++- facefusion/core.py | 9 ++-- facefusion/ffmpeg.py | 17 +++----- facefusion/vision.py | 25 +++++++++++ tests/test_audio.py | 4 +- tests/test_ffmpeg.py | 80 +++++++--------------------------- tests/test_vision.py | 24 +++++++++- 7 files changed, 78 insertions(+), 88 deletions(-) diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index 5d1c7ce6..ee25f58b 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -9,7 +9,7 @@ from facefusion.download import conditional_download_hashes, conditional_downloa from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import conditional_thread_semaphore from facefusion.typing import DownloadScope, Fps, InferencePool, ModelOptions, ModelSet, VisionFrame -from facefusion.vision import count_video_frame_total, detect_video_fps, get_video_frame, read_image +from facefusion.vision import detect_video_fps, get_video_frame, read_image PROBABILITY_LIMIT = 0.80 RATE_LIMIT = 10 @@ -108,10 +108,9 @@ def analyse_image(image_path : str) -> bool: @lru_cache(maxsize = None) -def analyse_video(video_path : str, start_frame : int, end_frame : int) -> bool: - video_frame_total = count_video_frame_total(video_path) +def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int) -> bool: video_fps = detect_video_fps(video_path) - frame_range = range(start_frame or 0, end_frame or video_frame_total) + frame_range = range(trim_frame_start, trim_frame_end) rate = 0.0 counter = 0 diff --git a/facefusion/core.py b/facefusion/core.py index 25eb7d8f..38d64dbc 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -26,7 +26,7 @@ from facefusion.program_helper import validate_args from facefusion.statistics import conditional_log_statistics from facefusion.temp_helper import clear_temp_directory, create_temp_directory, get_temp_file_path, get_temp_frame_paths, move_temp_file from facefusion.typing import Args, ErrorCode -from facefusion.vision import get_video_frame, pack_resolution, read_image, read_static_images, restrict_image_resolution, restrict_video_fps, restrict_video_resolution, unpack_resolution +from facefusion.vision import get_video_frame, pack_resolution, read_image, read_static_images, restrict_image_resolution, restrict_trim_frame, restrict_video_fps, restrict_video_resolution, unpack_resolution def cli() -> None: @@ -389,7 +389,8 @@ def process_image(start_time : float) -> ErrorCode: def process_video(start_time : float) -> ErrorCode: - if analyse_video(state_manager.get_item('target_path'), state_manager.get_item('trim_frame_start'), state_manager.get_item('trim_frame_end')): + trim_frame_start, trim_frame_end = restrict_trim_frame(state_manager.get_item('target_path'), state_manager.get_item('trim_frame_start'), state_manager.get_item('trim_frame_end')) + if analyse_video(state_manager.get_item('target_path'), trim_frame_start, trim_frame_end): return 3 # clear temp logger.debug(wording.get('clearing_temp'), __name__) @@ -402,7 +403,7 @@ def process_video(start_time : float) -> ErrorCode: temp_video_resolution = pack_resolution(restrict_video_resolution(state_manager.get_item('target_path'), unpack_resolution(state_manager.get_item('output_video_resolution')))) temp_video_fps = restrict_video_fps(state_manager.get_item('target_path'), state_manager.get_item('output_video_fps')) logger.info(wording.get('extracting_frames').format(resolution = temp_video_resolution, fps = temp_video_fps), __name__) - if extract_frames(state_manager.get_item('target_path'), temp_video_resolution, temp_video_fps): + if extract_frames(state_manager.get_item('target_path'), temp_video_resolution, temp_video_fps, trim_frame_start, trim_frame_end): logger.debug(wording.get('extracting_frames_succeed'), __name__) else: if is_process_stopping(): @@ -451,7 +452,7 @@ def process_video(start_time : float) -> ErrorCode: logger.warn(wording.get('replacing_audio_skipped'), __name__) move_temp_file(state_manager.get_item('target_path'), state_manager.get_item('output_path')) else: - if restore_audio(state_manager.get_item('target_path'), state_manager.get_item('output_path'), state_manager.get_item('output_video_fps')): + if restore_audio(state_manager.get_item('target_path'), state_manager.get_item('output_path'), state_manager.get_item('output_video_fps'), trim_frame_start, trim_frame_end): logger.debug(wording.get('restoring_audio_succeed'), __name__) else: if is_process_stopping(): diff --git a/facefusion/ffmpeg.py b/facefusion/ffmpeg.py index e4960d03..48530282 100644 --- a/facefusion/ffmpeg.py +++ b/facefusion/ffmpeg.py @@ -11,7 +11,7 @@ from facefusion import logger, process_manager, state_manager, wording from facefusion.filesystem import remove_file from facefusion.temp_helper import get_temp_file_path, get_temp_frame_paths, get_temp_frames_pattern from facefusion.typing import AudioBuffer, Fps, OutputVideoPreset, UpdateProgress -from facefusion.vision import count_video_frame_total, detect_video_duration, restrict_video_fps +from facefusion.vision import count_trim_frame_total, detect_video_duration, restrict_video_fps def run_ffmpeg_with_progress(args: List[str], update_progress : UpdateProgress) -> subprocess.Popen[bytes]: @@ -73,22 +73,17 @@ def log_debug(process : subprocess.Popen[bytes]) -> None: logger.debug(error.strip(), __name__) -def extract_frames(target_path : str, temp_video_resolution : str, temp_video_fps : Fps) -> bool: - extract_frame_total = count_video_frame_total(state_manager.get_item('target_path')) - trim_frame_start = state_manager.get_item('trim_frame_start') - trim_frame_end = state_manager.get_item('trim_frame_end') +def extract_frames(target_path : str, temp_video_resolution : str, temp_video_fps : Fps, trim_frame_start : int, trim_frame_end : int) -> bool: + extract_frame_total = count_trim_frame_total(target_path, trim_frame_start, trim_frame_end) temp_frames_pattern = get_temp_frames_pattern(target_path, '%08d') commands = [ '-i', target_path, '-s', str(temp_video_resolution), '-q:v', '0' ] if isinstance(trim_frame_start, int) and isinstance(trim_frame_end, int): commands.extend([ '-vf', 'trim=start_frame=' + str(trim_frame_start) + ':end_frame=' + str(trim_frame_end) + ',fps=' + str(temp_video_fps) ]) - extract_frame_total = trim_frame_end - trim_frame_start elif isinstance(trim_frame_start, int): commands.extend([ '-vf', 'trim=start_frame=' + str(trim_frame_start) + ',fps=' + str(temp_video_fps) ]) - extract_frame_total -= trim_frame_start elif isinstance(trim_frame_end, int): commands.extend([ '-vf', 'trim=end_frame=' + str(trim_frame_end) + ',fps=' + str(temp_video_fps) ]) - extract_frame_total -= trim_frame_end else: commands.extend([ '-vf', 'fps=' + str(temp_video_fps) ]) commands.extend([ '-vsync', '0', temp_frames_pattern ]) @@ -99,10 +94,10 @@ def extract_frames(target_path : str, temp_video_resolution : str, temp_video_fp def merge_video(target_path : str, output_video_resolution : str, output_video_fps: Fps) -> bool: - merge_frame_total = len(get_temp_frame_paths(target_path)) output_video_encoder = state_manager.get_item('output_video_encoder') output_video_quality = state_manager.get_item('output_video_quality') output_video_preset = state_manager.get_item('output_video_preset') + merge_frame_total = len(get_temp_frame_paths(target_path)) temp_video_fps = restrict_video_fps(target_path, output_video_fps) temp_file_path = get_temp_file_path(target_path) temp_frames_pattern = get_temp_frames_pattern(target_path, '%08d') @@ -179,9 +174,7 @@ def read_audio_buffer(target_path : str, sample_rate : int, channel_total : int) return None -def restore_audio(target_path : str, output_path : str, output_video_fps : Fps) -> bool: - trim_frame_start = state_manager.get_item('trim_frame_start') - trim_frame_end = state_manager.get_item('trim_frame_end') +def restore_audio(target_path : str, output_path : str, output_video_fps : Fps, trim_frame_start : int, trim_frame_end : int) -> bool: output_audio_encoder = state_manager.get_item('output_audio_encoder') temp_file_path = get_temp_file_path(target_path) temp_video_duration = detect_video_duration(temp_file_path) diff --git a/facefusion/vision.py b/facefusion/vision.py index 211dbe46..04c1b5ec 100644 --- a/facefusion/vision.py +++ b/facefusion/vision.py @@ -122,11 +122,36 @@ def restrict_video_fps(video_path : str, fps : Fps) -> Fps: def detect_video_duration(video_path : str) -> Duration: video_frame_total = count_video_frame_total(video_path) video_fps = detect_video_fps(video_path) + if video_frame_total and video_fps: return video_frame_total / video_fps return 0 +def count_trim_frame_total(video_path : str, trim_frame_start : Optional[int], trim_frame_end : Optional[int]) -> int: + trim_frame_start, trim_frame_end = restrict_trim_frame(video_path, trim_frame_start, trim_frame_end) + + return trim_frame_end - trim_frame_start + + +def restrict_trim_frame(video_path : str, trim_frame_start : Optional[int], trim_frame_end : Optional[int]) -> Tuple[int, int]: + video_frame_total = count_video_frame_total(video_path) + + if isinstance(trim_frame_start, int): + trim_frame_start = max(0, min(trim_frame_start, video_frame_total)) + if isinstance(trim_frame_end, int): + trim_frame_end = max(0, min(trim_frame_end, video_frame_total)) + + if isinstance(trim_frame_start, int) and isinstance(trim_frame_end, int): + return trim_frame_start, trim_frame_end + if isinstance(trim_frame_start, int): + return trim_frame_start, video_frame_total + if isinstance(trim_frame_end, int): + return 0, trim_frame_end + + return 0, video_frame_total + + def detect_video_resolution(video_path : str) -> Optional[Resolution]: if is_video(video_path): if is_windows(): diff --git a/tests/test_audio.py b/tests/test_audio.py index 66039f1e..36faf9b0 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -17,8 +17,8 @@ def before_all() -> None: def test_get_audio_frame() -> None: - assert get_audio_frame(get_test_example_file('source.mp3'), 25) is not None - assert get_audio_frame(get_test_example_file('source.wav'), 25) is not None + assert hasattr(get_audio_frame(get_test_example_file('source.mp3'), 25), '__array_interface__') + assert hasattr(get_audio_frame(get_test_example_file('source.wav'), 25), '__array_interface__') assert get_audio_frame('invalid', 25) is None diff --git a/tests/test_ffmpeg.py b/tests/test_ffmpeg.py index fa8e014a..a703ef4a 100644 --- a/tests/test_ffmpeg.py +++ b/tests/test_ffmpeg.py @@ -33,78 +33,30 @@ def before_all() -> None: @pytest.fixture(scope = 'function', autouse = True) def before_each() -> None: - state_manager.clear_item('trim_frame_start') - state_manager.clear_item('trim_frame_end') prepare_test_output_directory() def test_extract_frames() -> None: - target_paths =\ + extract_set =\ [ - get_test_example_file('target-240p-25fps.mp4'), - get_test_example_file('target-240p-30fps.mp4'), - get_test_example_file('target-240p-60fps.mp4') + (get_test_example_file('target-240p-25fps.mp4'), 0, 270, 324), + (get_test_example_file('target-240p-25fps.mp4'), 224, 270, 55), + (get_test_example_file('target-240p-25fps.mp4'), 124, 224, 120), + (get_test_example_file('target-240p-25fps.mp4'), 0, 100, 120), + (get_test_example_file('target-240p-30fps.mp4'), 0, 324, 324), + (get_test_example_file('target-240p-30fps.mp4'), 224, 324, 100), + (get_test_example_file('target-240p-30fps.mp4'), 124, 224, 100), + (get_test_example_file('target-240p-30fps.mp4'), 0, 100, 100), + (get_test_example_file('target-240p-60fps.mp4'), 0, 648, 324), + (get_test_example_file('target-240p-60fps.mp4'), 224, 648, 212), + (get_test_example_file('target-240p-60fps.mp4'), 124, 224, 50), + (get_test_example_file('target-240p-60fps.mp4'), 0, 100, 50) ] - for target_path in target_paths: + for target_path, trim_frame_start, trim_frame_end, frame_total in extract_set: create_temp_directory(target_path) - assert extract_frames(target_path, '452x240', 30.0) is True - assert len(get_temp_frame_paths(target_path)) == 324 - - clear_temp_directory(target_path) - - -def test_extract_frames_with_trim_start() -> None: - state_manager.init_item('trim_frame_start', 224) - target_paths =\ - [ - (get_test_example_file('target-240p-25fps.mp4'), 55), - (get_test_example_file('target-240p-30fps.mp4'), 100), - (get_test_example_file('target-240p-60fps.mp4'), 212) - ] - - for target_path, frame_total in target_paths: - create_temp_directory(target_path) - - assert extract_frames(target_path, '452x240', 30.0) is True - assert len(get_temp_frame_paths(target_path)) == frame_total - - clear_temp_directory(target_path) - - -def test_extract_frames_with_trim_start_and_trim_end() -> None: - state_manager.init_item('trim_frame_start', 124) - state_manager.init_item('trim_frame_end', 224) - target_paths =\ - [ - (get_test_example_file('target-240p-25fps.mp4'), 120), - (get_test_example_file('target-240p-30fps.mp4'), 100), - (get_test_example_file('target-240p-60fps.mp4'), 50) - ] - - for target_path, frame_total in target_paths: - create_temp_directory(target_path) - - assert extract_frames(target_path, '452x240', 30.0) is True - assert len(get_temp_frame_paths(target_path)) == frame_total - - clear_temp_directory(target_path) - - -def test_extract_frames_with_trim_end() -> None: - state_manager.init_item('trim_frame_end', 100) - target_paths =\ - [ - (get_test_example_file('target-240p-25fps.mp4'), 120), - (get_test_example_file('target-240p-30fps.mp4'), 100), - (get_test_example_file('target-240p-60fps.mp4'), 50) - ] - - for target_path, frame_total in target_paths: - create_temp_directory(target_path) - - assert extract_frames(target_path, '426x240', 30.0) is True + assert extract_frames(target_path, '452x240', 30.0, trim_frame_start, trim_frame_end) is True assert len(get_temp_frame_paths(target_path)) == frame_total clear_temp_directory(target_path) @@ -139,7 +91,7 @@ def test_restore_audio() -> None: create_temp_directory(target_path) copy_file(target_path, get_temp_file_path(target_path)) - assert restore_audio(target_path, output_path, 30) is True + assert restore_audio(target_path, output_path, 30, 0, 270) is True clear_temp_directory(target_path) diff --git a/tests/test_vision.py b/tests/test_vision.py index f5d2f5b4..d79fb07c 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -3,7 +3,7 @@ import subprocess import pytest from facefusion.download import conditional_download -from facefusion.vision import calc_histogram_difference, count_video_frame_total, create_image_resolutions, create_video_resolutions, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution, get_video_frame, match_frame_color, normalize_resolution, pack_resolution, read_image, restrict_image_resolution, restrict_video_fps, restrict_video_resolution, unpack_resolution +from facefusion.vision import calc_histogram_difference, count_trim_frame_total, count_video_frame_total, create_image_resolutions, create_video_resolutions, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution, get_video_frame, match_frame_color, normalize_resolution, pack_resolution, read_image, restrict_image_resolution, restrict_trim_frame, restrict_video_fps, restrict_video_resolution, unpack_resolution from .helper import get_test_example_file, get_test_examples_directory @@ -50,7 +50,7 @@ def test_create_image_resolutions() -> None: def test_get_video_frame() -> None: - assert get_video_frame(get_test_example_file('target-240p-25fps.mp4')) is not None + assert hasattr(get_video_frame(get_test_example_file('target-240p-25fps.mp4')), '__array_interface__') assert get_video_frame('invalid') is None @@ -79,6 +79,26 @@ def test_detect_video_duration() -> None: assert detect_video_duration('invalid') == 0 +def test_count_trim_frame_total() -> None: + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), 0, 200) == 200 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), 70, 270) == 200 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), -10, None) == 270 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), None, -10) == 0 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), 280, None) == 0 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), None, 280) == 270 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), None, None) == 270 + + +def test_restrict_trim_frame() -> None: + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), 0, 200) == (0, 200) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), 70, 270) == (70, 270) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), -10, None) == (0, 270) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), None, -10) == (0, 0) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), 280, None) == (270, 270) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), None, 280) == (0, 270) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), None, None) == (0, 270) + + def test_detect_video_resolution() -> None: assert detect_video_resolution(get_test_example_file('target-240p.mp4')) == (426, 226) assert detect_video_resolution(get_test_example_file('target-240p-90deg.mp4')) == (226, 426)