From c5afda6198b98f4413f51eba865512b165f72432 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Thu, 19 Dec 2024 10:27:18 +0100 Subject: [PATCH] Move restrict of trim frame to the core, Make sure all values are within the range --- facefusion/content_analyser.py | 3 +-- facefusion/core.py | 9 +++++---- facefusion/ffmpeg.py | 10 +++------- facefusion/vision.py | 8 ++++---- tests/test_vision.py | 4 ++++ 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index 3418680f..ee25f58b 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -9,7 +9,7 @@ from facefusion.download import conditional_download_hashes, conditional_downloa from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import conditional_thread_semaphore from facefusion.typing import DownloadScope, Fps, InferencePool, ModelOptions, ModelSet, VisionFrame -from facefusion.vision import detect_video_fps, get_video_frame, read_image, restrict_trim_frame +from facefusion.vision import detect_video_fps, get_video_frame, read_image PROBABILITY_LIMIT = 0.80 RATE_LIMIT = 10 @@ -110,7 +110,6 @@ def analyse_image(image_path : str) -> bool: @lru_cache(maxsize = None) def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int) -> bool: video_fps = detect_video_fps(video_path) - trim_frame_start, trim_frame_end = restrict_trim_frame(video_path, trim_frame_start, trim_frame_end) frame_range = range(trim_frame_start, trim_frame_end) rate = 0.0 counter = 0 diff --git a/facefusion/core.py b/facefusion/core.py index 25eb7d8f..38d64dbc 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -26,7 +26,7 @@ from facefusion.program_helper import validate_args from facefusion.statistics import conditional_log_statistics from facefusion.temp_helper import clear_temp_directory, create_temp_directory, get_temp_file_path, get_temp_frame_paths, move_temp_file from facefusion.typing import Args, ErrorCode -from facefusion.vision import get_video_frame, pack_resolution, read_image, read_static_images, restrict_image_resolution, restrict_video_fps, restrict_video_resolution, unpack_resolution +from facefusion.vision import get_video_frame, pack_resolution, read_image, read_static_images, restrict_image_resolution, restrict_trim_frame, restrict_video_fps, restrict_video_resolution, unpack_resolution def cli() -> None: @@ -389,7 +389,8 @@ def process_image(start_time : float) -> ErrorCode: def process_video(start_time : float) -> ErrorCode: - if analyse_video(state_manager.get_item('target_path'), state_manager.get_item('trim_frame_start'), state_manager.get_item('trim_frame_end')): + trim_frame_start, trim_frame_end = restrict_trim_frame(state_manager.get_item('target_path'), state_manager.get_item('trim_frame_start'), state_manager.get_item('trim_frame_end')) + if analyse_video(state_manager.get_item('target_path'), trim_frame_start, trim_frame_end): return 3 # clear temp logger.debug(wording.get('clearing_temp'), __name__) @@ -402,7 +403,7 @@ def process_video(start_time : float) -> ErrorCode: temp_video_resolution = pack_resolution(restrict_video_resolution(state_manager.get_item('target_path'), unpack_resolution(state_manager.get_item('output_video_resolution')))) temp_video_fps = restrict_video_fps(state_manager.get_item('target_path'), state_manager.get_item('output_video_fps')) logger.info(wording.get('extracting_frames').format(resolution = temp_video_resolution, fps = temp_video_fps), __name__) - if extract_frames(state_manager.get_item('target_path'), temp_video_resolution, temp_video_fps): + if extract_frames(state_manager.get_item('target_path'), temp_video_resolution, temp_video_fps, trim_frame_start, trim_frame_end): logger.debug(wording.get('extracting_frames_succeed'), __name__) else: if is_process_stopping(): @@ -451,7 +452,7 @@ def process_video(start_time : float) -> ErrorCode: logger.warn(wording.get('replacing_audio_skipped'), __name__) move_temp_file(state_manager.get_item('target_path'), state_manager.get_item('output_path')) else: - if restore_audio(state_manager.get_item('target_path'), state_manager.get_item('output_path'), state_manager.get_item('output_video_fps')): + if restore_audio(state_manager.get_item('target_path'), state_manager.get_item('output_path'), state_manager.get_item('output_video_fps'), trim_frame_start, trim_frame_end): logger.debug(wording.get('restoring_audio_succeed'), __name__) else: if is_process_stopping(): diff --git a/facefusion/ffmpeg.py b/facefusion/ffmpeg.py index 3a5b8989..48530282 100644 --- a/facefusion/ffmpeg.py +++ b/facefusion/ffmpeg.py @@ -73,10 +73,8 @@ def log_debug(process : subprocess.Popen[bytes]) -> None: logger.debug(error.strip(), __name__) -def extract_frames(target_path : str, temp_video_resolution : str, temp_video_fps : Fps) -> bool: - trim_frame_start = state_manager.get_item('trim_frame_start') - trim_frame_end = state_manager.get_item('trim_frame_end') - extract_frame_total = count_trim_frame_total(state_manager.get_item('target_path'), trim_frame_start, trim_frame_end) +def extract_frames(target_path : str, temp_video_resolution : str, temp_video_fps : Fps, trim_frame_start : int, trim_frame_end : int) -> bool: + extract_frame_total = count_trim_frame_total(target_path, trim_frame_start, trim_frame_end) temp_frames_pattern = get_temp_frames_pattern(target_path, '%08d') commands = [ '-i', target_path, '-s', str(temp_video_resolution), '-q:v', '0' ] @@ -176,9 +174,7 @@ def read_audio_buffer(target_path : str, sample_rate : int, channel_total : int) return None -def restore_audio(target_path : str, output_path : str, output_video_fps : Fps) -> bool: - trim_frame_start = state_manager.get_item('trim_frame_start') - trim_frame_end = state_manager.get_item('trim_frame_end') +def restore_audio(target_path : str, output_path : str, output_video_fps : Fps, trim_frame_start : int, trim_frame_end : int) -> bool: output_audio_encoder = state_manager.get_item('output_audio_encoder') temp_file_path = get_temp_file_path(target_path) temp_video_duration = detect_video_duration(temp_file_path) diff --git a/facefusion/vision.py b/facefusion/vision.py index b459f028..af2dee83 100644 --- a/facefusion/vision.py +++ b/facefusion/vision.py @@ -137,10 +137,10 @@ def count_trim_frame_total(video_path : str, trim_frame_start : int, trim_frame_ def restrict_trim_frame(video_path : str, trim_frame_start : int, trim_frame_end : int) -> Tuple[int, int]: video_frame_total = count_video_frame_total(video_path) - if isinstance(trim_frame_start, int) and trim_frame_start < 0: - trim_frame_start = 0 - if isinstance(trim_frame_end, int) and trim_frame_end > video_frame_total: - trim_frame_end = video_frame_total + if isinstance(trim_frame_start, int): + trim_frame_start = max(0, min(trim_frame_start, video_frame_total)) + if isinstance(trim_frame_end, int): + trim_frame_end = max(0, min(trim_frame_end, video_frame_total)) if isinstance(trim_frame_start, int) and isinstance(trim_frame_end, int): return trim_frame_start, trim_frame_end diff --git a/tests/test_vision.py b/tests/test_vision.py index 303d780e..d79fb07c 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -83,6 +83,8 @@ def test_count_trim_frame_total() -> None: assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), 0, 200) == 200 assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), 70, 270) == 200 assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), -10, None) == 270 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), None, -10) == 0 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), 280, None) == 0 assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), None, 280) == 270 assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), None, None) == 270 @@ -91,6 +93,8 @@ def test_restrict_trim_frame() -> None: assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), 0, 200) == (0, 200) assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), 70, 270) == (70, 270) assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), -10, None) == (0, 270) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), None, -10) == (0, 0) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), 280, None) == (270, 270) assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), None, 280) == (0, 270) assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), None, None) == (0, 270)