Refactor/choices and naming (#833)
* Refactor choices, imports and naming * Refactor choices, imports and naming
This commit is contained in:
parent
5856ac509e
commit
2ca200bd0a
@ -2,7 +2,7 @@ import logging
|
||||
from typing import List, Sequence
|
||||
|
||||
from facefusion.common_helper import create_float_range, create_int_range
|
||||
from facefusion.typing import Angle, DownloadProviderSet, DownloadScope, ExecutionProviderSet, FaceDetectorSet, FaceLandmarkerModel, FaceMaskRegion, FaceMaskType, FaceSelectorMode, FaceSelectorOrder, Gender, JobStatus, LogLevelSet, OutputAudioEncoder, OutputVideoEncoder, OutputVideoPreset, Race, Score, TempFrameFormat, UiWorkflow, VideoMemoryStrategy
|
||||
from facefusion.typing import Angle, DownloadProvider, DownloadProviderSet, DownloadScope, ExecutionProvider, ExecutionProviderSet, FaceDetectorModel, FaceDetectorSet, FaceLandmarkerModel, FaceMaskRegion, FaceMaskType, FaceSelectorMode, FaceSelectorOrder, Gender, JobStatus, LogLevel, LogLevelSet, OutputAudioEncoder, OutputVideoEncoder, OutputVideoPreset, Race, Score, TempFrameFormat, UiWorkflow, VideoMemoryStrategy
|
||||
|
||||
face_detector_set : FaceDetectorSet =\
|
||||
{
|
||||
@ -11,6 +11,7 @@ face_detector_set : FaceDetectorSet =\
|
||||
'scrfd': [ '160x160', '320x320', '480x480', '512x512', '640x640' ],
|
||||
'yoloface': [ '640x640' ]
|
||||
}
|
||||
face_detector_models : List[FaceDetectorModel] = list(face_detector_set.keys())
|
||||
face_landmarker_models : List[FaceLandmarkerModel] = [ 'many', '2dfan4', 'peppa_wutz' ]
|
||||
face_selector_modes : List[FaceSelectorMode] = [ 'many', 'one', 'reference' ]
|
||||
face_selector_orders : List[FaceSelectorOrder] = [ 'left-right', 'right-left', 'top-bottom', 'bottom-top', 'small-large', 'large-small', 'best-worst', 'worst-best' ]
|
||||
@ -36,11 +37,13 @@ execution_provider_set : ExecutionProviderSet =\
|
||||
'rocm': 'ROCMExecutionProvider',
|
||||
'tensorrt': 'TensorrtExecutionProvider'
|
||||
}
|
||||
execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys())
|
||||
download_provider_set : DownloadProviderSet =\
|
||||
{
|
||||
'github': 'https://github.com/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}',
|
||||
'huggingface': 'https://huggingface.co/facefusion/{base_name}/resolve/main/{file_name}'
|
||||
}
|
||||
download_providers : List[DownloadProvider] = list(download_provider_set.keys())
|
||||
download_scopes : List[DownloadScope] = [ 'lite', 'full' ]
|
||||
|
||||
video_memory_strategies : List[VideoMemoryStrategy] = [ 'strict', 'moderate', 'tolerant' ]
|
||||
@ -52,6 +55,7 @@ log_level_set : LogLevelSet =\
|
||||
'info': logging.INFO,
|
||||
'debug': logging.DEBUG
|
||||
}
|
||||
log_levels : List[LogLevel] = list(log_level_set.keys())
|
||||
|
||||
ui_workflows : List[UiWorkflow] = [ 'instant_runner', 'job_runner', 'job_manager' ]
|
||||
job_statuses : List[JobStatus] = [ 'drafted', 'queued', 'completed', 'failed' ]
|
||||
|
@ -9,12 +9,12 @@ from urllib.parse import urlparse
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import facefusion.choices
|
||||
from facefusion import logger, process_manager, state_manager, wording
|
||||
from facefusion.choices import download_provider_set
|
||||
from facefusion.common_helper import is_macos
|
||||
from facefusion.filesystem import get_file_size, is_file, remove_file
|
||||
from facefusion.hash_helper import validate_hash
|
||||
from facefusion.typing import DownloadProviderKey, DownloadSet
|
||||
from facefusion.typing import DownloadProvider, DownloadSet
|
||||
|
||||
if is_macos():
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
@ -128,11 +128,11 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str
|
||||
def resolve_download_url(base_name : str, file_name : str) -> Optional[str]:
|
||||
download_providers = state_manager.get_item('download_providers')
|
||||
|
||||
for download_provider in download_provider_set:
|
||||
for download_provider in facefusion.choices.download_provider_set:
|
||||
if download_provider in download_providers:
|
||||
return resolve_download_url_by_provider(download_provider, base_name, file_name)
|
||||
return None
|
||||
|
||||
|
||||
def resolve_download_url_by_provider(download_provider : DownloadProviderKey, base_name : str, file_name : str) -> Optional[str]:
|
||||
return download_provider_set.get(download_provider).format(base_name = base_name, file_name = file_name)
|
||||
def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]:
|
||||
return facefusion.choices.download_provider_set.get(download_provider).format(base_name = base_name, file_name = file_name)
|
||||
|
@ -6,37 +6,38 @@ from typing import Any, List, Optional
|
||||
|
||||
from onnxruntime import get_available_providers, set_default_logger_severity
|
||||
|
||||
from facefusion.choices import execution_provider_set
|
||||
from facefusion.typing import ExecutionDevice, ExecutionProviderKey, ExecutionProviderSet, ValueAndUnit
|
||||
import facefusion.choices
|
||||
from facefusion.typing import ExecutionDevice, ExecutionProvider, ValueAndUnit
|
||||
|
||||
set_default_logger_severity(3)
|
||||
|
||||
|
||||
def has_execution_provider(execution_provider_key : ExecutionProviderKey) -> bool:
|
||||
return execution_provider_key in get_execution_provider_set().keys()
|
||||
def has_execution_provider(execution_provider : ExecutionProvider) -> bool:
|
||||
return execution_provider in get_available_execution_providers()
|
||||
|
||||
|
||||
def get_execution_provider_set() -> ExecutionProviderSet:
|
||||
available_execution_providers = get_available_providers()
|
||||
available_execution_provider_set : ExecutionProviderSet = {}
|
||||
def get_available_execution_providers() -> List[ExecutionProvider]:
|
||||
inference_execution_providers = get_available_providers()
|
||||
available_execution_providers = []
|
||||
|
||||
for execution_provider_key, execution_provider_value in execution_provider_set.items():
|
||||
if execution_provider_value in available_execution_providers:
|
||||
available_execution_provider_set[execution_provider_key] = execution_provider_value
|
||||
return available_execution_provider_set
|
||||
for execution_provider, execution_provider_value in facefusion.choices.execution_provider_set.items():
|
||||
if execution_provider_value in inference_execution_providers:
|
||||
available_execution_providers.append(execution_provider)
|
||||
|
||||
return available_execution_providers
|
||||
|
||||
|
||||
def create_execution_providers(execution_device_id : str, execution_provider_keys : List[ExecutionProviderKey]) -> List[Any]:
|
||||
execution_providers : List[Any] = []
|
||||
def create_inference_execution_providers(execution_device_id : str, execution_providers : List[ExecutionProvider]) -> List[Any]:
|
||||
inference_execution_providers : List[Any] = []
|
||||
|
||||
for execution_provider_key in execution_provider_keys:
|
||||
if execution_provider_key == 'cuda':
|
||||
execution_providers.append((execution_provider_set.get(execution_provider_key),
|
||||
for execution_provider in execution_providers:
|
||||
if execution_provider == 'cuda':
|
||||
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
||||
{
|
||||
'device_id': execution_device_id
|
||||
}))
|
||||
if execution_provider_key == 'tensorrt':
|
||||
execution_providers.append((execution_provider_set.get(execution_provider_key),
|
||||
if execution_provider == 'tensorrt':
|
||||
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
||||
{
|
||||
'device_id': execution_device_id,
|
||||
'trt_engine_cache_enable': True,
|
||||
@ -45,24 +46,24 @@ def create_execution_providers(execution_device_id : str, execution_provider_key
|
||||
'trt_timing_cache_path': '.caches',
|
||||
'trt_builder_optimization_level': 5
|
||||
}))
|
||||
if execution_provider_key == 'openvino':
|
||||
execution_providers.append((execution_provider_set.get(execution_provider_key),
|
||||
if execution_provider == 'openvino':
|
||||
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
||||
{
|
||||
'device_type': 'GPU' if execution_device_id == '0' else 'GPU.' + execution_device_id,
|
||||
'precision': 'FP32'
|
||||
}))
|
||||
if execution_provider_key in [ 'directml', 'rocm' ]:
|
||||
execution_providers.append((execution_provider_set.get(execution_provider_key),
|
||||
if execution_provider in [ 'directml', 'rocm' ]:
|
||||
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
||||
{
|
||||
'device_id': execution_device_id
|
||||
}))
|
||||
if execution_provider_key == 'coreml':
|
||||
execution_providers.append(execution_provider_set.get(execution_provider_key))
|
||||
if execution_provider == 'coreml':
|
||||
inference_execution_providers.append(facefusion.choices.execution_provider_set.get(execution_provider))
|
||||
|
||||
if 'cpu' in execution_provider_keys:
|
||||
execution_providers.append(execution_provider_set.get('cpu'))
|
||||
if 'cpu' in execution_providers:
|
||||
inference_execution_providers.append(facefusion.choices.execution_provider_set.get('cpu'))
|
||||
|
||||
return execution_providers
|
||||
return inference_execution_providers
|
||||
|
||||
|
||||
def run_nvidia_smi() -> subprocess.Popen[bytes]:
|
||||
|
@ -5,9 +5,9 @@ from onnxruntime import InferenceSession
|
||||
|
||||
from facefusion import process_manager, state_manager
|
||||
from facefusion.app_context import detect_app_context
|
||||
from facefusion.execution import create_execution_providers
|
||||
from facefusion.execution import create_inference_execution_providers
|
||||
from facefusion.thread_helper import thread_lock
|
||||
from facefusion.typing import DownloadSet, ExecutionProviderKey, InferencePool, InferencePoolSet
|
||||
from facefusion.typing import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
|
||||
|
||||
INFERENCE_POOLS : InferencePoolSet =\
|
||||
{
|
||||
@ -35,11 +35,11 @@ def get_inference_pool(model_context : str, model_sources : DownloadSet) -> Infe
|
||||
return INFERENCE_POOLS.get(app_context).get(inference_context)
|
||||
|
||||
|
||||
def create_inference_pool(model_sources : DownloadSet, execution_device_id : str, execution_provider_keys : List[ExecutionProviderKey]) -> InferencePool:
|
||||
def create_inference_pool(model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool:
|
||||
inference_pool : InferencePool = {}
|
||||
|
||||
for model_name in model_sources.keys():
|
||||
inference_pool[model_name] = create_inference_session(model_sources.get(model_name).get('path'), execution_device_id, execution_provider_keys)
|
||||
inference_pool[model_name] = create_inference_session(model_sources.get(model_name).get('path'), execution_device_id, execution_providers)
|
||||
return inference_pool
|
||||
|
||||
|
||||
@ -53,9 +53,9 @@ def clear_inference_pool(model_context : str) -> None:
|
||||
del INFERENCE_POOLS[app_context][inference_context]
|
||||
|
||||
|
||||
def create_inference_session(model_path : str, execution_device_id : str, execution_provider_keys : List[ExecutionProviderKey]) -> InferenceSession:
|
||||
execution_providers = create_execution_providers(execution_device_id, execution_provider_keys)
|
||||
return InferenceSession(model_path, providers = execution_providers)
|
||||
def create_inference_session(model_path : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferenceSession:
|
||||
inference_execution_providers = create_inference_execution_providers(execution_device_id, execution_providers)
|
||||
return InferenceSession(model_path, providers = inference_execution_providers)
|
||||
|
||||
|
||||
def get_inference_context(model_context : str) -> str:
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
from copy import copy
|
||||
from typing import List, Optional
|
||||
|
||||
from facefusion.choices import job_statuses
|
||||
import facefusion.choices
|
||||
from facefusion.date_helper import get_current_date_time
|
||||
from facefusion.filesystem import create_directory, is_directory, is_file, move_file, remove_directory, remove_file, resolve_file_pattern
|
||||
from facefusion.jobs.job_helper import get_step_output_path
|
||||
@ -16,7 +16,7 @@ def init_jobs(jobs_path : str) -> bool:
|
||||
global JOBS_PATH
|
||||
|
||||
JOBS_PATH = jobs_path
|
||||
job_status_paths = [ os.path.join(JOBS_PATH, job_status) for job_status in job_statuses ]
|
||||
job_status_paths = [ os.path.join(JOBS_PATH, job_status) for job_status in facefusion.choices.job_statuses ]
|
||||
|
||||
for job_status_path in job_status_paths:
|
||||
create_directory(job_status_path)
|
||||
@ -245,7 +245,7 @@ def find_job_path(job_id : str) -> Optional[str]:
|
||||
job_file_name = get_job_file_name(job_id)
|
||||
|
||||
if job_file_name:
|
||||
for job_status in job_statuses:
|
||||
for job_status in facefusion.choices.job_statuses:
|
||||
job_pattern = os.path.join(JOBS_PATH, job_status, job_file_name)
|
||||
job_paths = resolve_file_pattern(job_pattern)
|
||||
|
||||
|
@ -1,14 +1,14 @@
|
||||
from logging import Logger, basicConfig, getLogger
|
||||
from typing import Tuple
|
||||
|
||||
from facefusion.choices import log_level_set
|
||||
import facefusion.choices
|
||||
from facefusion.common_helper import get_first, get_last
|
||||
from facefusion.typing import LogLevel, TableContents, TableHeaders
|
||||
|
||||
|
||||
def init(log_level : LogLevel) -> None:
|
||||
basicConfig(format = '%(message)s')
|
||||
get_package_logger().setLevel(log_level_set.get(log_level))
|
||||
get_package_logger().setLevel(facefusion.choices.log_level_set.get(log_level))
|
||||
|
||||
|
||||
def get_package_logger() -> Logger:
|
||||
|
@ -2,7 +2,7 @@ from typing import List, Sequence
|
||||
|
||||
from facefusion.common_helper import create_float_range, create_int_range
|
||||
from facefusion.filesystem import list_directory, resolve_relative_path
|
||||
from facefusion.processors.typing import AgeModifierModel, DeepSwapperModel, ExpressionRestorerModel, FaceDebuggerItem, FaceEditorModel, FaceEnhancerModel, FaceSwapperSet, FrameColorizerModel, FrameEnhancerModel, LipSyncerModel
|
||||
from facefusion.processors.typing import AgeModifierModel, DeepSwapperModel, ExpressionRestorerModel, FaceDebuggerItem, FaceEditorModel, FaceEnhancerModel, FaceSwapperModel, FaceSwapperSet, FrameColorizerModel, FrameEnhancerModel, LipSyncerModel
|
||||
|
||||
age_modifier_models : List[AgeModifierModel] = [ 'styleganex_age' ]
|
||||
deep_swapper_models : List[DeepSwapperModel] =\
|
||||
@ -156,12 +156,12 @@ deep_swapper_models : List[DeepSwapperModel] =\
|
||||
'rumateus/taylor_swift_224'
|
||||
]
|
||||
|
||||
model_files = list_directory(resolve_relative_path('../.assets/models/local'))
|
||||
custom_model_files = list_directory(resolve_relative_path('../.assets/models/custom'))
|
||||
|
||||
if model_files:
|
||||
if custom_model_files:
|
||||
|
||||
for model_file in model_files:
|
||||
model_id = '/'.join([ 'local', model_file.get('name') ])
|
||||
for model_file in custom_model_files:
|
||||
model_id = '/'.join([ 'custom', model_file.get('name') ])
|
||||
deep_swapper_models.append(model_id)
|
||||
|
||||
expression_restorer_models : List[ExpressionRestorerModel] = [ 'live_portrait' ]
|
||||
@ -181,6 +181,7 @@ face_swapper_set : FaceSwapperSet =\
|
||||
'simswap_unofficial_512': [ '512x512', '768x768', '1024x1024' ],
|
||||
'uniface_256': [ '256x256', '512x512', '768x768', '1024x1024' ]
|
||||
}
|
||||
face_swapper_models : List[FaceSwapperModel] = list(face_swapper_set.keys())
|
||||
frame_colorizer_models : List[FrameColorizerModel] = [ 'ddcolor', 'ddcolor_artistic', 'deoldify', 'deoldify_artistic', 'deoldify_stable' ]
|
||||
frame_colorizer_sizes : List[str] = [ '192x192', '256x256', '384x384', '512x512' ]
|
||||
frame_enhancer_models : List[FrameEnhancerModel] = [ 'clear_reality_x4', 'lsdir_x4', 'nomos8k_sc_x4', 'real_esrgan_x2', 'real_esrgan_x2_fp16', 'real_esrgan_x4', 'real_esrgan_x4_fp16', 'real_esrgan_x8', 'real_esrgan_x8_fp16', 'real_hatgan_x4', 'real_web_photo_x4', 'realistic_rescaler_x4', 'remacri_x4', 'siax_x4', 'span_kendata_x4', 'swin2_sr_x4', 'ultra_sharp_x4' ]
|
||||
|
@ -5,11 +5,11 @@ from typing import List
|
||||
import cv2
|
||||
import numpy
|
||||
|
||||
import facefusion.choices
|
||||
import facefusion.jobs.job_manager
|
||||
import facefusion.jobs.job_store
|
||||
import facefusion.processors.core as processors
|
||||
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, wording
|
||||
from facefusion.choices import execution_provider_set
|
||||
from facefusion.common_helper import create_int_metavar
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.execution import has_execution_provider
|
||||
@ -163,7 +163,7 @@ def forward(crop_vision_frame : VisionFrame, extend_vision_frame : VisionFrame,
|
||||
age_modifier_inputs = {}
|
||||
|
||||
if has_execution_provider('coreml'):
|
||||
age_modifier.set_providers([ execution_provider_set.get('cpu') ])
|
||||
age_modifier.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ])
|
||||
|
||||
for age_modifier_input in age_modifier.get_inputs():
|
||||
if age_modifier_input.name == 'target':
|
||||
|
@ -215,12 +215,12 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'template': 'dfl_whole_face'
|
||||
}
|
||||
|
||||
model_files = list_directory(resolve_relative_path('../.assets/models/local'))
|
||||
custom_model_files = list_directory(resolve_relative_path('../.assets/models/custom'))
|
||||
|
||||
if model_files:
|
||||
if custom_model_files:
|
||||
|
||||
for model_file in model_files:
|
||||
model_id = '/'.join([ 'local', model_file.get('name') ])
|
||||
for model_file in custom_model_files:
|
||||
model_id = '/'.join([ 'custom', model_file.get('name') ])
|
||||
|
||||
model_set[model_id] =\
|
||||
{
|
||||
|
@ -4,11 +4,11 @@ from typing import List, Tuple
|
||||
|
||||
import numpy
|
||||
|
||||
import facefusion.choices
|
||||
import facefusion.jobs.job_manager
|
||||
import facefusion.jobs.job_store
|
||||
import facefusion.processors.core as processors
|
||||
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, wording
|
||||
from facefusion.choices import execution_provider_set
|
||||
from facefusion.common_helper import get_first
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.execution import has_execution_provider
|
||||
@ -354,7 +354,7 @@ def get_model_options() -> ModelOptions:
|
||||
def register_args(program : ArgumentParser) -> None:
|
||||
group_processors = find_argument_group(program, 'processors')
|
||||
if group_processors:
|
||||
group_processors.add_argument('--face-swapper-model', help = wording.get('help.face_swapper_model'), default = config.get_str_value('processors.face_swapper_model', 'inswapper_128_fp16'), choices = processors_choices.face_swapper_set.keys())
|
||||
group_processors.add_argument('--face-swapper-model', help = wording.get('help.face_swapper_model'), default = config.get_str_value('processors.face_swapper_model', 'inswapper_128_fp16'), choices = processors_choices.face_swapper_models)
|
||||
known_args, _ = program.parse_known_args()
|
||||
face_swapper_pixel_boost_choices = processors_choices.face_swapper_set.get(known_args.face_swapper_model)
|
||||
group_processors.add_argument('--face-swapper-pixel-boost', help = wording.get('help.face_swapper_pixel_boost'), default = config.get_str_value('processors.face_swapper_pixel_boost', get_first(face_swapper_pixel_boost_choices)), choices = face_swapper_pixel_boost_choices)
|
||||
@ -449,7 +449,7 @@ def forward_swap_face(source_face : Face, crop_vision_frame : VisionFrame) -> Vi
|
||||
face_swapper_inputs = {}
|
||||
|
||||
if has_execution_provider('coreml') and model_type in [ 'ghost', 'uniface' ]:
|
||||
face_swapper.set_providers([ execution_provider_set.get('cpu') ])
|
||||
face_swapper.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ])
|
||||
|
||||
for face_swapper_input in face_swapper.get_inputs():
|
||||
if face_swapper_input.name == 'source':
|
||||
|
@ -4,7 +4,7 @@ from argparse import ArgumentParser, HelpFormatter
|
||||
import facefusion.choices
|
||||
from facefusion import config, metadata, state_manager, wording
|
||||
from facefusion.common_helper import create_float_metavar, create_int_metavar, get_last
|
||||
from facefusion.execution import get_execution_provider_set
|
||||
from facefusion.execution import get_available_execution_providers
|
||||
from facefusion.filesystem import list_directory
|
||||
from facefusion.jobs import job_store
|
||||
from facefusion.processors.core import get_processors_modules
|
||||
@ -94,7 +94,7 @@ def create_output_pattern_program() -> ArgumentParser:
|
||||
def create_face_detector_program() -> ArgumentParser:
|
||||
program = ArgumentParser(add_help = False)
|
||||
group_face_detector = program.add_argument_group('face detector')
|
||||
group_face_detector.add_argument('--face-detector-model', help = wording.get('help.face_detector_model'), default = config.get_str_value('face_detector.face_detector_model', 'yoloface'), choices = list(facefusion.choices.face_detector_set.keys()))
|
||||
group_face_detector.add_argument('--face-detector-model', help = wording.get('help.face_detector_model'), default = config.get_str_value('face_detector.face_detector_model', 'yoloface'), choices = facefusion.choices.face_detector_models)
|
||||
known_args, _ = program.parse_known_args()
|
||||
face_detector_size_choices = facefusion.choices.face_detector_set.get(known_args.face_detector_model)
|
||||
group_face_detector.add_argument('--face-detector-size', help = wording.get('help.face_detector_size'), default = config.get_str_value('face_detector.face_detector_size', get_last(face_detector_size_choices)), choices = face_detector_size_choices)
|
||||
@ -190,9 +190,10 @@ def create_uis_program() -> ArgumentParser:
|
||||
|
||||
def create_execution_program() -> ArgumentParser:
|
||||
program = ArgumentParser(add_help = False)
|
||||
available_execution_providers = get_available_execution_providers()
|
||||
group_execution = program.add_argument_group('execution')
|
||||
group_execution.add_argument('--execution-device-id', help = wording.get('help.execution_device_id'), default = config.get_str_value('execution.execution_device_id', '0'))
|
||||
group_execution.add_argument('--execution-providers', help = wording.get('help.execution_providers').format(choices = ', '.join(list(get_execution_provider_set().keys()))), default = config.get_str_list('execution.execution_providers', 'cpu'), choices = list(get_execution_provider_set().keys()), nargs = '+', metavar = 'EXECUTION_PROVIDERS')
|
||||
group_execution.add_argument('--execution-providers', help = wording.get('help.execution_providers').format(choices = ', '.join(available_execution_providers)), default = config.get_str_list('execution.execution_providers', 'cpu'), choices = available_execution_providers, nargs = '+', metavar = 'EXECUTION_PROVIDERS')
|
||||
group_execution.add_argument('--execution-thread-count', help = wording.get('help.execution_thread_count'), type = int, default = config.get_int_value('execution.execution_thread_count', '4'), choices = facefusion.choices.execution_thread_count_range, metavar = create_int_metavar(facefusion.choices.execution_thread_count_range))
|
||||
group_execution.add_argument('--execution-queue-count', help = wording.get('help.execution_queue_count'), type = int, default = config.get_int_value('execution.execution_queue_count', '1'), choices = facefusion.choices.execution_queue_count_range, metavar = create_int_metavar(facefusion.choices.execution_queue_count_range))
|
||||
job_store.register_job_keys([ 'execution_device_id', 'execution_providers', 'execution_thread_count', 'execution_queue_count' ])
|
||||
@ -201,8 +202,9 @@ def create_execution_program() -> ArgumentParser:
|
||||
|
||||
def create_download_providers_program() -> ArgumentParser:
|
||||
program = ArgumentParser(add_help = False)
|
||||
download_providers = list(facefusion.choices.download_provider_set.keys())
|
||||
group_download = program.add_argument_group('download')
|
||||
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(list(facefusion.choices.download_provider_set.keys()))), default = config.get_str_list('download.download_providers', 'github'), choices = list(facefusion.choices.download_provider_set.keys()), nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
|
||||
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', 'github'), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
|
||||
job_store.register_job_keys([ 'download_providers' ])
|
||||
return program
|
||||
|
||||
@ -210,7 +212,7 @@ def create_download_providers_program() -> ArgumentParser:
|
||||
def create_download_scope_program() -> ArgumentParser:
|
||||
program = ArgumentParser(add_help = False)
|
||||
group_download = program.add_argument_group('download')
|
||||
group_download.add_argument('--download-scope', help = wording.get('help.download_scope'), default = config.get_str_value('download.download_scope', 'lite'), choices = list(facefusion.choices.download_scopes))
|
||||
group_download.add_argument('--download-scope', help = wording.get('help.download_scope'), default = config.get_str_value('download.download_scope', 'lite'), choices = facefusion.choices.download_scopes)
|
||||
job_store.register_job_keys([ 'download_scope' ])
|
||||
return program
|
||||
|
||||
@ -226,8 +228,9 @@ def create_memory_program() -> ArgumentParser:
|
||||
|
||||
def create_misc_program() -> ArgumentParser:
|
||||
program = ArgumentParser(add_help = False)
|
||||
log_level_keys = list(facefusion.choices.log_level_set.keys())
|
||||
group_misc = program.add_argument_group('misc')
|
||||
group_misc.add_argument('--log-level', help = wording.get('help.log_level'), default = config.get_str_value('misc.log_level', 'info'), choices = list(facefusion.choices.log_level_set.keys()))
|
||||
group_misc.add_argument('--log-level', help = wording.get('help.log_level'), default = config.get_str_value('misc.log_level', 'info'), choices = log_level_keys)
|
||||
job_store.register_job_keys([ 'log_level' ])
|
||||
return program
|
||||
|
||||
|
@ -112,9 +112,9 @@ ModelOptions = Dict[str, Any]
|
||||
ModelSet = Dict[str, ModelOptions]
|
||||
ModelInitializer = NDArray[Any]
|
||||
|
||||
ExecutionProviderKey = Literal['cpu', 'coreml', 'cuda', 'directml', 'openvino', 'rocm', 'tensorrt']
|
||||
ExecutionProvider = Literal['cpu', 'coreml', 'cuda', 'directml', 'openvino', 'rocm', 'tensorrt']
|
||||
ExecutionProviderValue = Literal['CPUExecutionProvider', 'CoreMLExecutionProvider', 'CUDAExecutionProvider', 'DmlExecutionProvider', 'OpenVINOExecutionProvider', 'ROCMExecutionProvider', 'TensorrtExecutionProvider']
|
||||
ExecutionProviderSet = Dict[ExecutionProviderKey, ExecutionProviderValue]
|
||||
ExecutionProviderSet = Dict[ExecutionProvider, ExecutionProviderValue]
|
||||
ValueAndUnit = TypedDict('ValueAndUnit',
|
||||
{
|
||||
'value' : int,
|
||||
@ -155,8 +155,8 @@ ExecutionDevice = TypedDict('ExecutionDevice',
|
||||
'utilization' : ExecutionDeviceUtilization
|
||||
})
|
||||
|
||||
DownloadProviderKey = Literal['github', 'huggingface']
|
||||
DownloadProviderSet = Dict[DownloadProviderKey, str]
|
||||
DownloadProvider = Literal['github', 'huggingface']
|
||||
DownloadProviderSet = Dict[DownloadProvider, str]
|
||||
DownloadScope = Literal['lite', 'full']
|
||||
Download = TypedDict('Download',
|
||||
{
|
||||
@ -314,10 +314,10 @@ State = TypedDict('State',
|
||||
'ui_layouts' : List[str],
|
||||
'ui_workflow' : UiWorkflow,
|
||||
'execution_device_id' : str,
|
||||
'execution_providers' : List[ExecutionProviderKey],
|
||||
'execution_providers' : List[ExecutionProvider],
|
||||
'execution_thread_count' : int,
|
||||
'execution_queue_count' : int,
|
||||
'download_providers' : List[DownloadProviderKey],
|
||||
'download_providers' : List[DownloadProvider],
|
||||
'download_scope' : DownloadScope,
|
||||
'video_memory_strategy' : VideoMemoryStrategy,
|
||||
'system_memory_limit' : int,
|
||||
|
@ -2,11 +2,11 @@ from typing import List, Optional
|
||||
|
||||
import gradio
|
||||
|
||||
import facefusion.choices
|
||||
from facefusion import content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, state_manager, voice_extractor, wording
|
||||
from facefusion.choices import download_provider_set
|
||||
from facefusion.filesystem import list_directory
|
||||
from facefusion.processors.core import get_processors_modules
|
||||
from facefusion.typing import DownloadProviderKey
|
||||
from facefusion.typing import DownloadProvider
|
||||
|
||||
DOWNLOAD_PROVIDERS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None
|
||||
|
||||
@ -16,7 +16,7 @@ def render() -> None:
|
||||
|
||||
DOWNLOAD_PROVIDERS_CHECKBOX_GROUP = gradio.CheckboxGroup(
|
||||
label = wording.get('uis.download_providers_checkbox_group'),
|
||||
choices = list(download_provider_set.keys()),
|
||||
choices = facefusion.choices.download_providers,
|
||||
value = state_manager.get_item('download_providers')
|
||||
)
|
||||
|
||||
@ -25,7 +25,7 @@ def listen() -> None:
|
||||
DOWNLOAD_PROVIDERS_CHECKBOX_GROUP.change(update_download_providers, inputs = DOWNLOAD_PROVIDERS_CHECKBOX_GROUP, outputs = DOWNLOAD_PROVIDERS_CHECKBOX_GROUP)
|
||||
|
||||
|
||||
def update_download_providers(download_providers : List[DownloadProviderKey]) -> gradio.CheckboxGroup:
|
||||
def update_download_providers(download_providers : List[DownloadProvider]) -> gradio.CheckboxGroup:
|
||||
common_modules =\
|
||||
[
|
||||
content_analyser,
|
||||
@ -43,6 +43,6 @@ def update_download_providers(download_providers : List[DownloadProviderKey]) ->
|
||||
if hasattr(module, 'create_static_model_set'):
|
||||
module.create_static_model_set.cache_clear()
|
||||
|
||||
download_providers = download_providers or list(download_provider_set.keys())
|
||||
download_providers = download_providers or facefusion.choices.download_providers
|
||||
state_manager.set_item('download_providers', download_providers)
|
||||
return gradio.CheckboxGroup(value = state_manager.get_item('download_providers'))
|
||||
|
@ -3,10 +3,10 @@ from typing import List, Optional
|
||||
import gradio
|
||||
|
||||
from facefusion import content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, state_manager, voice_extractor, wording
|
||||
from facefusion.execution import get_execution_provider_set
|
||||
from facefusion.execution import get_available_execution_providers
|
||||
from facefusion.filesystem import list_directory
|
||||
from facefusion.processors.core import get_processors_modules
|
||||
from facefusion.typing import ExecutionProviderKey
|
||||
from facefusion.typing import ExecutionProvider
|
||||
|
||||
EXECUTION_PROVIDERS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None
|
||||
|
||||
@ -16,7 +16,7 @@ def render() -> None:
|
||||
|
||||
EXECUTION_PROVIDERS_CHECKBOX_GROUP = gradio.CheckboxGroup(
|
||||
label = wording.get('uis.execution_providers_checkbox_group'),
|
||||
choices = list(get_execution_provider_set().keys()),
|
||||
choices = get_available_execution_providers(),
|
||||
value = state_manager.get_item('execution_providers')
|
||||
)
|
||||
|
||||
@ -25,7 +25,7 @@ def listen() -> None:
|
||||
EXECUTION_PROVIDERS_CHECKBOX_GROUP.change(update_execution_providers, inputs = EXECUTION_PROVIDERS_CHECKBOX_GROUP, outputs = EXECUTION_PROVIDERS_CHECKBOX_GROUP)
|
||||
|
||||
|
||||
def update_execution_providers(execution_providers : List[ExecutionProviderKey]) -> gradio.CheckboxGroup:
|
||||
def update_execution_providers(execution_providers : List[ExecutionProvider]) -> gradio.CheckboxGroup:
|
||||
common_modules =\
|
||||
[
|
||||
content_analyser,
|
||||
@ -43,6 +43,6 @@ def update_execution_providers(execution_providers : List[ExecutionProviderKey])
|
||||
if hasattr(module, 'clear_inference_pool'):
|
||||
module.clear_inference_pool()
|
||||
|
||||
execution_providers = execution_providers or list(get_execution_provider_set())
|
||||
execution_providers = execution_providers or get_available_execution_providers()
|
||||
state_manager.set_item('execution_providers', execution_providers)
|
||||
return gradio.CheckboxGroup(value = state_manager.get_item('execution_providers'))
|
||||
|
@ -31,7 +31,7 @@ def render() -> None:
|
||||
with gradio.Row():
|
||||
FACE_DETECTOR_MODEL_DROPDOWN = gradio.Dropdown(
|
||||
label = wording.get('uis.face_detector_model_dropdown'),
|
||||
choices = list(facefusion.choices.face_detector_set.keys()),
|
||||
choices = facefusion.choices.face_detector_models,
|
||||
value = state_manager.get_item('face_detector_model')
|
||||
)
|
||||
FACE_DETECTOR_SIZE_DROPDOWN = gradio.Dropdown(**face_detector_size_dropdown_options)
|
||||
|
@ -20,7 +20,7 @@ def render() -> None:
|
||||
has_face_swapper = 'face_swapper' in state_manager.get_item('processors')
|
||||
FACE_SWAPPER_MODEL_DROPDOWN = gradio.Dropdown(
|
||||
label = wording.get('uis.face_swapper_model_dropdown'),
|
||||
choices = list(processors_choices.face_swapper_set.keys()),
|
||||
choices = processors_choices.face_swapper_models,
|
||||
value = state_manager.get_item('face_swapper_model'),
|
||||
visible = has_face_swapper
|
||||
)
|
||||
|
@ -7,8 +7,8 @@ from typing import Optional
|
||||
import gradio
|
||||
from tqdm import tqdm
|
||||
|
||||
import facefusion.choices
|
||||
from facefusion import logger, state_manager, wording
|
||||
from facefusion.choices import log_level_set
|
||||
from facefusion.typing import LogLevel
|
||||
|
||||
LOG_LEVEL_DROPDOWN : Optional[gradio.Dropdown] = None
|
||||
@ -24,7 +24,7 @@ def render() -> None:
|
||||
|
||||
LOG_LEVEL_DROPDOWN = gradio.Dropdown(
|
||||
label = wording.get('uis.log_level_dropdown'),
|
||||
choices = list(log_level_set.keys()),
|
||||
choices = facefusion.choices.log_levels,
|
||||
value = state_manager.get_item('log_level')
|
||||
)
|
||||
TERMINAL_TEXTBOX = gradio.Textbox(
|
||||
|
@ -5,7 +5,7 @@ import cv2
|
||||
import numpy
|
||||
from cv2.typing import Size
|
||||
|
||||
from facefusion.choices import image_template_sizes, video_template_sizes
|
||||
import facefusion.choices
|
||||
from facefusion.common_helper import is_windows
|
||||
from facefusion.filesystem import is_image, is_video, sanitize_path_for_windows
|
||||
from facefusion.typing import Duration, Fps, Orientation, Resolution, VisionFrame
|
||||
@ -64,8 +64,8 @@ def create_image_resolutions(resolution : Resolution) -> List[str]:
|
||||
if resolution:
|
||||
width, height = resolution
|
||||
temp_resolutions.append(normalize_resolution(resolution))
|
||||
for template_size in image_template_sizes:
|
||||
temp_resolutions.append(normalize_resolution((width * template_size, height * template_size)))
|
||||
for image_template_size in facefusion.choices.image_template_sizes:
|
||||
temp_resolutions.append(normalize_resolution((width * image_template_size, height * image_template_size)))
|
||||
temp_resolutions = sorted(set(temp_resolutions))
|
||||
for temp_resolution in temp_resolutions:
|
||||
resolutions.append(pack_resolution(temp_resolution))
|
||||
@ -155,11 +155,11 @@ def create_video_resolutions(resolution : Resolution) -> List[str]:
|
||||
if resolution:
|
||||
width, height = resolution
|
||||
temp_resolutions.append(normalize_resolution(resolution))
|
||||
for template_size in video_template_sizes:
|
||||
for video_template_size in facefusion.choices.video_template_sizes:
|
||||
if width > height:
|
||||
temp_resolutions.append(normalize_resolution((template_size * width / height, template_size)))
|
||||
temp_resolutions.append(normalize_resolution((video_template_size * width / height, video_template_size)))
|
||||
else:
|
||||
temp_resolutions.append(normalize_resolution((template_size, template_size * height / width)))
|
||||
temp_resolutions.append(normalize_resolution((video_template_size, video_template_size * height / width)))
|
||||
temp_resolutions = sorted(set(temp_resolutions))
|
||||
for temp_resolution in temp_resolutions:
|
||||
resolutions.append(pack_resolution(temp_resolution))
|
||||
|
@ -1,8 +1,8 @@
|
||||
from facefusion.execution import create_execution_providers, get_execution_provider_set, has_execution_provider
|
||||
from facefusion.execution import create_inference_execution_providers, get_available_execution_providers, has_execution_provider
|
||||
|
||||
|
||||
def test_get_execution_provider_set() -> None:
|
||||
assert 'cpu' in get_execution_provider_set().keys()
|
||||
assert 'cpu' in get_available_execution_providers()
|
||||
|
||||
|
||||
def test_has_execution_provider() -> None:
|
||||
@ -20,4 +20,4 @@ def test_multiple_execution_providers() -> None:
|
||||
'CPUExecutionProvider'
|
||||
]
|
||||
|
||||
assert create_execution_providers('1', [ 'cpu', 'cuda' ]) == execution_providers
|
||||
assert create_inference_execution_providers('1', [ 'cpu', 'cuda' ]) == execution_providers
|
||||
|
Loading…
Reference in New Issue
Block a user