From cb51775d99ee025c98e3038c2759203a3a4eee52 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 20 Nov 2024 19:27:59 +0100 Subject: [PATCH] Introduce create_static_model_set() everywhere --- facefusion/content_analyser.py | 54 ++++---- facefusion/core.py | 19 +++ facefusion/face_classifier.py | 48 +++---- facefusion/face_detector.py | 105 ++++++++-------- facefusion/face_landmarker.py | 119 +++++++++--------- facefusion/face_masker.py | 103 +++++++-------- facefusion/face_recognizer.py | 44 ++++--- facefusion/processors/modules/age_modifier.py | 6 +- facefusion/processors/modules/deep_swapper.py | 6 +- .../processors/modules/expression_restorer.py | 6 +- facefusion/processors/modules/face_editor.py | 6 +- .../processors/modules/face_enhancer.py | 6 +- .../processors/modules/frame_colorizer.py | 6 +- .../processors/modules/frame_enhancer.py | 6 +- facefusion/processors/modules/lip_syncer.py | 6 +- facefusion/uis/components/download.py | 2 + facefusion/voice_extractor.py | 36 +++--- 17 files changed, 323 insertions(+), 255 deletions(-) diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index bc6fcf4b..2f1853bc 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -11,35 +11,39 @@ from facefusion.thread_helper import conditional_thread_semaphore from facefusion.typing import Fps, InferencePool, ModelOptions, ModelSet, VisionFrame from facefusion.vision import count_video_frame_total, detect_video_fps, get_video_frame, read_image -MODEL_SET : ModelSet =\ -{ - 'open_nsfw': - { - 'hashes': - { - 'content_analyser': - { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/open_nsfw.hash', - 'path': resolve_relative_path('../.assets/models/open_nsfw.hash') - } - }, - 'sources': - { - 'content_analyser': - { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/open_nsfw.onnx', - 'path': resolve_relative_path('../.assets/models/open_nsfw.onnx') - } - }, - 'size': (224, 224), - 'mean': [ 104, 117, 123 ] - } -} PROBABILITY_LIMIT = 0.80 RATE_LIMIT = 10 STREAM_COUNTER = 0 +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: + return\ + { + 'open_nsfw': + { + 'hashes': + { + 'content_analyser': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/open_nsfw.hash', + 'path': resolve_relative_path('../.assets/models/open_nsfw.hash') + } + }, + 'sources': + { + 'content_analyser': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/open_nsfw.onnx', + 'path': resolve_relative_path('../.assets/models/open_nsfw.onnx') + } + }, + 'size': (224, 224), + 'mean': [ 104, 117, 123 ] + } + } + + def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') return inference_manager.get_inference_pool(__name__, model_sources) @@ -50,7 +54,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: - return MODEL_SET.get('open_nsfw') + return create_static_model_set().get('open_nsfw') def pre_check() -> bool: diff --git a/facefusion/core.py b/facefusion/core.py index 26efc286..739508c2 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -157,6 +157,25 @@ def conditional_append_reference_faces() -> None: append_reference_face(processor_module.__name__, abstract_reference_face) +def clear_model_sets() -> None: + available_processors = list_directory('facefusion/processors/modules') + common_modules =\ + [ + content_analyser, + face_classifier, + face_detector, + face_landmarker, + face_recognizer, + face_masker, + voice_extractor + ] + processor_modules = get_processors_modules(available_processors) + + for module in common_modules + processor_modules: + if hasattr(module, 'create_static_model_set'): + module.create_static_model_set.cache_clear() + + def force_download() -> ErrorCode: available_processors = list_directory('facefusion/processors/modules') common_modules =\ diff --git a/facefusion/face_classifier.py b/facefusion/face_classifier.py index aa72e609..eb011fcb 100644 --- a/facefusion/face_classifier.py +++ b/facefusion/face_classifier.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import List, Tuple import numpy @@ -9,32 +10,35 @@ from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import conditional_thread_semaphore from facefusion.typing import Age, FaceLandmark5, Gender, InferencePool, ModelOptions, ModelSet, Race, VisionFrame -MODEL_SET : ModelSet =\ -{ - 'fairface': + +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: + return\ { - 'hashes': + 'fairface': { - 'face_classifier': + 'hashes': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fairface.hash', - 'path': resolve_relative_path('../.assets/models/fairface.hash') - } - }, - 'sources': - { - 'face_classifier': + 'face_classifier': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fairface.hash', + 'path': resolve_relative_path('../.assets/models/fairface.hash') + } + }, + 'sources': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fairface.onnx', - 'path': resolve_relative_path('../.assets/models/fairface.onnx') - } - }, - 'template': 'arcface_112_v2', - 'size': (224, 224), - 'mean': [ 0.485, 0.456, 0.406 ], - 'standard_deviation': [ 0.229, 0.224, 0.225 ] + 'face_classifier': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fairface.onnx', + 'path': resolve_relative_path('../.assets/models/fairface.onnx') + } + }, + 'template': 'arcface_112_v2', + 'size': (224, 224), + 'mean': [ 0.485, 0.456, 0.406 ], + 'standard_deviation': [ 0.229, 0.224, 0.225 ] + } } -} def get_inference_pool() -> InferencePool: @@ -47,7 +51,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: - return MODEL_SET.get('fairface') + return create_static_model_set().get('fairface') def pre_check() -> bool: diff --git a/facefusion/face_detector.py b/facefusion/face_detector.py index 3e0b7f3c..c66dc75d 100644 --- a/facefusion/face_detector.py +++ b/facefusion/face_detector.py @@ -2,6 +2,7 @@ from typing import List, Tuple import cv2 import numpy +from charset_normalizer.md import lru_cache from facefusion import inference_manager, state_manager from facefusion.download import conditional_download_hashes, conditional_download_sources @@ -11,66 +12,69 @@ from facefusion.thread_helper import thread_semaphore from facefusion.typing import Angle, BoundingBox, Detection, DownloadSet, FaceLandmark5, InferencePool, ModelSet, Score, VisionFrame from facefusion.vision import resize_frame_resolution, unpack_resolution -MODEL_SET : ModelSet =\ -{ - 'retinaface': + +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: + return\ { - 'hashes': + 'retinaface': { - 'retinaface': + 'hashes': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/retinaface_10g.hash', - 'path': resolve_relative_path('../.assets/models/retinaface_10g.hash') + 'retinaface': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/retinaface_10g.hash', + 'path': resolve_relative_path('../.assets/models/retinaface_10g.hash') + } + }, + 'sources': + { + 'retinaface': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/retinaface_10g.onnx', + 'path': resolve_relative_path('../.assets/models/retinaface_10g.onnx') + } } }, - 'sources': + 'scrfd': { - 'retinaface': + 'hashes': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/retinaface_10g.onnx', - 'path': resolve_relative_path('../.assets/models/retinaface_10g.onnx') - } - } - }, - 'scrfd': - { - 'hashes': - { - 'scrfd': + 'scrfd': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/scrfd_2.5g.hash', + 'path': resolve_relative_path('../.assets/models/scrfd_2.5g.hash') + } + }, + 'sources': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/scrfd_2.5g.hash', - 'path': resolve_relative_path('../.assets/models/scrfd_2.5g.hash') + 'scrfd': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/scrfd_2.5g.onnx', + 'path': resolve_relative_path('../.assets/models/scrfd_2.5g.onnx') + } } }, - 'sources': + 'yoloface': { - 'scrfd': + 'hashes': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/scrfd_2.5g.onnx', - 'path': resolve_relative_path('../.assets/models/scrfd_2.5g.onnx') - } - } - }, - 'yoloface': - { - 'hashes': - { - 'yoloface': + 'yoloface': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/yoloface_8n.hash', + 'path': resolve_relative_path('../.assets/models/yoloface_8n.hash') + } + }, + 'sources': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/yoloface_8n.hash', - 'path': resolve_relative_path('../.assets/models/yoloface_8n.hash') - } - }, - 'sources': - { - 'yoloface': - { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/yoloface_8n.onnx', - 'path': resolve_relative_path('../.assets/models/yoloface_8n.onnx') + 'yoloface': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/yoloface_8n.onnx', + 'path': resolve_relative_path('../.assets/models/yoloface_8n.onnx') + } } } } -} def get_inference_pool() -> InferencePool: @@ -87,16 +91,17 @@ def clear_inference_pool() -> None: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: model_hashes = {} model_sources = {} + model_set = create_static_model_set() if state_manager.get_item('face_detector_model') in [ 'many', 'retinaface' ]: - model_hashes['retinaface'] = MODEL_SET.get('retinaface').get('hashes').get('retinaface') - model_sources['retinaface'] = MODEL_SET.get('retinaface').get('sources').get('retinaface') + model_hashes['retinaface'] = model_set.get('retinaface').get('hashes').get('retinaface') + model_sources['retinaface'] = model_set.get('retinaface').get('sources').get('retinaface') if state_manager.get_item('face_detector_model') in [ 'many', 'scrfd' ]: - model_hashes['scrfd'] = MODEL_SET.get('scrfd').get('hashes').get('scrfd') - model_sources['scrfd'] = MODEL_SET.get('scrfd').get('sources').get('scrfd') + model_hashes['scrfd'] = model_set.get('scrfd').get('hashes').get('scrfd') + model_sources['scrfd'] = model_set.get('scrfd').get('sources').get('scrfd') if state_manager.get_item('face_detector_model') in [ 'many', 'yoloface' ]: - model_hashes['yoloface'] = MODEL_SET.get('yoloface').get('hashes').get('yoloface') - model_sources['yoloface'] = MODEL_SET.get('yoloface').get('sources').get('yoloface') + model_hashes['yoloface'] = model_set.get('yoloface').get('hashes').get('yoloface') + model_sources['yoloface'] = model_set.get('yoloface').get('sources').get('yoloface') return model_hashes, model_sources diff --git a/facefusion/face_landmarker.py b/facefusion/face_landmarker.py index aaf44047..5c69a23a 100644 --- a/facefusion/face_landmarker.py +++ b/facefusion/face_landmarker.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Tuple import cv2 @@ -10,68 +11,71 @@ from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import conditional_thread_semaphore from facefusion.typing import Angle, BoundingBox, DownloadSet, FaceLandmark5, FaceLandmark68, InferencePool, ModelSet, Prediction, Score, VisionFrame -MODEL_SET : ModelSet =\ -{ - '2dfan4': + +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: + return\ { - 'hashes': + '2dfan4': { - '2dfan4': + 'hashes': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/2dfan4.hash', - 'path': resolve_relative_path('../.assets/models/2dfan4.hash') - } + '2dfan4': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/2dfan4.hash', + 'path': resolve_relative_path('../.assets/models/2dfan4.hash') + } + }, + 'sources': + { + '2dfan4': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/2dfan4.onnx', + 'path': resolve_relative_path('../.assets/models/2dfan4.onnx') + } + }, + 'size': (256, 256) }, - 'sources': + 'peppa_wutz': { - '2dfan4': + 'hashes': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/2dfan4.onnx', - 'path': resolve_relative_path('../.assets/models/2dfan4.onnx') - } + 'peppa_wutz': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/peppa_wutz.hash', + 'path': resolve_relative_path('../.assets/models/peppa_wutz.hash') + } + }, + 'sources': + { + 'peppa_wutz': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/peppa_wutz.onnx', + 'path': resolve_relative_path('../.assets/models/peppa_wutz.onnx') + } + }, + 'size': (256, 256) }, - 'size': (256, 256) - }, - 'peppa_wutz': - { - 'hashes': + 'fan_68_5': { - 'peppa_wutz': + 'hashes': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/peppa_wutz.hash', - 'path': resolve_relative_path('../.assets/models/peppa_wutz.hash') - } - }, - 'sources': - { - 'peppa_wutz': + 'fan_68_5': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fan_68_5.hash', + 'path': resolve_relative_path('../.assets/models/fan_68_5.hash') + } + }, + 'sources': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/peppa_wutz.onnx', - 'path': resolve_relative_path('../.assets/models/peppa_wutz.onnx') - } - }, - 'size': (256, 256) - }, - 'fan_68_5': - { - 'hashes': - { - 'fan_68_5': - { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fan_68_5.hash', - 'path': resolve_relative_path('../.assets/models/fan_68_5.hash') - } - }, - 'sources': - { - 'fan_68_5': - { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fan_68_5.onnx', - 'path': resolve_relative_path('../.assets/models/fan_68_5.onnx') + 'fan_68_5': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fan_68_5.onnx', + 'path': resolve_relative_path('../.assets/models/fan_68_5.onnx') + } } } } -} def get_inference_pool() -> InferencePool: @@ -86,21 +90,22 @@ def clear_inference_pool() -> None: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: + model_set = create_static_model_set() model_hashes =\ { - 'fan_68_5': MODEL_SET.get('fan_68_5').get('hashes').get('fan_68_5') + 'fan_68_5': model_set.get('fan_68_5').get('hashes').get('fan_68_5') } model_sources =\ { - 'fan_68_5': MODEL_SET.get('fan_68_5').get('sources').get('fan_68_5') + 'fan_68_5': model_set.get('fan_68_5').get('sources').get('fan_68_5') } if state_manager.get_item('face_landmarker_model') in [ 'many', '2dfan4' ]: - model_hashes['2dfan4'] = MODEL_SET.get('2dfan4').get('hashes').get('2dfan4') - model_sources['2dfan4'] = MODEL_SET.get('2dfan4').get('sources').get('2dfan4') + model_hashes['2dfan4'] = model_set.get('2dfan4').get('hashes').get('2dfan4') + model_sources['2dfan4'] = model_set.get('2dfan4').get('sources').get('2dfan4') if state_manager.get_item('face_landmarker_model') in [ 'many', 'peppa_wutz' ]: - model_hashes['peppa_wutz'] = MODEL_SET.get('peppa_wutz').get('hashes').get('peppa_wutz') - model_sources['peppa_wutz'] = MODEL_SET.get('peppa_wutz').get('sources').get('peppa_wutz') + model_hashes['peppa_wutz'] = model_set.get('peppa_wutz').get('hashes').get('peppa_wutz') + model_sources['peppa_wutz'] = model_set.get('peppa_wutz').get('sources').get('peppa_wutz') return model_hashes, model_sources @@ -127,7 +132,7 @@ def detect_face_landmarks(vision_frame : VisionFrame, bounding_box : BoundingBox def detect_with_2dfan4(temp_vision_frame: VisionFrame, bounding_box: BoundingBox, face_angle: Angle) -> Tuple[FaceLandmark68, Score]: - model_size = MODEL_SET.get('2dfan4').get('size') + model_size = create_static_model_set().get('2dfan4').get('size') scale = 195 / numpy.subtract(bounding_box[2:], bounding_box[:2]).max().clip(1, None) translation = (model_size[0] - numpy.add(bounding_box[2:], bounding_box[:2]) * scale) * 0.5 rotated_matrix, rotated_size = create_rotated_matrix_and_size(face_angle, model_size) @@ -146,7 +151,7 @@ def detect_with_2dfan4(temp_vision_frame: VisionFrame, bounding_box: BoundingBox def detect_with_peppa_wutz(temp_vision_frame : VisionFrame, bounding_box : BoundingBox, face_angle : Angle) -> Tuple[FaceLandmark68, Score]: - model_size = MODEL_SET.get('peppa_wutz').get('size') + model_size = create_static_model_set().get('peppa_wutz').get('size') scale = 195 / numpy.subtract(bounding_box[2:], bounding_box[:2]).max().clip(1, None) translation = (model_size[0] - numpy.add(bounding_box[2:], bounding_box[:2]) * scale) * 0.5 rotated_matrix, rotated_size = create_rotated_matrix_and_size(face_angle, model_size) diff --git a/facefusion/face_masker.py b/facefusion/face_masker.py index 36e55507..0fd6640b 100755 --- a/facefusion/face_masker.py +++ b/facefusion/face_masker.py @@ -11,49 +11,6 @@ from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import conditional_thread_semaphore from facefusion.typing import DownloadSet, FaceLandmark68, FaceMaskRegion, InferencePool, Mask, ModelSet, Padding, VisionFrame -MODEL_SET : ModelSet =\ -{ - 'face_occluder': - { - 'hashes': - { - 'face_occluder': - { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/dfl_xseg.hash', - 'path': resolve_relative_path('../.assets/models/dfl_xseg.hash') - } - }, - 'sources': - { - 'face_occluder': - { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/dfl_xseg.onnx', - 'path': resolve_relative_path('../.assets/models/dfl_xseg.onnx') - } - }, - 'size': (256, 256) - }, - 'face_parser': - { - 'hashes': - { - 'face_parser': - { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/bisenet_resnet_34.hash', - 'path': resolve_relative_path('../.assets/models/bisenet_resnet_34.hash') - } - }, - 'sources': - { - 'face_parser': - { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/bisenet_resnet_34.onnx', - 'path': resolve_relative_path('../.assets/models/bisenet_resnet_34.onnx') - } - }, - 'size': (512, 512) - } -} FACE_MASK_REGIONS : Dict[FaceMaskRegion, int] =\ { 'skin': 1, @@ -69,6 +26,53 @@ FACE_MASK_REGIONS : Dict[FaceMaskRegion, int] =\ } +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: + return\ + { + 'face_occluder': + { + 'hashes': + { + 'face_occluder': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/dfl_xseg.hash', + 'path': resolve_relative_path('../.assets/models/dfl_xseg.hash') + } + }, + 'sources': + { + 'face_occluder': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/dfl_xseg.onnx', + 'path': resolve_relative_path('../.assets/models/dfl_xseg.onnx') + } + }, + 'size': (256, 256) + }, + 'face_parser': + { + 'hashes': + { + 'face_parser': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/bisenet_resnet_34.hash', + 'path': resolve_relative_path('../.assets/models/bisenet_resnet_34.hash') + } + }, + 'sources': + { + 'face_parser': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/bisenet_resnet_34.onnx', + 'path': resolve_relative_path('../.assets/models/bisenet_resnet_34.onnx') + } + }, + 'size': (512, 512) + } + } + + def get_inference_pool() -> InferencePool: _, model_sources = collect_model_downloads() return inference_manager.get_inference_pool(__name__, model_sources) @@ -79,15 +83,16 @@ def clear_inference_pool() -> None: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: + model_set = create_static_model_set() model_hashes =\ { - 'face_occluder': MODEL_SET.get('face_occluder').get('hashes').get('face_occluder'), - 'face_parser': MODEL_SET.get('face_parser').get('hashes').get('face_parser') + 'face_occluder': model_set.get('face_occluder').get('hashes').get('face_occluder'), + 'face_parser': model_set.get('face_parser').get('hashes').get('face_parser') } model_sources =\ { - 'face_occluder': MODEL_SET.get('face_occluder').get('sources').get('face_occluder'), - 'face_parser': MODEL_SET.get('face_parser').get('sources').get('face_parser') + 'face_occluder': model_set.get('face_occluder').get('sources').get('face_occluder'), + 'face_parser': model_set.get('face_parser').get('sources').get('face_parser') } return model_hashes, model_sources @@ -113,7 +118,7 @@ def create_static_box_mask(crop_size : Size, face_mask_blur : float, face_mask_p def create_occlusion_mask(crop_vision_frame : VisionFrame) -> Mask: - model_size = MODEL_SET.get('face_occluder').get('size') + model_size = create_static_model_set().get('face_occluder').get('size') prepare_vision_frame = cv2.resize(crop_vision_frame, model_size) prepare_vision_frame = numpy.expand_dims(prepare_vision_frame, axis = 0).astype(numpy.float32) / 255 prepare_vision_frame = prepare_vision_frame.transpose(0, 1, 2, 3) @@ -125,7 +130,7 @@ def create_occlusion_mask(crop_vision_frame : VisionFrame) -> Mask: def create_region_mask(crop_vision_frame : VisionFrame, face_mask_regions : List[FaceMaskRegion]) -> Mask: - model_size = MODEL_SET.get('face_parser').get('size') + model_size = create_static_model_set().get('face_parser').get('size') prepare_vision_frame = cv2.resize(crop_vision_frame, model_size) prepare_vision_frame = prepare_vision_frame[:, :, ::-1].astype(numpy.float32) / 255 prepare_vision_frame = numpy.subtract(prepare_vision_frame, numpy.array([ 0.485, 0.456, 0.406 ]).astype(numpy.float32)) diff --git a/facefusion/face_recognizer.py b/facefusion/face_recognizer.py index bba66811..d1f094e8 100644 --- a/facefusion/face_recognizer.py +++ b/facefusion/face_recognizer.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Tuple import numpy @@ -9,30 +10,33 @@ from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import conditional_thread_semaphore from facefusion.typing import Embedding, FaceLandmark5, InferencePool, ModelOptions, ModelSet, VisionFrame -MODEL_SET : ModelSet =\ -{ - 'arcface': + +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: + return\ { - 'hashes': + 'arcface': { - 'face_recognizer': + 'hashes': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/arcface_w600k_r50.hash', - 'path': resolve_relative_path('../.assets/models/arcface_w600k_r50.hash') - } - }, - 'sources': - { - 'face_recognizer': + 'face_recognizer': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/arcface_w600k_r50.hash', + 'path': resolve_relative_path('../.assets/models/arcface_w600k_r50.hash') + } + }, + 'sources': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/arcface_w600k_r50.onnx', - 'path': resolve_relative_path('../.assets/models/arcface_w600k_r50.onnx') - } - }, - 'template': 'arcface_112_v2', - 'size': (112, 112) + 'face_recognizer': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/arcface_w600k_r50.onnx', + 'path': resolve_relative_path('../.assets/models/arcface_w600k_r50.onnx') + } + }, + 'template': 'arcface_112_v2', + 'size': (112, 112) + } } -} def get_inference_pool() -> InferencePool: @@ -45,7 +49,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: - return MODEL_SET.get('arcface') + return create_static_model_set().get('arcface') def pre_check() -> bool: diff --git a/facefusion/processors/modules/age_modifier.py b/facefusion/processors/modules/age_modifier.py index bafb3af6..b0ee0bdd 100755 --- a/facefusion/processors/modules/age_modifier.py +++ b/facefusion/processors/modules/age_modifier.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +from functools import lru_cache from typing import List import cv2 @@ -24,7 +25,8 @@ from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, ModelOp from facefusion.vision import match_frame_color, read_image, read_static_image, write_image -def create_model_set() -> ModelSet: +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: return\ { 'styleganex_age': @@ -73,7 +75,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: age_modifier_model = state_manager.get_item('age_modifier_model') - return create_model_set().get(age_modifier_model) + return create_static_model_set().get(age_modifier_model) def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/processors/modules/deep_swapper.py b/facefusion/processors/modules/deep_swapper.py index 9d30b967..336b6957 100755 --- a/facefusion/processors/modules/deep_swapper.py +++ b/facefusion/processors/modules/deep_swapper.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +from functools import lru_cache from typing import List, Tuple import cv2 @@ -24,7 +25,8 @@ from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, Mask, M from facefusion.vision import conditional_match_frame_color, read_image, read_static_image, write_image -def create_model_set() -> ModelSet: +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: model_config =\ [ ('druuzil', 'adrianne_palicki_384', (384, 384)), @@ -217,7 +219,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: deep_swapper_model = state_manager.get_item('deep_swapper_model') - return create_model_set().get(deep_swapper_model) + return create_static_model_set().get(deep_swapper_model) def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/processors/modules/expression_restorer.py b/facefusion/processors/modules/expression_restorer.py index bed16587..7fe3e9d4 100755 --- a/facefusion/processors/modules/expression_restorer.py +++ b/facefusion/processors/modules/expression_restorer.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +from functools import lru_cache from typing import List, Tuple import cv2 @@ -26,7 +27,8 @@ from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, ModelOp from facefusion.vision import get_video_frame, read_image, read_static_image, write_image -def create_model_set() -> ModelSet: +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: return\ { 'live_portrait': @@ -85,7 +87,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: expression_restorer_model = state_manager.get_item('expression_restorer_model') - return create_model_set().get(expression_restorer_model) + return create_static_model_set().get(expression_restorer_model) def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/processors/modules/face_editor.py b/facefusion/processors/modules/face_editor.py index fafbb1c4..4f9bac95 100755 --- a/facefusion/processors/modules/face_editor.py +++ b/facefusion/processors/modules/face_editor.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +from functools import lru_cache from typing import List, Tuple import cv2 @@ -25,7 +26,8 @@ from facefusion.typing import ApplyStateItem, Args, Face, FaceLandmark68, Infere from facefusion.vision import read_image, read_static_image, write_image -def create_model_set() -> ModelSet: +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: return\ { 'live_portrait': @@ -115,7 +117,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: face_editor_model = state_manager.get_item('face_editor_model') - return create_model_set().get(face_editor_model) + return create_static_model_set().get(face_editor_model) def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/processors/modules/face_enhancer.py b/facefusion/processors/modules/face_enhancer.py index 2124a4d4..d539d0dd 100755 --- a/facefusion/processors/modules/face_enhancer.py +++ b/facefusion/processors/modules/face_enhancer.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +from functools import lru_cache from typing import List import cv2 @@ -24,7 +25,8 @@ from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, ModelOp from facefusion.vision import read_image, read_static_image, write_image -def create_model_set() -> ModelSet: +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: return\ { 'codeformer': @@ -232,7 +234,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: face_enhancer_model = state_manager.get_item('face_enhancer_model') - return create_model_set().get(face_enhancer_model) + return create_static_model_set().get(face_enhancer_model) def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/processors/modules/frame_colorizer.py b/facefusion/processors/modules/frame_colorizer.py index f92d1168..521bc474 100644 --- a/facefusion/processors/modules/frame_colorizer.py +++ b/facefusion/processors/modules/frame_colorizer.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +from functools import lru_cache from typing import List import cv2 @@ -19,7 +20,8 @@ from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, ModelOp from facefusion.vision import read_image, read_static_image, unpack_resolution, write_image -def create_model_set() -> ModelSet: +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: return\ { 'ddcolor': @@ -138,7 +140,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: frame_colorizer_model = state_manager.get_item('frame_colorizer_model') - return create_model_set().get(frame_colorizer_model) + return create_static_model_set().get(frame_colorizer_model) def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/processors/modules/frame_enhancer.py b/facefusion/processors/modules/frame_enhancer.py index 8cf2d9ac..016fffab 100644 --- a/facefusion/processors/modules/frame_enhancer.py +++ b/facefusion/processors/modules/frame_enhancer.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +from functools import lru_cache from typing import List import cv2 @@ -19,7 +20,8 @@ from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, ModelOp from facefusion.vision import create_tile_frames, merge_tile_frames, read_image, read_static_image, write_image -def create_model_set() -> ModelSet: +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: return\ { 'clear_reality_x4': @@ -395,7 +397,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: frame_enhancer_model = state_manager.get_item('frame_enhancer_model') - return create_model_set().get(frame_enhancer_model) + return create_static_model_set().get(frame_enhancer_model) def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/processors/modules/lip_syncer.py b/facefusion/processors/modules/lip_syncer.py index aa767631..610cb0e6 100755 --- a/facefusion/processors/modules/lip_syncer.py +++ b/facefusion/processors/modules/lip_syncer.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser +from functools import lru_cache from typing import List import cv2 @@ -25,7 +26,8 @@ from facefusion.typing import ApplyStateItem, Args, AudioFrame, Face, InferenceP from facefusion.vision import read_image, read_static_image, restrict_video_fps, write_image -def create_model_set() -> ModelSet: +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: return\ { 'wav2lip_96': @@ -84,7 +86,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: lip_syncer_model = state_manager.get_item('lip_syncer_model') - return create_model_set().get(lip_syncer_model) + return create_static_model_set().get(lip_syncer_model) def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/uis/components/download.py b/facefusion/uis/components/download.py index 7a4c2e2f..4f1b50e0 100644 --- a/facefusion/uis/components/download.py +++ b/facefusion/uis/components/download.py @@ -4,6 +4,7 @@ import gradio from facefusion import state_manager, wording from facefusion.choices import download_provider_set +from facefusion.core import clear_model_sets from facefusion.typing import DownloadProviderKey DOWNLOAD_PROVIDERS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None @@ -24,6 +25,7 @@ def listen() -> None: def update_download_providers(download_providers : List[DownloadProviderKey]) -> gradio.CheckboxGroup: + clear_model_sets() download_providers = download_providers or list(download_provider_set.keys()) state_manager.set_item('download_providers', download_providers) return gradio.CheckboxGroup(value = state_manager.get_item('download_providers')) diff --git a/facefusion/voice_extractor.py b/facefusion/voice_extractor.py index 8bbedbb0..a66114ee 100644 --- a/facefusion/voice_extractor.py +++ b/facefusion/voice_extractor.py @@ -1,3 +1,4 @@ +from functools import lru_cache from typing import Tuple import numpy @@ -9,28 +10,31 @@ from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import thread_semaphore from facefusion.typing import Audio, AudioChunk, InferencePool, ModelOptions, ModelSet -MODEL_SET : ModelSet =\ -{ - 'kim_vocal_2': + +@lru_cache(maxsize = None) +def create_static_model_set() -> ModelSet: + return\ { - 'hashes': + 'kim_vocal_2': { - 'voice_extractor': + 'hashes': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/kim_vocal_2.hash', - 'path': resolve_relative_path('../.assets/models/kim_vocal_2.hash') - } - }, - 'sources': - { - 'voice_extractor': + 'voice_extractor': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/kim_vocal_2.hash', + 'path': resolve_relative_path('../.assets/models/kim_vocal_2.hash') + } + }, + 'sources': { - 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/kim_vocal_2.onnx', - 'path': resolve_relative_path('../.assets/models/kim_vocal_2.onnx') + 'voice_extractor': + { + 'url': 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/kim_vocal_2.onnx', + 'path': resolve_relative_path('../.assets/models/kim_vocal_2.onnx') + } } } } -} def get_inference_pool() -> InferencePool: @@ -43,7 +47,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: - return MODEL_SET.get('kim_vocal_2') + return create_static_model_set().get('kim_vocal_2') def pre_check() -> bool: