* Validate the overrides from facefusion.ini

* Break down cli testing

* Remove architecture lookup to support old driver

* Remove architecture lookup to support old driver

* Remove hwaccel auto

* Respect the output video resolution

* Bump next version

* Full directml support (#501)

* Introduce conditional thread management for DML support

* Finish migration to thread helpers

* Introduce dynamic frame colorizer sizes

* Introduce dynamic frame colorizer sizes

* Add 192x192 to frame colorizer

* Fix async audio
This commit is contained in:
Henry Ruhs 2024-04-19 13:35:36 +02:00 committed by GitHub
parent 092dfbb796
commit 4efa5b2c6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 350 additions and 191 deletions

View File

@ -96,6 +96,7 @@ frame processors:
--face-swapper-model {blendswap_256,inswapper_128,inswapper_128_fp16,simswap_256,simswap_512_unofficial,uniface_256} choose the model responsible for swapping the face --face-swapper-model {blendswap_256,inswapper_128,inswapper_128_fp16,simswap_256,simswap_512_unofficial,uniface_256} choose the model responsible for swapping the face
--frame-colorizer-model {ddcolor,ddcolor_artistic,deoldify,deoldify_artistic,deoldify_stable} choose the model responsible for colorizing the frame --frame-colorizer-model {ddcolor,ddcolor_artistic,deoldify,deoldify_artistic,deoldify_stable} choose the model responsible for colorizing the frame
--frame-colorizer-blend [0-100] blend the colorized into the previous frame --frame-colorizer-blend [0-100] blend the colorized into the previous frame
--frame-colorizer-size {192x192,256x256,384x384,512x512} specify the size of the frame provided to the frame colorizer
--frame-enhancer-model {lsdir_x4,nomos8k_sc_x4,real_esrgan_x2,real_esrgan_x2_fp16,real_esrgan_x4,real_esrgan_x4_fp16,real_hatgan_x4,span_kendata_x4} choose the model responsible for enhancing the frame --frame-enhancer-model {lsdir_x4,nomos8k_sc_x4,real_esrgan_x2,real_esrgan_x2_fp16,real_esrgan_x4,real_esrgan_x4_fp16,real_hatgan_x4,span_kendata_x4} choose the model responsible for enhancing the frame
--frame-enhancer-blend [0-100] blend the enhanced into the previous frame --frame-enhancer-blend [0-100] blend the enhanced into the previous frame
--lip-syncer-model {wav2lip_gan} choose the model responsible for syncing the lips --lip-syncer-model {wav2lip_gan} choose the model responsible for syncing the lips

View File

@ -63,6 +63,7 @@ face_enhancer_blend =
face_swapper_model = face_swapper_model =
frame_colorizer_model = frame_colorizer_model =
frame_colorizer_blend = frame_colorizer_blend =
frame_colorizer_size =
frame_enhancer_model = frame_enhancer_model =
frame_enhancer_blend = frame_enhancer_blend =
lip_syncer_model = lip_syncer_model =

View File

@ -1,7 +1,6 @@
from typing import Any from typing import Any
from functools import lru_cache from functools import lru_cache
from time import sleep from time import sleep
import threading
import cv2 import cv2
import numpy import numpy
import onnxruntime import onnxruntime
@ -9,6 +8,7 @@ from tqdm import tqdm
import facefusion.globals import facefusion.globals
from facefusion import process_manager, wording from facefusion import process_manager, wording
from facefusion.thread_helper import thread_lock, conditional_thread_semaphore
from facefusion.typing import VisionFrame, ModelSet, Fps from facefusion.typing import VisionFrame, ModelSet, Fps
from facefusion.execution import apply_execution_provider_options from facefusion.execution import apply_execution_provider_options
from facefusion.vision import get_video_frame, count_video_frame_total, read_image, detect_video_fps from facefusion.vision import get_video_frame, count_video_frame_total, read_image, detect_video_fps
@ -16,7 +16,6 @@ from facefusion.filesystem import resolve_relative_path, is_file
from facefusion.download import conditional_download from facefusion.download import conditional_download
CONTENT_ANALYSER = None CONTENT_ANALYSER = None
THREAD_LOCK : threading.Lock = threading.Lock()
MODELS : ModelSet =\ MODELS : ModelSet =\
{ {
'open_nsfw': 'open_nsfw':
@ -33,7 +32,7 @@ STREAM_COUNTER = 0
def get_content_analyser() -> Any: def get_content_analyser() -> Any:
global CONTENT_ANALYSER global CONTENT_ANALYSER
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if CONTENT_ANALYSER is None: if CONTENT_ANALYSER is None:
@ -72,6 +71,7 @@ def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool:
def analyse_frame(vision_frame : VisionFrame) -> bool: def analyse_frame(vision_frame : VisionFrame) -> bool:
content_analyser = get_content_analyser() content_analyser = get_content_analyser()
vision_frame = prepare_frame(vision_frame) vision_frame = prepare_frame(vision_frame)
with conditional_thread_semaphore(facefusion.globals.execution_providers):
probability = content_analyser.run(None, probability = content_analyser.run(None,
{ {
content_analyser.get_inputs()[0].name: vision_frame content_analyser.get_inputs()[0].name: vision_frame

View File

@ -108,6 +108,19 @@ def cli() -> None:
run(program) run(program)
def validate_args(program : ArgumentParser) -> None:
try:
for action in program._actions:
if action.default:
if isinstance(action.default, list):
for default in action.default:
program._check_value(action, default)
else:
program._check_value(action, action.default)
except Exception as exception:
program.error(str(exception))
def apply_args(program : ArgumentParser) -> None: def apply_args(program : ArgumentParser) -> None:
args = program.parse_args() args = program.parse_args()
# general # general
@ -185,6 +198,7 @@ def apply_args(program : ArgumentParser) -> None:
def run(program : ArgumentParser) -> None: def run(program : ArgumentParser) -> None:
validate_args(program)
apply_args(program) apply_args(program)
logger.init(facefusion.globals.log_level) logger.init(facefusion.globals.log_level)

View File

@ -11,14 +11,14 @@ def encode_execution_providers(execution_providers : List[str]) -> List[str]:
return [ execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers ] return [ execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers ]
def decode_execution_providers(execution_providers: List[str]) -> List[str]: def decode_execution_providers(execution_providers : List[str]) -> List[str]:
available_execution_providers = onnxruntime.get_available_providers() available_execution_providers = onnxruntime.get_available_providers()
encoded_execution_providers = encode_execution_providers(available_execution_providers) encoded_execution_providers = encode_execution_providers(available_execution_providers)
return [ execution_provider for execution_provider, encoded_execution_provider in zip(available_execution_providers, encoded_execution_providers) if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers) ] return [ execution_provider for execution_provider, encoded_execution_provider in zip(available_execution_providers, encoded_execution_providers) if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers) ]
def apply_execution_provider_options(execution_providers: List[str]) -> List[Any]: def apply_execution_provider_options(execution_providers : List[str]) -> List[Any]:
execution_providers_with_options : List[Any] = [] execution_providers_with_options : List[Any] = []
for execution_provider in execution_providers: for execution_provider in execution_providers:
@ -64,13 +64,12 @@ def detect_execution_devices() -> List[ExecutionDevice]:
'framework': 'framework':
{ {
'name': 'CUDA', 'name': 'CUDA',
'version': root_element.find('cuda_version').text, 'version': root_element.find('cuda_version').text
}, },
'product': 'product':
{ {
'vendor': 'NVIDIA', 'vendor': 'NVIDIA',
'name': gpu_element.find('product_name').text.replace('NVIDIA ', ''), 'name': gpu_element.find('product_name').text.replace('NVIDIA ', '')
'architecture': gpu_element.find('product_architecture').text,
}, },
'video_memory': 'video_memory':
{ {

View File

@ -1,6 +1,5 @@
from typing import Any, Optional, List, Tuple from typing import Any, Optional, List, Tuple
from time import sleep from time import sleep
import threading
import cv2 import cv2
import numpy import numpy
import onnxruntime import onnxruntime
@ -13,12 +12,11 @@ from facefusion.face_store import get_static_faces, set_static_faces
from facefusion.execution import apply_execution_provider_options from facefusion.execution import apply_execution_provider_options
from facefusion.download import conditional_download from facefusion.download import conditional_download
from facefusion.filesystem import resolve_relative_path, is_file from facefusion.filesystem import resolve_relative_path, is_file
from facefusion.thread_helper import thread_lock, thread_semaphore, conditional_thread_semaphore
from facefusion.typing import VisionFrame, Face, FaceSet, FaceAnalyserOrder, FaceAnalyserAge, FaceAnalyserGender, ModelSet, BoundingBox, FaceLandmarkSet, FaceLandmark5, FaceLandmark68, Score, FaceScoreSet, Embedding from facefusion.typing import VisionFrame, Face, FaceSet, FaceAnalyserOrder, FaceAnalyserAge, FaceAnalyserGender, ModelSet, BoundingBox, FaceLandmarkSet, FaceLandmark5, FaceLandmark68, Score, FaceScoreSet, Embedding
from facefusion.vision import resize_frame_resolution, unpack_resolution from facefusion.vision import resize_frame_resolution, unpack_resolution
FACE_ANALYSER = None FACE_ANALYSER = None
THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore()
THREAD_LOCK : threading.Lock = threading.Lock()
MODELS : ModelSet =\ MODELS : ModelSet =\
{ {
'face_detector_retinaface': 'face_detector_retinaface':
@ -85,7 +83,7 @@ def get_face_analyser() -> Any:
face_detectors = {} face_detectors = {}
face_landmarkers = {} face_landmarkers = {}
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if FACE_ANALYSER is None: if FACE_ANALYSER is None:
@ -185,7 +183,7 @@ def detect_with_retinaface(vision_frame : VisionFrame, face_detector_size : str)
score_list = [] score_list = []
detect_vision_frame = prepare_detect_frame(temp_vision_frame, face_detector_size) detect_vision_frame = prepare_detect_frame(temp_vision_frame, face_detector_size)
with THREAD_SEMAPHORE: with thread_semaphore():
detections = face_detector.run(None, detections = face_detector.run(None,
{ {
face_detector.get_inputs()[0].name: detect_vision_frame face_detector.get_inputs()[0].name: detect_vision_frame
@ -227,7 +225,7 @@ def detect_with_scrfd(vision_frame : VisionFrame, face_detector_size : str) -> T
score_list = [] score_list = []
detect_vision_frame = prepare_detect_frame(temp_vision_frame, face_detector_size) detect_vision_frame = prepare_detect_frame(temp_vision_frame, face_detector_size)
with THREAD_SEMAPHORE: with thread_semaphore():
detections = face_detector.run(None, detections = face_detector.run(None,
{ {
face_detector.get_inputs()[0].name: detect_vision_frame face_detector.get_inputs()[0].name: detect_vision_frame
@ -266,7 +264,7 @@ def detect_with_yoloface(vision_frame : VisionFrame, face_detector_size : str) -
score_list = [] score_list = []
detect_vision_frame = prepare_detect_frame(temp_vision_frame, face_detector_size) detect_vision_frame = prepare_detect_frame(temp_vision_frame, face_detector_size)
with THREAD_SEMAPHORE: with thread_semaphore():
detections = face_detector.run(None, detections = face_detector.run(None,
{ {
face_detector.get_inputs()[0].name: detect_vision_frame face_detector.get_inputs()[0].name: detect_vision_frame
@ -304,7 +302,7 @@ def detect_with_yunet(vision_frame : VisionFrame, face_detector_size : str) -> T
face_detector.setInputSize((temp_vision_frame.shape[1], temp_vision_frame.shape[0])) face_detector.setInputSize((temp_vision_frame.shape[1], temp_vision_frame.shape[0]))
face_detector.setScoreThreshold(facefusion.globals.face_detector_score) face_detector.setScoreThreshold(facefusion.globals.face_detector_score)
with THREAD_SEMAPHORE: with thread_semaphore():
_, detections = face_detector.detect(temp_vision_frame) _, detections = face_detector.detect(temp_vision_frame)
if numpy.any(detections): if numpy.any(detections):
for detection in detections: for detection in detections:
@ -380,6 +378,7 @@ def calc_embedding(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandma
crop_vision_frame = crop_vision_frame / 127.5 - 1 crop_vision_frame = crop_vision_frame / 127.5 - 1
crop_vision_frame = crop_vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) crop_vision_frame = crop_vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32)
crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0) crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0)
with conditional_thread_semaphore(facefusion.globals.execution_providers):
embedding = face_recognizer.run(None, embedding = face_recognizer.run(None,
{ {
face_recognizer.get_inputs()[0].name: crop_vision_frame face_recognizer.get_inputs()[0].name: crop_vision_frame
@ -399,6 +398,7 @@ def detect_face_landmark_68(temp_vision_frame : VisionFrame, bounding_box : Boun
crop_vision_frame[:, :, 0] = cv2.createCLAHE(clipLimit = 2).apply(crop_vision_frame[:, :, 0]) crop_vision_frame[:, :, 0] = cv2.createCLAHE(clipLimit = 2).apply(crop_vision_frame[:, :, 0])
crop_vision_frame = cv2.cvtColor(crop_vision_frame, cv2.COLOR_Lab2RGB) crop_vision_frame = cv2.cvtColor(crop_vision_frame, cv2.COLOR_Lab2RGB)
crop_vision_frame = crop_vision_frame.transpose(2, 0, 1).astype(numpy.float32) / 255.0 crop_vision_frame = crop_vision_frame.transpose(2, 0, 1).astype(numpy.float32) / 255.0
with conditional_thread_semaphore(facefusion.globals.execution_providers):
face_landmark_68, face_heatmap = face_landmarker.run(None, face_landmark_68, face_heatmap = face_landmarker.run(None,
{ {
face_landmarker.get_inputs()[0].name: [ crop_vision_frame ] face_landmarker.get_inputs()[0].name: [ crop_vision_frame ]
@ -416,6 +416,7 @@ def expand_face_landmark_68_from_5(face_landmark_5 : FaceLandmark5) -> FaceLandm
face_landmarker = get_face_analyser().get('face_landmarkers').get('68_5') face_landmarker = get_face_analyser().get('face_landmarkers').get('68_5')
affine_matrix = estimate_matrix_by_face_landmark_5(face_landmark_5, 'ffhq_512', (1, 1)) affine_matrix = estimate_matrix_by_face_landmark_5(face_landmark_5, 'ffhq_512', (1, 1))
face_landmark_5 = cv2.transform(face_landmark_5.reshape(1, -1, 2), affine_matrix).reshape(-1, 2) face_landmark_5 = cv2.transform(face_landmark_5.reshape(1, -1, 2), affine_matrix).reshape(-1, 2)
with conditional_thread_semaphore(facefusion.globals.execution_providers):
face_landmark_68_5 = face_landmarker.run(None, face_landmark_68_5 = face_landmarker.run(None,
{ {
face_landmarker.get_inputs()[0].name: [ face_landmark_5 ] face_landmarker.get_inputs()[0].name: [ face_landmark_5 ]
@ -432,6 +433,7 @@ def detect_gender_age(temp_vision_frame : VisionFrame, bounding_box : BoundingBo
crop_vision_frame, affine_matrix = warp_face_by_translation(temp_vision_frame, translation, scale, (96, 96)) crop_vision_frame, affine_matrix = warp_face_by_translation(temp_vision_frame, translation, scale, (96, 96))
crop_vision_frame = crop_vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) crop_vision_frame = crop_vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32)
crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0) crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0)
with conditional_thread_semaphore(facefusion.globals.execution_providers):
prediction = gender_age.run(None, prediction = gender_age.run(None,
{ {
gender_age.get_inputs()[0].name: crop_vision_frame gender_age.get_inputs()[0].name: crop_vision_frame

View File

@ -2,13 +2,13 @@ from typing import Any, Dict, List
from cv2.typing import Size from cv2.typing import Size
from functools import lru_cache from functools import lru_cache
from time import sleep from time import sleep
import threading
import cv2 import cv2
import numpy import numpy
import onnxruntime import onnxruntime
import facefusion.globals import facefusion.globals
from facefusion import process_manager from facefusion import process_manager
from facefusion.thread_helper import thread_lock, conditional_thread_semaphore
from facefusion.typing import FaceLandmark68, VisionFrame, Mask, Padding, FaceMaskRegion, ModelSet from facefusion.typing import FaceLandmark68, VisionFrame, Mask, Padding, FaceMaskRegion, ModelSet
from facefusion.execution import apply_execution_provider_options from facefusion.execution import apply_execution_provider_options
from facefusion.filesystem import resolve_relative_path, is_file from facefusion.filesystem import resolve_relative_path, is_file
@ -16,7 +16,6 @@ from facefusion.download import conditional_download
FACE_OCCLUDER = None FACE_OCCLUDER = None
FACE_PARSER = None FACE_PARSER = None
THREAD_LOCK : threading.Lock = threading.Lock()
MODELS : ModelSet =\ MODELS : ModelSet =\
{ {
'face_occluder': 'face_occluder':
@ -48,7 +47,7 @@ FACE_MASK_REGIONS : Dict[FaceMaskRegion, int] =\
def get_face_occluder() -> Any: def get_face_occluder() -> Any:
global FACE_OCCLUDER global FACE_OCCLUDER
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if FACE_OCCLUDER is None: if FACE_OCCLUDER is None:
@ -60,7 +59,7 @@ def get_face_occluder() -> Any:
def get_face_parser() -> Any: def get_face_parser() -> Any:
global FACE_PARSER global FACE_PARSER
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if FACE_PARSER is None: if FACE_PARSER is None:
@ -120,6 +119,7 @@ def create_occlusion_mask(crop_vision_frame : VisionFrame) -> Mask:
prepare_vision_frame = cv2.resize(crop_vision_frame, face_occluder.get_inputs()[0].shape[1:3][::-1]) prepare_vision_frame = cv2.resize(crop_vision_frame, face_occluder.get_inputs()[0].shape[1:3][::-1])
prepare_vision_frame = numpy.expand_dims(prepare_vision_frame, axis = 0).astype(numpy.float32) / 255 prepare_vision_frame = numpy.expand_dims(prepare_vision_frame, axis = 0).astype(numpy.float32) / 255
prepare_vision_frame = prepare_vision_frame.transpose(0, 1, 2, 3) prepare_vision_frame = prepare_vision_frame.transpose(0, 1, 2, 3)
with conditional_thread_semaphore(facefusion.globals.execution_providers):
occlusion_mask : Mask = face_occluder.run(None, occlusion_mask : Mask = face_occluder.run(None,
{ {
face_occluder.get_inputs()[0].name: prepare_vision_frame face_occluder.get_inputs()[0].name: prepare_vision_frame
@ -135,6 +135,7 @@ def create_region_mask(crop_vision_frame : VisionFrame, face_mask_regions : List
prepare_vision_frame = cv2.flip(cv2.resize(crop_vision_frame, (512, 512)), 1) prepare_vision_frame = cv2.flip(cv2.resize(crop_vision_frame, (512, 512)), 1)
prepare_vision_frame = numpy.expand_dims(prepare_vision_frame, axis = 0).astype(numpy.float32)[:, :, ::-1] / 127.5 - 1 prepare_vision_frame = numpy.expand_dims(prepare_vision_frame, axis = 0).astype(numpy.float32)[:, :, ::-1] / 127.5 - 1
prepare_vision_frame = prepare_vision_frame.transpose(0, 3, 1, 2) prepare_vision_frame = prepare_vision_frame.transpose(0, 3, 1, 2)
with conditional_thread_semaphore(facefusion.globals.execution_providers):
region_mask : Mask = face_parser.run(None, region_mask : Mask = face_parser.run(None,
{ {
face_parser.get_inputs()[0].name: prepare_vision_frame face_parser.get_inputs()[0].name: prepare_vision_frame

View File

@ -44,16 +44,16 @@ def extract_frames(target_path : str, temp_video_resolution : str, temp_video_fp
trim_frame_start = facefusion.globals.trim_frame_start trim_frame_start = facefusion.globals.trim_frame_start
trim_frame_end = facefusion.globals.trim_frame_end trim_frame_end = facefusion.globals.trim_frame_end
temp_frames_pattern = get_temp_frames_pattern(target_path, '%04d') temp_frames_pattern = get_temp_frames_pattern(target_path, '%04d')
commands = [ '-hwaccel', 'auto', '-i', target_path, '-q:v', '0' ] commands = [ '-i', target_path, '-s', str(temp_video_resolution), '-q:v', '0' ]
if trim_frame_start is not None and trim_frame_end is not None: if trim_frame_start is not None and trim_frame_end is not None:
commands.extend([ '-vf', 'trim=start_frame=' + str(trim_frame_start) + ':end_frame=' + str(trim_frame_end) + ',scale=' + str(temp_video_resolution) + ',fps=' + str(temp_video_fps) ]) commands.extend([ '-vf', 'trim=start_frame=' + str(trim_frame_start) + ':end_frame=' + str(trim_frame_end) + ',fps=' + str(temp_video_fps) ])
elif trim_frame_start is not None: elif trim_frame_start is not None:
commands.extend([ '-vf', 'trim=start_frame=' + str(trim_frame_start) + ',scale=' + str(temp_video_resolution) + ',fps=' + str(temp_video_fps) ]) commands.extend([ '-vf', 'trim=start_frame=' + str(trim_frame_start) + ',fps=' + str(temp_video_fps) ])
elif trim_frame_end is not None: elif trim_frame_end is not None:
commands.extend([ '-vf', 'trim=end_frame=' + str(trim_frame_end) + ',scale=' + str(temp_video_resolution) + ',fps=' + str(temp_video_fps) ]) commands.extend([ '-vf', 'trim=end_frame=' + str(trim_frame_end) + ',fps=' + str(temp_video_fps) ])
else: else:
commands.extend([ '-vf', 'scale=' + str(temp_video_resolution) + ',fps=' + str(temp_video_fps) ]) commands.extend([ '-vf', 'fps=' + str(temp_video_fps) ])
commands.extend([ '-vsync', '0', temp_frames_pattern ]) commands.extend([ '-vsync', '0', temp_frames_pattern ])
return run_ffmpeg(commands) return run_ffmpeg(commands)
@ -62,7 +62,7 @@ def merge_video(target_path : str, output_video_resolution : str, output_video_f
temp_video_fps = restrict_video_fps(target_path, output_video_fps) temp_video_fps = restrict_video_fps(target_path, output_video_fps)
temp_output_video_path = get_temp_output_video_path(target_path) temp_output_video_path = get_temp_output_video_path(target_path)
temp_frames_pattern = get_temp_frames_pattern(target_path, '%04d') temp_frames_pattern = get_temp_frames_pattern(target_path, '%04d')
commands = [ '-hwaccel', 'auto', '-s', str(output_video_resolution), '-r', str(temp_video_fps), '-i', temp_frames_pattern, '-c:v', facefusion.globals.output_video_encoder ] commands = [ '-r', str(temp_video_fps), '-i', temp_frames_pattern, '-s', str(output_video_resolution), '-c:v', facefusion.globals.output_video_encoder ]
if facefusion.globals.output_video_encoder in [ 'libx264', 'libx265' ]: if facefusion.globals.output_video_encoder in [ 'libx264', 'libx265' ]:
output_video_compression = round(51 - (facefusion.globals.output_video_quality * 0.51)) output_video_compression = round(51 - (facefusion.globals.output_video_quality * 0.51))
@ -83,13 +83,13 @@ def merge_video(target_path : str, output_video_resolution : str, output_video_f
def copy_image(target_path : str, output_path : str, temp_image_resolution : str) -> bool: def copy_image(target_path : str, output_path : str, temp_image_resolution : str) -> bool:
is_webp = filetype.guess_mime(target_path) == 'image/webp' is_webp = filetype.guess_mime(target_path) == 'image/webp'
temp_image_compression = 100 if is_webp else 0 temp_image_compression = 100 if is_webp else 0
commands = [ '-i', target_path, '-q:v', str(temp_image_compression), '-vf', 'scale=' + str(temp_image_resolution), '-y', output_path ] commands = [ '-i', target_path, '-s', str(temp_image_resolution), '-q:v', str(temp_image_compression), '-y', output_path ]
return run_ffmpeg(commands) return run_ffmpeg(commands)
def finalize_image(output_path : str, output_image_resolution : str) -> bool: def finalize_image(output_path : str, output_image_resolution : str) -> bool:
output_image_compression = round(31 - (facefusion.globals.output_image_quality * 0.31)) output_image_compression = round(31 - (facefusion.globals.output_image_quality * 0.31))
commands = [ '-i', output_path, '-q:v', str(output_image_compression), '-vf', 'scale=' + str(output_image_resolution), '-y', output_path ] commands = [ '-i', output_path, '-s', str(output_image_resolution), '-q:v', str(output_image_compression), '-y', output_path ]
return run_ffmpeg(commands) return run_ffmpeg(commands)
@ -106,7 +106,7 @@ def restore_audio(target_path : str, output_path : str, output_video_fps : Fps)
trim_frame_start = facefusion.globals.trim_frame_start trim_frame_start = facefusion.globals.trim_frame_start
trim_frame_end = facefusion.globals.trim_frame_end trim_frame_end = facefusion.globals.trim_frame_end
temp_output_video_path = get_temp_output_video_path(target_path) temp_output_video_path = get_temp_output_video_path(target_path)
commands = [ '-hwaccel', 'auto', '-i', temp_output_video_path ] commands = [ '-i', temp_output_video_path ]
if trim_frame_start is not None: if trim_frame_start is not None:
start_time = trim_frame_start / output_video_fps start_time = trim_frame_start / output_video_fps
@ -120,7 +120,7 @@ def restore_audio(target_path : str, output_path : str, output_video_fps : Fps)
def replace_audio(target_path : str, audio_path : str, output_path : str) -> bool: def replace_audio(target_path : str, audio_path : str, output_path : str) -> bool:
temp_output_path = get_temp_output_video_path(target_path) temp_output_path = get_temp_output_video_path(target_path)
commands = [ '-hwaccel', 'auto', '-i', temp_output_path, '-i', audio_path, '-af', 'apad', '-shortest', '-y', output_path ] commands = [ '-i', temp_output_path, '-i', audio_path, '-af', 'apad', '-shortest', '-y', output_path ]
return run_ffmpeg(commands) return run_ffmpeg(commands)

View File

@ -2,7 +2,7 @@ METADATA =\
{ {
'name': 'FaceFusion', 'name': 'FaceFusion',
'description': 'Next generation face swapper and enhancer', 'description': 'Next generation face swapper and enhancer',
'version': '2.5.1', 'version': '2.5.2',
'license': 'MIT', 'license': 'MIT',
'author': 'Henry Ruhs', 'author': 'Henry Ruhs',
'url': 'https://facefusion.io' 'url': 'https://facefusion.io'

View File

@ -7,6 +7,7 @@ face_debugger_items : List[FaceDebuggerItem] = [ 'bounding-box', 'face-landmark-
face_enhancer_models : List[FaceEnhancerModel] = [ 'codeformer', 'gfpgan_1.2', 'gfpgan_1.3', 'gfpgan_1.4', 'gpen_bfr_256', 'gpen_bfr_512', 'gpen_bfr_1024', 'gpen_bfr_2048', 'restoreformer_plus_plus' ] face_enhancer_models : List[FaceEnhancerModel] = [ 'codeformer', 'gfpgan_1.2', 'gfpgan_1.3', 'gfpgan_1.4', 'gpen_bfr_256', 'gpen_bfr_512', 'gpen_bfr_1024', 'gpen_bfr_2048', 'restoreformer_plus_plus' ]
face_swapper_models : List[FaceSwapperModel] = [ 'blendswap_256', 'inswapper_128', 'inswapper_128_fp16', 'simswap_256', 'simswap_512_unofficial', 'uniface_256' ] face_swapper_models : List[FaceSwapperModel] = [ 'blendswap_256', 'inswapper_128', 'inswapper_128_fp16', 'simswap_256', 'simswap_512_unofficial', 'uniface_256' ]
frame_colorizer_models : List[FrameColorizerModel] = [ 'ddcolor', 'ddcolor_artistic', 'deoldify', 'deoldify_artistic', 'deoldify_stable' ] frame_colorizer_models : List[FrameColorizerModel] = [ 'ddcolor', 'ddcolor_artistic', 'deoldify', 'deoldify_artistic', 'deoldify_stable' ]
frame_colorizer_sizes : List[str] = [ '192x192', '256x256', '384x384', '512x512' ]
frame_enhancer_models : List[FrameEnhancerModel] = [ 'lsdir_x4', 'nomos8k_sc_x4', 'real_esrgan_x2', 'real_esrgan_x2_fp16', 'real_esrgan_x4', 'real_esrgan_x4_fp16', 'real_hatgan_x4', 'span_kendata_x4' ] frame_enhancer_models : List[FrameEnhancerModel] = [ 'lsdir_x4', 'nomos8k_sc_x4', 'real_esrgan_x2', 'real_esrgan_x2_fp16', 'real_esrgan_x4', 'real_esrgan_x4_fp16', 'real_hatgan_x4', 'span_kendata_x4' ]
lip_syncer_models : List[LipSyncerModel] = [ 'wav2lip_gan' ] lip_syncer_models : List[LipSyncerModel] = [ 'wav2lip_gan' ]

View File

@ -8,6 +8,7 @@ face_enhancer_blend : Optional[int] = None
face_swapper_model : Optional[FaceSwapperModel] = None face_swapper_model : Optional[FaceSwapperModel] = None
frame_colorizer_model : Optional[FrameColorizerModel] = None frame_colorizer_model : Optional[FrameColorizerModel] = None
frame_colorizer_blend : Optional[int] = None frame_colorizer_blend : Optional[int] = None
frame_colorizer_size : Optional[str] = None
frame_enhancer_model : Optional[FrameEnhancerModel] = None frame_enhancer_model : Optional[FrameEnhancerModel] = None
frame_enhancer_blend : Optional[int] = None frame_enhancer_blend : Optional[int] = None
lip_syncer_model : Optional[LipSyncerModel] = None lip_syncer_model : Optional[LipSyncerModel] = None

View File

@ -2,7 +2,6 @@ from typing import Any, List, Literal, Optional
from argparse import ArgumentParser from argparse import ArgumentParser
from time import sleep from time import sleep
import cv2 import cv2
import threading
import numpy import numpy
import onnxruntime import onnxruntime
@ -16,6 +15,7 @@ from facefusion.execution import apply_execution_provider_options
from facefusion.content_analyser import clear_content_analyser from facefusion.content_analyser import clear_content_analyser
from facefusion.face_store import get_reference_faces from facefusion.face_store import get_reference_faces
from facefusion.normalizer import normalize_output_path from facefusion.normalizer import normalize_output_path
from facefusion.thread_helper import thread_lock, thread_semaphore
from facefusion.typing import Face, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, QueuePayload from facefusion.typing import Face, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, QueuePayload
from facefusion.common_helper import create_metavar from facefusion.common_helper import create_metavar
from facefusion.filesystem import is_file, is_image, is_video, resolve_relative_path from facefusion.filesystem import is_file, is_image, is_video, resolve_relative_path
@ -26,8 +26,6 @@ from facefusion.processors.frame import globals as frame_processors_globals
from facefusion.processors.frame import choices as frame_processors_choices from facefusion.processors.frame import choices as frame_processors_choices
FRAME_PROCESSOR = None FRAME_PROCESSOR = None
THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore()
THREAD_LOCK : threading.Lock = threading.Lock()
NAME = __name__.upper() NAME = __name__.upper()
MODELS : ModelSet =\ MODELS : ModelSet =\
{ {
@ -101,7 +99,7 @@ OPTIONS : Optional[OptionsWithModel] = None
def get_frame_processor() -> Any: def get_frame_processor() -> Any:
global FRAME_PROCESSOR global FRAME_PROCESSOR
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if FRAME_PROCESSOR is None: if FRAME_PROCESSOR is None:
@ -221,7 +219,7 @@ def apply_enhance(crop_vision_frame : VisionFrame) -> VisionFrame:
if frame_processor_input.name == 'weight': if frame_processor_input.name == 'weight':
weight = numpy.array([ 1 ]).astype(numpy.double) weight = numpy.array([ 1 ]).astype(numpy.double)
frame_processor_inputs[frame_processor_input.name] = weight frame_processor_inputs[frame_processor_input.name] = weight
with THREAD_SEMAPHORE: with thread_semaphore():
crop_vision_frame = frame_processor.run(None, frame_processor_inputs)[0][0] crop_vision_frame = frame_processor.run(None, frame_processor_inputs)[0][0]
return crop_vision_frame return crop_vision_frame

View File

@ -2,7 +2,6 @@ from typing import Any, List, Literal, Optional
from argparse import ArgumentParser from argparse import ArgumentParser
from time import sleep from time import sleep
import platform import platform
import threading
import numpy import numpy
import onnx import onnx
import onnxruntime import onnxruntime
@ -18,6 +17,7 @@ from facefusion.face_helper import warp_face_by_face_landmark_5, paste_back
from facefusion.face_store import get_reference_faces from facefusion.face_store import get_reference_faces
from facefusion.content_analyser import clear_content_analyser from facefusion.content_analyser import clear_content_analyser
from facefusion.normalizer import normalize_output_path from facefusion.normalizer import normalize_output_path
from facefusion.thread_helper import thread_lock, conditional_thread_semaphore
from facefusion.typing import Face, Embedding, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, QueuePayload from facefusion.typing import Face, Embedding, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, QueuePayload
from facefusion.filesystem import is_file, is_image, has_image, is_video, filter_image_paths, resolve_relative_path from facefusion.filesystem import is_file, is_image, has_image, is_video, filter_image_paths, resolve_relative_path
from facefusion.download import conditional_download, is_download_done from facefusion.download import conditional_download, is_download_done
@ -28,7 +28,6 @@ from facefusion.processors.frame import choices as frame_processors_choices
FRAME_PROCESSOR = None FRAME_PROCESSOR = None
MODEL_INITIALIZER = None MODEL_INITIALIZER = None
THREAD_LOCK : threading.Lock = threading.Lock()
NAME = __name__.upper() NAME = __name__.upper()
MODELS : ModelSet =\ MODELS : ModelSet =\
{ {
@ -99,7 +98,7 @@ OPTIONS : Optional[OptionsWithModel] = None
def get_frame_processor() -> Any: def get_frame_processor() -> Any:
global FRAME_PROCESSOR global FRAME_PROCESSOR
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if FRAME_PROCESSOR is None: if FRAME_PROCESSOR is None:
@ -117,7 +116,7 @@ def clear_frame_processor() -> None:
def get_model_initializer() -> Any: def get_model_initializer() -> Any:
global MODEL_INITIALIZER global MODEL_INITIALIZER
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if MODEL_INITIALIZER is None: if MODEL_INITIALIZER is None:
@ -263,6 +262,7 @@ def apply_swap(source_face : Face, crop_vision_frame : VisionFrame) -> VisionFra
frame_processor_inputs[frame_processor_input.name] = prepare_source_embedding(source_face) frame_processor_inputs[frame_processor_input.name] = prepare_source_embedding(source_face)
if frame_processor_input.name == 'target': if frame_processor_input.name == 'target':
frame_processor_inputs[frame_processor_input.name] = crop_vision_frame frame_processor_inputs[frame_processor_input.name] = crop_vision_frame
with conditional_thread_semaphore(facefusion.globals.execution_providers):
crop_vision_frame = frame_processor.run(None, frame_processor_inputs)[0][0] crop_vision_frame = frame_processor.run(None, frame_processor_inputs)[0][0]
return crop_vision_frame return crop_vision_frame

View File

@ -1,7 +1,6 @@
from typing import Any, List, Literal, Optional from typing import Any, List, Literal, Optional
from argparse import ArgumentParser from argparse import ArgumentParser
from time import sleep from time import sleep
import threading
import cv2 import cv2
import numpy import numpy
import onnxruntime import onnxruntime
@ -13,18 +12,17 @@ from facefusion.face_analyser import clear_face_analyser
from facefusion.content_analyser import clear_content_analyser from facefusion.content_analyser import clear_content_analyser
from facefusion.execution import apply_execution_provider_options from facefusion.execution import apply_execution_provider_options
from facefusion.normalizer import normalize_output_path from facefusion.normalizer import normalize_output_path
from facefusion.thread_helper import thread_lock, thread_semaphore
from facefusion.typing import Face, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, QueuePayload from facefusion.typing import Face, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, QueuePayload
from facefusion.common_helper import create_metavar from facefusion.common_helper import create_metavar
from facefusion.filesystem import is_file, resolve_relative_path, is_image, is_video from facefusion.filesystem import is_file, resolve_relative_path, is_image, is_video
from facefusion.download import conditional_download, is_download_done from facefusion.download import conditional_download, is_download_done
from facefusion.vision import read_image, read_static_image, write_image from facefusion.vision import read_image, read_static_image, write_image, unpack_resolution
from facefusion.processors.frame.typings import FrameColorizerInputs from facefusion.processors.frame.typings import FrameColorizerInputs
from facefusion.processors.frame import globals as frame_processors_globals from facefusion.processors.frame import globals as frame_processors_globals
from facefusion.processors.frame import choices as frame_processors_choices from facefusion.processors.frame import choices as frame_processors_choices
FRAME_PROCESSOR = None FRAME_PROCESSOR = None
THREAD_LOCK : threading.Lock = threading.Lock()
THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore()
NAME = __name__.upper() NAME = __name__.upper()
MODELS : ModelSet =\ MODELS : ModelSet =\
{ {
@ -32,36 +30,31 @@ MODELS : ModelSet =\
{ {
'type': 'ddcolor', 'type': 'ddcolor',
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/ddcolor.onnx', 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/ddcolor.onnx',
'path': resolve_relative_path('../.assets/models/ddcolor.onnx'), 'path': resolve_relative_path('../.assets/models/ddcolor.onnx')
'size': (512, 512)
}, },
'ddcolor_artistic': 'ddcolor_artistic':
{ {
'type': 'ddcolor', 'type': 'ddcolor',
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/ddcolor_artistic.onnx', 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/ddcolor_artistic.onnx',
'path': resolve_relative_path('../.assets/models/ddcolor_artistic.onnx'), 'path': resolve_relative_path('../.assets/models/ddcolor_artistic.onnx')
'size': (512, 512)
}, },
'deoldify': 'deoldify':
{ {
'type': 'deoldify', 'type': 'deoldify',
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/deoldify.onnx', 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/deoldify.onnx',
'path': resolve_relative_path('../.assets/models/deoldify.onnx'), 'path': resolve_relative_path('../.assets/models/deoldify.onnx')
'size': (256, 256)
}, },
'deoldify_artistic': 'deoldify_artistic':
{ {
'type': 'deoldify', 'type': 'deoldify',
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/deoldify_artistic.onnx', 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/deoldify_artistic.onnx',
'path': resolve_relative_path('../.assets/models/deoldify_artistic.onnx'), 'path': resolve_relative_path('../.assets/models/deoldify_artistic.onnx')
'size': (256, 256)
}, },
'deoldify_stable': 'deoldify_stable':
{ {
'type': 'deoldify', 'type': 'deoldify',
'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/deoldify_stable.onnx', 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models/deoldify_stable.onnx',
'path': resolve_relative_path('../.assets/models/deoldify_stable.onnx'), 'path': resolve_relative_path('../.assets/models/deoldify_stable.onnx')
'size': (256, 256)
} }
} }
OPTIONS : Optional[OptionsWithModel] = None OPTIONS : Optional[OptionsWithModel] = None
@ -70,7 +63,7 @@ OPTIONS : Optional[OptionsWithModel] = None
def get_frame_processor() -> Any: def get_frame_processor() -> Any:
global FRAME_PROCESSOR global FRAME_PROCESSOR
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if FRAME_PROCESSOR is None: if FRAME_PROCESSOR is None:
@ -105,12 +98,14 @@ def set_options(key : Literal['model'], value : Any) -> None:
def register_args(program : ArgumentParser) -> None: def register_args(program : ArgumentParser) -> None:
program.add_argument('--frame-colorizer-model', help = wording.get('help.frame_colorizer_model'), default = config.get_str_value('frame_processors.frame_colorizer_model', 'ddcolor'), choices = frame_processors_choices.frame_colorizer_models) program.add_argument('--frame-colorizer-model', help = wording.get('help.frame_colorizer_model'), default = config.get_str_value('frame_processors.frame_colorizer_model', 'ddcolor'), choices = frame_processors_choices.frame_colorizer_models)
program.add_argument('--frame-colorizer-blend', help = wording.get('help.frame_colorizer_blend'), type = int, default = config.get_int_value('frame_processors.frame_colorizer_blend', '100'), choices = frame_processors_choices.frame_colorizer_blend_range, metavar = create_metavar(frame_processors_choices.frame_colorizer_blend_range)) program.add_argument('--frame-colorizer-blend', help = wording.get('help.frame_colorizer_blend'), type = int, default = config.get_int_value('frame_processors.frame_colorizer_blend', '100'), choices = frame_processors_choices.frame_colorizer_blend_range, metavar = create_metavar(frame_processors_choices.frame_colorizer_blend_range))
program.add_argument('--frame-colorizer-size', help = wording.get('help.frame_colorizer_size'), type = str, default = config.get_str_value('frame_processors.frame_colorizer_size', '256x256'), choices = frame_processors_choices.frame_colorizer_sizes)
def apply_args(program : ArgumentParser) -> None: def apply_args(program : ArgumentParser) -> None:
args = program.parse_args() args = program.parse_args()
frame_processors_globals.frame_colorizer_model = args.frame_colorizer_model frame_processors_globals.frame_colorizer_model = args.frame_colorizer_model
frame_processors_globals.frame_colorizer_blend = args.frame_colorizer_blend frame_processors_globals.frame_colorizer_blend = args.frame_colorizer_blend
frame_processors_globals.frame_colorizer_size = args.frame_colorizer_size
def pre_check() -> bool: def pre_check() -> bool:
@ -160,7 +155,7 @@ def post_process() -> None:
def colorize_frame(temp_vision_frame : VisionFrame) -> VisionFrame: def colorize_frame(temp_vision_frame : VisionFrame) -> VisionFrame:
frame_processor = get_frame_processor() frame_processor = get_frame_processor()
prepare_vision_frame = prepare_temp_frame(temp_vision_frame) prepare_vision_frame = prepare_temp_frame(temp_vision_frame)
with THREAD_SEMAPHORE: with thread_semaphore():
color_vision_frame = frame_processor.run(None, color_vision_frame = frame_processor.run(None,
{ {
frame_processor.get_inputs()[0].name: prepare_vision_frame frame_processor.get_inputs()[0].name: prepare_vision_frame
@ -171,7 +166,7 @@ def colorize_frame(temp_vision_frame : VisionFrame) -> VisionFrame:
def prepare_temp_frame(temp_vision_frame : VisionFrame) -> VisionFrame: def prepare_temp_frame(temp_vision_frame : VisionFrame) -> VisionFrame:
model_size = get_options('model').get('size') model_size = unpack_resolution(frame_processors_globals.frame_colorizer_size)
model_type = get_options('model').get('type') model_type = get_options('model').get('type')
temp_vision_frame = cv2.cvtColor(temp_vision_frame, cv2.COLOR_BGR2GRAY) temp_vision_frame = cv2.cvtColor(temp_vision_frame, cv2.COLOR_BGR2GRAY)
temp_vision_frame = cv2.cvtColor(temp_vision_frame, cv2.COLOR_GRAY2RGB) temp_vision_frame = cv2.cvtColor(temp_vision_frame, cv2.COLOR_GRAY2RGB)

View File

@ -1,7 +1,6 @@
from typing import Any, List, Literal, Optional from typing import Any, List, Literal, Optional
from argparse import ArgumentParser from argparse import ArgumentParser
from time import sleep from time import sleep
import threading
import cv2 import cv2
import numpy import numpy
import onnxruntime import onnxruntime
@ -13,6 +12,7 @@ from facefusion.face_analyser import clear_face_analyser
from facefusion.content_analyser import clear_content_analyser from facefusion.content_analyser import clear_content_analyser
from facefusion.execution import apply_execution_provider_options from facefusion.execution import apply_execution_provider_options
from facefusion.normalizer import normalize_output_path from facefusion.normalizer import normalize_output_path
from facefusion.thread_helper import thread_lock, conditional_thread_semaphore
from facefusion.typing import Face, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, QueuePayload from facefusion.typing import Face, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, QueuePayload
from facefusion.common_helper import create_metavar from facefusion.common_helper import create_metavar
from facefusion.filesystem import is_file, resolve_relative_path, is_image, is_video from facefusion.filesystem import is_file, resolve_relative_path, is_image, is_video
@ -23,7 +23,6 @@ from facefusion.processors.frame import globals as frame_processors_globals
from facefusion.processors.frame import choices as frame_processors_choices from facefusion.processors.frame import choices as frame_processors_choices
FRAME_PROCESSOR = None FRAME_PROCESSOR = None
THREAD_LOCK : threading.Lock = threading.Lock()
NAME = __name__.upper() NAME = __name__.upper()
MODELS : ModelSet =\ MODELS : ModelSet =\
{ {
@ -90,7 +89,7 @@ OPTIONS : Optional[OptionsWithModel] = None
def get_frame_processor() -> Any: def get_frame_processor() -> Any:
global FRAME_PROCESSOR global FRAME_PROCESSOR
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if FRAME_PROCESSOR is None: if FRAME_PROCESSOR is None:
@ -185,6 +184,7 @@ def enhance_frame(temp_vision_frame : VisionFrame) -> VisionFrame:
tile_vision_frames, pad_width, pad_height = create_tile_frames(temp_vision_frame, size) tile_vision_frames, pad_width, pad_height = create_tile_frames(temp_vision_frame, size)
for index, tile_vision_frame in enumerate(tile_vision_frames): for index, tile_vision_frame in enumerate(tile_vision_frames):
with conditional_thread_semaphore(facefusion.globals.execution_providers):
tile_vision_frame = frame_processor.run(None, tile_vision_frame = frame_processor.run(None,
{ {
frame_processor.get_inputs()[0].name : prepare_tile_frame(tile_vision_frame) frame_processor.get_inputs()[0].name : prepare_tile_frame(tile_vision_frame)

View File

@ -1,7 +1,6 @@
from typing import Any, List, Literal, Optional from typing import Any, List, Literal, Optional
from argparse import ArgumentParser from argparse import ArgumentParser
from time import sleep from time import sleep
import threading
import cv2 import cv2
import numpy import numpy
import onnxruntime import onnxruntime
@ -16,6 +15,7 @@ from facefusion.face_helper import warp_face_by_face_landmark_5, warp_face_by_bo
from facefusion.face_store import get_reference_faces from facefusion.face_store import get_reference_faces
from facefusion.content_analyser import clear_content_analyser from facefusion.content_analyser import clear_content_analyser
from facefusion.normalizer import normalize_output_path from facefusion.normalizer import normalize_output_path
from facefusion.thread_helper import thread_lock, conditional_thread_semaphore
from facefusion.typing import Face, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, AudioFrame, QueuePayload from facefusion.typing import Face, VisionFrame, UpdateProgress, ProcessMode, ModelSet, OptionsWithModel, AudioFrame, QueuePayload
from facefusion.filesystem import is_file, has_audio, resolve_relative_path from facefusion.filesystem import is_file, has_audio, resolve_relative_path
from facefusion.download import conditional_download, is_download_done from facefusion.download import conditional_download, is_download_done
@ -29,7 +29,6 @@ from facefusion.processors.frame import globals as frame_processors_globals
from facefusion.processors.frame import choices as frame_processors_choices from facefusion.processors.frame import choices as frame_processors_choices
FRAME_PROCESSOR = None FRAME_PROCESSOR = None
THREAD_LOCK : threading.Lock = threading.Lock()
NAME = __name__.upper() NAME = __name__.upper()
MODELS : ModelSet =\ MODELS : ModelSet =\
{ {
@ -45,7 +44,7 @@ OPTIONS : Optional[OptionsWithModel] = None
def get_frame_processor() -> Any: def get_frame_processor() -> Any:
global FRAME_PROCESSOR global FRAME_PROCESSOR
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if FRAME_PROCESSOR is None: if FRAME_PROCESSOR is None:
@ -155,6 +154,7 @@ def sync_lip(target_face : Face, temp_audio_frame : AudioFrame, temp_vision_fram
crop_mask_list.append(occlusion_mask) crop_mask_list.append(occlusion_mask)
close_vision_frame, close_matrix = warp_face_by_bounding_box(crop_vision_frame, bounding_box, (96, 96)) close_vision_frame, close_matrix = warp_face_by_bounding_box(crop_vision_frame, bounding_box, (96, 96))
close_vision_frame = prepare_crop_frame(close_vision_frame) close_vision_frame = prepare_crop_frame(close_vision_frame)
with conditional_thread_semaphore(facefusion.globals.execution_providers):
close_vision_frame = frame_processor.run(None, close_vision_frame = frame_processor.run(None,
{ {
'source': temp_audio_frame, 'source': temp_audio_frame,

View File

@ -0,0 +1,21 @@
from typing import List, Union, ContextManager
import threading
from contextlib import nullcontext
THREAD_LOCK : threading.Lock = threading.Lock()
THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore()
NULL_CONTEXT : ContextManager[None] = nullcontext()
def thread_lock() -> threading.Lock:
return THREAD_LOCK
def thread_semaphore() -> threading.Semaphore:
return THREAD_SEMAPHORE
def conditional_thread_semaphore(execution_providers : List[str]) -> Union[threading.Semaphore, ContextManager[None]]:
if 'DmlExecutionProvider' in execution_providers:
return THREAD_SEMAPHORE
return NULL_CONTEXT

View File

@ -100,8 +100,7 @@ ExecutionDeviceFramework = TypedDict('ExecutionDeviceFramework',
ExecutionDeviceProduct = TypedDict('ExecutionDeviceProduct', ExecutionDeviceProduct = TypedDict('ExecutionDeviceProduct',
{ {
'vendor' : str, 'vendor' : str,
'name' : str, 'name' : str
'architecture' : str,
}) })
ExecutionDeviceVideoMemory = TypedDict('ExecutionDeviceVideoMemory', ExecutionDeviceVideoMemory = TypedDict('ExecutionDeviceVideoMemory',
{ {

View File

@ -14,6 +14,7 @@ FACE_ENHANCER_BLEND_SLIDER : Optional[gradio.Slider] = None
FACE_SWAPPER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None FACE_SWAPPER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None
FRAME_COLORIZER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None FRAME_COLORIZER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None
FRAME_COLORIZER_BLEND_SLIDER : Optional[gradio.Slider] = None FRAME_COLORIZER_BLEND_SLIDER : Optional[gradio.Slider] = None
FRAME_COLORIZER_SIZE_DROPDOWN : Optional[gradio.Dropdown] = None
FRAME_ENHANCER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None FRAME_ENHANCER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None
FRAME_ENHANCER_BLEND_SLIDER : Optional[gradio.Slider] = None FRAME_ENHANCER_BLEND_SLIDER : Optional[gradio.Slider] = None
LIP_SYNCER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None LIP_SYNCER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None
@ -26,6 +27,7 @@ def render() -> None:
global FACE_SWAPPER_MODEL_DROPDOWN global FACE_SWAPPER_MODEL_DROPDOWN
global FRAME_COLORIZER_MODEL_DROPDOWN global FRAME_COLORIZER_MODEL_DROPDOWN
global FRAME_COLORIZER_BLEND_SLIDER global FRAME_COLORIZER_BLEND_SLIDER
global FRAME_COLORIZER_SIZE_DROPDOWN
global FRAME_ENHANCER_MODEL_DROPDOWN global FRAME_ENHANCER_MODEL_DROPDOWN
global FRAME_ENHANCER_BLEND_SLIDER global FRAME_ENHANCER_BLEND_SLIDER
global LIP_SYNCER_MODEL_DROPDOWN global LIP_SYNCER_MODEL_DROPDOWN
@ -70,6 +72,12 @@ def render() -> None:
maximum = frame_processors_choices.frame_colorizer_blend_range[-1], maximum = frame_processors_choices.frame_colorizer_blend_range[-1],
visible = 'frame_colorizer' in facefusion.globals.frame_processors visible = 'frame_colorizer' in facefusion.globals.frame_processors
) )
FRAME_COLORIZER_SIZE_DROPDOWN = gradio.Dropdown(
label = wording.get('uis.frame_colorizer_size_dropdown'),
choices = frame_processors_choices.frame_colorizer_sizes,
value = frame_processors_globals.frame_colorizer_size,
visible = 'frame_colorizer' in facefusion.globals.frame_processors
)
FRAME_ENHANCER_MODEL_DROPDOWN = gradio.Dropdown( FRAME_ENHANCER_MODEL_DROPDOWN = gradio.Dropdown(
label = wording.get('uis.frame_enhancer_model_dropdown'), label = wording.get('uis.frame_enhancer_model_dropdown'),
choices = frame_processors_choices.frame_enhancer_models, choices = frame_processors_choices.frame_enhancer_models,
@ -96,6 +104,7 @@ def render() -> None:
register_ui_component('face_swapper_model_dropdown', FACE_SWAPPER_MODEL_DROPDOWN) register_ui_component('face_swapper_model_dropdown', FACE_SWAPPER_MODEL_DROPDOWN)
register_ui_component('frame_colorizer_model_dropdown', FRAME_COLORIZER_MODEL_DROPDOWN) register_ui_component('frame_colorizer_model_dropdown', FRAME_COLORIZER_MODEL_DROPDOWN)
register_ui_component('frame_colorizer_blend_slider', FRAME_COLORIZER_BLEND_SLIDER) register_ui_component('frame_colorizer_blend_slider', FRAME_COLORIZER_BLEND_SLIDER)
register_ui_component('frame_colorizer_size_dropdown', FRAME_COLORIZER_SIZE_DROPDOWN)
register_ui_component('frame_enhancer_model_dropdown', FRAME_ENHANCER_MODEL_DROPDOWN) register_ui_component('frame_enhancer_model_dropdown', FRAME_ENHANCER_MODEL_DROPDOWN)
register_ui_component('frame_enhancer_blend_slider', FRAME_ENHANCER_BLEND_SLIDER) register_ui_component('frame_enhancer_blend_slider', FRAME_ENHANCER_BLEND_SLIDER)
register_ui_component('lip_syncer_model_dropdown', LIP_SYNCER_MODEL_DROPDOWN) register_ui_component('lip_syncer_model_dropdown', LIP_SYNCER_MODEL_DROPDOWN)
@ -108,22 +117,23 @@ def listen() -> None:
FACE_SWAPPER_MODEL_DROPDOWN.change(update_face_swapper_model, inputs = FACE_SWAPPER_MODEL_DROPDOWN, outputs = FACE_SWAPPER_MODEL_DROPDOWN) FACE_SWAPPER_MODEL_DROPDOWN.change(update_face_swapper_model, inputs = FACE_SWAPPER_MODEL_DROPDOWN, outputs = FACE_SWAPPER_MODEL_DROPDOWN)
FRAME_COLORIZER_MODEL_DROPDOWN.change(update_frame_colorizer_model, inputs = FRAME_COLORIZER_MODEL_DROPDOWN, outputs = FRAME_COLORIZER_MODEL_DROPDOWN) FRAME_COLORIZER_MODEL_DROPDOWN.change(update_frame_colorizer_model, inputs = FRAME_COLORIZER_MODEL_DROPDOWN, outputs = FRAME_COLORIZER_MODEL_DROPDOWN)
FRAME_COLORIZER_BLEND_SLIDER.release(update_frame_colorizer_blend, inputs = FRAME_COLORIZER_BLEND_SLIDER) FRAME_COLORIZER_BLEND_SLIDER.release(update_frame_colorizer_blend, inputs = FRAME_COLORIZER_BLEND_SLIDER)
FRAME_COLORIZER_SIZE_DROPDOWN.change(update_frame_colorizer_size, inputs = FRAME_COLORIZER_SIZE_DROPDOWN, outputs = FRAME_COLORIZER_SIZE_DROPDOWN)
FRAME_ENHANCER_MODEL_DROPDOWN.change(update_frame_enhancer_model, inputs = FRAME_ENHANCER_MODEL_DROPDOWN, outputs = FRAME_ENHANCER_MODEL_DROPDOWN) FRAME_ENHANCER_MODEL_DROPDOWN.change(update_frame_enhancer_model, inputs = FRAME_ENHANCER_MODEL_DROPDOWN, outputs = FRAME_ENHANCER_MODEL_DROPDOWN)
FRAME_ENHANCER_BLEND_SLIDER.release(update_frame_enhancer_blend, inputs = FRAME_ENHANCER_BLEND_SLIDER) FRAME_ENHANCER_BLEND_SLIDER.release(update_frame_enhancer_blend, inputs = FRAME_ENHANCER_BLEND_SLIDER)
LIP_SYNCER_MODEL_DROPDOWN.change(update_lip_syncer_model, inputs = LIP_SYNCER_MODEL_DROPDOWN, outputs = LIP_SYNCER_MODEL_DROPDOWN) LIP_SYNCER_MODEL_DROPDOWN.change(update_lip_syncer_model, inputs = LIP_SYNCER_MODEL_DROPDOWN, outputs = LIP_SYNCER_MODEL_DROPDOWN)
frame_processors_checkbox_group = get_ui_component('frame_processors_checkbox_group') frame_processors_checkbox_group = get_ui_component('frame_processors_checkbox_group')
if frame_processors_checkbox_group: if frame_processors_checkbox_group:
frame_processors_checkbox_group.change(update_frame_processors, inputs = frame_processors_checkbox_group, outputs = [ FACE_DEBUGGER_ITEMS_CHECKBOX_GROUP, FACE_ENHANCER_MODEL_DROPDOWN, FACE_ENHANCER_BLEND_SLIDER, FACE_SWAPPER_MODEL_DROPDOWN, FRAME_COLORIZER_MODEL_DROPDOWN, FRAME_COLORIZER_BLEND_SLIDER, FRAME_ENHANCER_MODEL_DROPDOWN, FRAME_ENHANCER_BLEND_SLIDER, LIP_SYNCER_MODEL_DROPDOWN ]) frame_processors_checkbox_group.change(update_frame_processors, inputs = frame_processors_checkbox_group, outputs = [ FACE_DEBUGGER_ITEMS_CHECKBOX_GROUP, FACE_ENHANCER_MODEL_DROPDOWN, FACE_ENHANCER_BLEND_SLIDER, FACE_SWAPPER_MODEL_DROPDOWN, FRAME_COLORIZER_MODEL_DROPDOWN, FRAME_COLORIZER_BLEND_SLIDER, FRAME_COLORIZER_SIZE_DROPDOWN, FRAME_ENHANCER_MODEL_DROPDOWN, FRAME_ENHANCER_BLEND_SLIDER, LIP_SYNCER_MODEL_DROPDOWN ])
def update_frame_processors(frame_processors : List[str]) -> Tuple[gradio.CheckboxGroup, gradio.Dropdown, gradio.Slider, gradio.Dropdown, gradio.Dropdown, gradio.Slider, gradio.Dropdown, gradio.Slider, gradio.Dropdown]: def update_frame_processors(frame_processors : List[str]) -> Tuple[gradio.CheckboxGroup, gradio.Dropdown, gradio.Slider, gradio.Dropdown, gradio.Dropdown, gradio.Slider, gradio.Dropdown, gradio.Dropdown, gradio.Slider, gradio.Dropdown]:
has_face_debugger = 'face_debugger' in frame_processors has_face_debugger = 'face_debugger' in frame_processors
has_face_enhancer = 'face_enhancer' in frame_processors has_face_enhancer = 'face_enhancer' in frame_processors
has_face_swapper = 'face_swapper' in frame_processors has_face_swapper = 'face_swapper' in frame_processors
has_frame_colorizer = 'frame_colorizer' in frame_processors has_frame_colorizer = 'frame_colorizer' in frame_processors
has_frame_enhancer = 'frame_enhancer' in frame_processors has_frame_enhancer = 'frame_enhancer' in frame_processors
has_lip_syncer = 'lip_syncer' in frame_processors has_lip_syncer = 'lip_syncer' in frame_processors
return gradio.CheckboxGroup(visible = has_face_debugger), gradio.Dropdown(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer), gradio.Dropdown(visible = has_face_swapper), gradio.Dropdown(visible = has_frame_colorizer), gradio.Slider(visible = has_frame_colorizer), gradio.Dropdown(visible = has_frame_enhancer), gradio.Slider(visible = has_frame_enhancer), gradio.Dropdown(visible = has_lip_syncer) return gradio.CheckboxGroup(visible = has_face_debugger), gradio.Dropdown(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer), gradio.Dropdown(visible = has_face_swapper), gradio.Dropdown(visible = has_frame_colorizer), gradio.Slider(visible = has_frame_colorizer), gradio.Dropdown(visible = has_frame_colorizer), gradio.Dropdown(visible = has_frame_enhancer), gradio.Slider(visible = has_frame_enhancer), gradio.Dropdown(visible = has_lip_syncer)
def update_face_debugger_items(face_debugger_items : List[FaceDebuggerItem]) -> None: def update_face_debugger_items(face_debugger_items : List[FaceDebuggerItem]) -> None:
@ -177,6 +187,11 @@ def update_frame_colorizer_blend(frame_colorizer_blend : int) -> None:
frame_processors_globals.frame_colorizer_blend = frame_colorizer_blend frame_processors_globals.frame_colorizer_blend = frame_colorizer_blend
def update_frame_colorizer_size(frame_colorizer_size : str) -> gradio.Dropdown:
frame_processors_globals.frame_colorizer_size = frame_colorizer_size
return gradio.Dropdown(value = frame_processors_globals.frame_colorizer_size)
def update_frame_enhancer_model(frame_enhancer_model : FrameEnhancerModel) -> gradio.Dropdown: def update_frame_enhancer_model(frame_enhancer_model : FrameEnhancerModel) -> gradio.Dropdown:
frame_processors_globals.frame_enhancer_model = frame_enhancer_model frame_processors_globals.frame_enhancer_model = frame_enhancer_model
frame_enhancer_module = load_frame_processor_module('frame_enhancer') frame_enhancer_module = load_frame_processor_module('frame_enhancer')

View File

@ -93,6 +93,7 @@ def listen() -> None:
for ui_component in get_ui_components( for ui_component in get_ui_components(
[ [
'face_debugger_items_checkbox_group', 'face_debugger_items_checkbox_group',
'frame_colorizer_size_dropdown',
'face_selector_mode_dropdown', 'face_selector_mode_dropdown',
'face_mask_types_checkbox_group', 'face_mask_types_checkbox_group',
'face_mask_region_checkbox_group', 'face_mask_region_checkbox_group',

View File

@ -36,6 +36,7 @@ ComponentName = Literal\
'face_swapper_model_dropdown', 'face_swapper_model_dropdown',
'frame_colorizer_model_dropdown', 'frame_colorizer_model_dropdown',
'frame_colorizer_blend_slider', 'frame_colorizer_blend_slider',
'frame_colorizer_size_dropdown',
'frame_enhancer_model_dropdown', 'frame_enhancer_model_dropdown',
'frame_enhancer_blend_slider', 'frame_enhancer_blend_slider',
'lip_syncer_model_dropdown', 'lip_syncer_model_dropdown',

View File

@ -1,20 +1,18 @@
from typing import Any, Tuple from typing import Any, Tuple
from time import sleep from time import sleep
import threading
import scipy import scipy
import numpy import numpy
import onnxruntime import onnxruntime
import facefusion.globals import facefusion.globals
from facefusion import process_manager from facefusion import process_manager
from facefusion.thread_helper import thread_lock, thread_semaphore
from facefusion.typing import ModelSet, AudioChunk, Audio from facefusion.typing import ModelSet, AudioChunk, Audio
from facefusion.execution import apply_execution_provider_options from facefusion.execution import apply_execution_provider_options
from facefusion.filesystem import resolve_relative_path, is_file from facefusion.filesystem import resolve_relative_path, is_file
from facefusion.download import conditional_download from facefusion.download import conditional_download
VOICE_EXTRACTOR = None VOICE_EXTRACTOR = None
THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore()
THREAD_LOCK : threading.Lock = threading.Lock()
MODELS : ModelSet =\ MODELS : ModelSet =\
{ {
'voice_extractor': 'voice_extractor':
@ -28,7 +26,7 @@ MODELS : ModelSet =\
def get_voice_extractor() -> Any: def get_voice_extractor() -> Any:
global VOICE_EXTRACTOR global VOICE_EXTRACTOR
with THREAD_LOCK: with thread_lock():
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
if VOICE_EXTRACTOR is None: if VOICE_EXTRACTOR is None:
@ -73,7 +71,7 @@ def extract_voice(temp_audio_chunk : AudioChunk) -> AudioChunk:
trim_size = 3840 trim_size = 3840
temp_audio_chunk, pad_size = prepare_audio_chunk(temp_audio_chunk.T, chunk_size, trim_size) temp_audio_chunk, pad_size = prepare_audio_chunk(temp_audio_chunk.T, chunk_size, trim_size)
temp_audio_chunk = decompose_audio_chunk(temp_audio_chunk, trim_size) temp_audio_chunk = decompose_audio_chunk(temp_audio_chunk, trim_size)
with THREAD_SEMAPHORE: with thread_semaphore():
temp_audio_chunk = voice_extractor.run(None, temp_audio_chunk = voice_extractor.run(None,
{ {
voice_extractor.get_inputs()[0].name: temp_audio_chunk voice_extractor.get_inputs()[0].name: temp_audio_chunk

View File

@ -110,6 +110,7 @@ WORDING : Dict[str, Any] =\
'face_swapper_model': 'choose the model responsible for swapping the face', 'face_swapper_model': 'choose the model responsible for swapping the face',
'frame_colorizer_model': 'choose the model responsible for colorizing the frame', 'frame_colorizer_model': 'choose the model responsible for colorizing the frame',
'frame_colorizer_blend': 'blend the colorized into the previous frame', 'frame_colorizer_blend': 'blend the colorized into the previous frame',
'frame_colorizer_size': 'specify the size of the frame provided to the frame colorizer',
'frame_enhancer_model': 'choose the model responsible for enhancing the frame', 'frame_enhancer_model': 'choose the model responsible for enhancing the frame',
'frame_enhancer_blend': 'blend the enhanced into the previous frame', 'frame_enhancer_blend': 'blend the enhanced into the previous frame',
'lip_syncer_model': 'choose the model responsible for syncing the lips', 'lip_syncer_model': 'choose the model responsible for syncing the lips',
@ -166,6 +167,7 @@ WORDING : Dict[str, Any] =\
'face_swapper_model_dropdown': 'FACE SWAPPER MODEL', 'face_swapper_model_dropdown': 'FACE SWAPPER MODEL',
'frame_colorizer_model_dropdown': 'FRAME COLORIZER MODEL', 'frame_colorizer_model_dropdown': 'FRAME COLORIZER MODEL',
'frame_colorizer_blend_slider': 'FRAME COLORIZER BLEND', 'frame_colorizer_blend_slider': 'FRAME COLORIZER BLEND',
'frame_colorizer_size_dropdown': 'FRAME COLORIZER SIZE',
'frame_enhancer_model_dropdown': 'FRAME ENHANCER MODEL', 'frame_enhancer_model_dropdown': 'FRAME ENHANCER MODEL',
'frame_enhancer_blend_slider': 'FRAME ENHANCER BLEND', 'frame_enhancer_blend_slider': 'FRAME ENHANCER BLEND',
'lip_syncer_model_dropdown': 'LIP SYNCER MODEL', 'lip_syncer_model_dropdown': 'LIP SYNCER MODEL',

View File

@ -1,80 +0,0 @@
import subprocess
import sys
import pytest
from facefusion.download import conditional_download
@pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None:
conditional_download('.assets/examples',
[
'https://github.com/facefusion/facefusion-assets/releases/download/examples/source.jpg',
'https://github.com/facefusion/facefusion-assets/releases/download/examples/source.mp3',
'https://github.com/facefusion/facefusion-assets/releases/download/examples/target-240p.mp4'
])
subprocess.run([ 'ffmpeg', '-i', '.assets/examples/target-240p.mp4', '-vframes', '1', '.assets/examples/target-240p.jpg' ])
def test_debug_face_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_debugger', '-t', '.assets/examples/target-240p.jpg', '-o', '.assets/examples/test_debug_face_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_debug_face_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_debugger', '-t', '.assets/examples/target-240p.mp4', '-o', '.assets/examples/test_debug_face_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()
def test_enhance_face_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_enhancer', '-t', '.assets/examples/target-240p.jpg', '-o', '.assets/examples/test_enhance_face_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_enhance_face_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_enhancer', '-t', '.assets/examples/target-240p.mp4', '-o', '.assets/examples/test_enhance_face_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()
def test_swap_face_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_swapper', '-s', '.assets/examples/source.jpg', '-t', '.assets/examples/target-240p.jpg', '-o', '.assets/examples/test_swap_face_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_swap_face_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_swapper', '-s', '.assets/examples/source.jpg', '-t', '.assets/examples/target-240p.mp4', '-o', '.assets/examples/test_swap_face_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()
def test_sync_lip_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'lip_syncer', '-s', '.assets/examples/source.mp3', '-t', '.assets/examples/target-240p.jpg', '-o', '.assets/examples/test_sync_lip_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_sync_lip_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'lip_syncer', '-s', '.assets/examples/source.mp3', '-t', '.assets/examples/target-240p.mp4', '-o', '.assets/examples/test_sync_lip_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()

View File

@ -0,0 +1,31 @@
import subprocess
import sys
import pytest
from facefusion.download import conditional_download
@pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None:
conditional_download('.assets/examples',
[
'https://github.com/facefusion/facefusion-assets/releases/download/examples/source.jpg',
'https://github.com/facefusion/facefusion-assets/releases/download/examples/target-240p.mp4'
])
subprocess.run([ 'ffmpeg', '-i', '.assets/examples/target-240p.mp4', '-vframes', '1', '.assets/examples/target-240p.jpg' ])
def test_debug_face_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_debugger', '-t', '.assets/examples/target-240p.jpg', '-o', '.assets/examples/test_debug_face_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_debug_face_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_debugger', '-t', '.assets/examples/target-240p.mp4', '-o', '.assets/examples/test_debug_face_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()

View File

@ -0,0 +1,32 @@
import subprocess
import sys
import pytest
from facefusion.download import conditional_download
@pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None:
conditional_download('.assets/examples',
[
'https://github.com/facefusion/facefusion-assets/releases/download/examples/source.jpg',
'https://github.com/facefusion/facefusion-assets/releases/download/examples/target-240p.mp4'
])
subprocess.run([ 'ffmpeg', '-i', '.assets/examples/target-240p.mp4', '-vframes', '1', '.assets/examples/target-240p.jpg' ])
def test_enhance_face_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_enhancer', '-t', '.assets/examples/target-240p.jpg', '-o', '.assets/examples/test_enhance_face_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_enhance_face_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_enhancer', '-t', '.assets/examples/target-240p.mp4', '-o', '.assets/examples/test_enhance_face_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()

View File

@ -0,0 +1,31 @@
import subprocess
import sys
import pytest
from facefusion.download import conditional_download
@pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None:
conditional_download('.assets/examples',
[
'https://github.com/facefusion/facefusion-assets/releases/download/examples/source.jpg',
'https://github.com/facefusion/facefusion-assets/releases/download/examples/target-240p.mp4'
])
subprocess.run([ 'ffmpeg', '-i', '.assets/examples/target-240p.mp4', '-vframes', '1', '.assets/examples/target-240p.jpg' ])
def test_swap_face_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_swapper', '-s', '.assets/examples/source.jpg', '-t', '.assets/examples/target-240p.jpg', '-o', '.assets/examples/test_swap_face_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_swap_face_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'face_swapper', '-s', '.assets/examples/source.jpg', '-t', '.assets/examples/target-240p.mp4', '-o', '.assets/examples/test_swap_face_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()

View File

@ -0,0 +1,32 @@
import subprocess
import sys
import pytest
from facefusion.download import conditional_download
@pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None:
conditional_download('.assets/examples',
[
'https://github.com/facefusion/facefusion-assets/releases/download/examples/source.jpg',
'https://github.com/facefusion/facefusion-assets/releases/download/examples/target-240p.mp4'
])
subprocess.run([ 'ffmpeg', '-i', '.assets/examples/target-240p.mp4', '-vframes', '1', '-vf', 'hue=s=0', '.assets/examples/target-240p-0sat.jpg' ])
subprocess.run([ 'ffmpeg', '-i', '.assets/examples/target-240p.mp4', '-vf', 'hue=s=0', '.assets/examples/target-240p-0sat.mp4' ])
def test_colorize_frame_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'frame_colorizer', '-t', '.assets/examples/target-240p-0sat.jpg', '-o', '.assets/examples/test_colorize_frame_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_colorize_frame_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'frame_colorizer', '-t', '.assets/examples/target-240p-0sat.mp4', '-o', '.assets/examples/test_colorize_frame_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()

View File

@ -0,0 +1,31 @@
import subprocess
import sys
import pytest
from facefusion.download import conditional_download
@pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None:
conditional_download('.assets/examples',
[
'https://github.com/facefusion/facefusion-assets/releases/download/examples/source.jpg',
'https://github.com/facefusion/facefusion-assets/releases/download/examples/target-240p.mp4'
])
subprocess.run([ 'ffmpeg', '-i', '.assets/examples/target-240p.mp4', '-vframes', '1', '.assets/examples/target-240p.jpg' ])
def test_enhance_frame_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'frame_enhancer', '-t', '.assets/examples/target-240p.jpg', '-o', '.assets/examples/test_enhance_frame_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_enhance_frame_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'frame_enhancer', '-t', '.assets/examples/target-240p.mp4', '-o', '.assets/examples/test_enhance_frame_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()

View File

@ -0,0 +1,32 @@
import subprocess
import sys
import pytest
from facefusion.download import conditional_download
@pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None:
conditional_download('.assets/examples',
[
'https://github.com/facefusion/facefusion-assets/releases/download/examples/source.jpg',
'https://github.com/facefusion/facefusion-assets/releases/download/examples/source.mp3',
'https://github.com/facefusion/facefusion-assets/releases/download/examples/target-240p.mp4'
])
subprocess.run([ 'ffmpeg', '-i', '.assets/examples/target-240p.mp4', '-vframes', '1', '.assets/examples/target-240p.jpg' ])
def test_sync_lip_to_image() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'lip_syncer', '-s', '.assets/examples/source.mp3', '-t', '.assets/examples/target-240p.jpg', '-o', '.assets/examples/test_sync_lip_to_image.jpg', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'image succeed' in run.stdout.decode()
def test_sync_lip_to_video() -> None:
commands = [ sys.executable, 'run.py', '--frame-processors', 'lip_syncer', '-s', '.assets/examples/source.mp3', '-t', '.assets/examples/target-240p.mp4', '-o', '.assets/examples/test_sync_lip_to_video.mp4', '--trim-frame-end', '10', '--headless' ]
run = subprocess.run(commands, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
assert run.returncode == 0
assert 'video succeed' in run.stdout.decode()