Kill resolve_execution_provider_keys() and move fallbacks where they belong
This commit is contained in:
parent
2f98ac8471
commit
bae8c65cf0
@ -5,7 +5,7 @@ from onnxruntime import InferenceSession
|
||||
|
||||
from facefusion import process_manager, state_manager
|
||||
from facefusion.app_context import detect_app_context
|
||||
from facefusion.execution import create_execution_providers, has_execution_provider
|
||||
from facefusion.execution import create_execution_providers
|
||||
from facefusion.thread_helper import thread_lock
|
||||
from facefusion.typing import DownloadSet, ExecutionProviderKey, InferencePool, InferencePoolSet
|
||||
|
||||
@ -30,8 +30,7 @@ def get_inference_pool(model_context : str, model_sources : DownloadSet) -> Infe
|
||||
if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context):
|
||||
INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context)
|
||||
if not INFERENCE_POOLS.get(app_context).get(inference_context):
|
||||
execution_provider_keys = resolve_execution_provider_keys(model_context)
|
||||
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), execution_provider_keys)
|
||||
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), state_manager.get_item('execution_providers'))
|
||||
|
||||
return INFERENCE_POOLS.get(app_context).get(inference_context)
|
||||
|
||||
@ -59,13 +58,6 @@ def create_inference_session(model_path : str, execution_device_id : str, execut
|
||||
return InferenceSession(model_path, providers = execution_providers)
|
||||
|
||||
|
||||
def resolve_execution_provider_keys(model_context : str) -> List[ExecutionProviderKey]:
|
||||
if has_execution_provider('coreml') and (model_context.startswith('facefusion.processors.modules.age_modifier') or model_context.startswith('facefusion.processors.modules.frame_colorizer')):
|
||||
return [ 'cpu' ]
|
||||
return state_manager.get_item('execution_providers')
|
||||
|
||||
|
||||
def get_inference_context(model_context : str) -> str:
|
||||
execution_provider_keys = resolve_execution_provider_keys(model_context)
|
||||
inference_context = model_context + '.' + '_'.join(execution_provider_keys)
|
||||
inference_context = model_context + '.' + '_'.join(state_manager.get_item('execution_providers'))
|
||||
return inference_context
|
||||
|
@ -9,8 +9,10 @@ import facefusion.jobs.job_manager
|
||||
import facefusion.jobs.job_store
|
||||
import facefusion.processors.core as processors
|
||||
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, wording
|
||||
from facefusion.choices import execution_provider_set
|
||||
from facefusion.common_helper import create_int_metavar
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.execution import has_execution_provider
|
||||
from facefusion.face_analyser import get_many_faces, get_one_face
|
||||
from facefusion.face_helper import merge_matrix, paste_back, scale_face_landmark_5, warp_face_by_face_landmark_5
|
||||
from facefusion.face_masker import create_occlusion_mask, create_static_box_mask
|
||||
@ -160,6 +162,9 @@ def forward(crop_vision_frame : VisionFrame, extend_vision_frame : VisionFrame,
|
||||
age_modifier = get_inference_pool().get('age_modifier')
|
||||
age_modifier_inputs = {}
|
||||
|
||||
if has_execution_provider('coreml'):
|
||||
age_modifier.set_providers(execution_provider_set.get('cpu'))
|
||||
|
||||
for age_modifier_input in age_modifier.get_inputs():
|
||||
if age_modifier_input.name == 'target':
|
||||
age_modifier_inputs[age_modifier_input.name] = crop_vision_frame
|
||||
|
@ -8,6 +8,7 @@ import facefusion.jobs.job_manager
|
||||
import facefusion.jobs.job_store
|
||||
import facefusion.processors.core as processors
|
||||
from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, wording
|
||||
from facefusion.choices import execution_provider_set
|
||||
from facefusion.common_helper import get_first
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.execution import has_execution_provider
|
||||
@ -347,7 +348,6 @@ def clear_inference_pool() -> None:
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
face_swapper_model = state_manager.get_item('face_swapper_model')
|
||||
face_swapper_model = 'inswapper_128' if has_execution_provider('coreml') and face_swapper_model == 'inswapper_128_fp16' else face_swapper_model
|
||||
return create_static_model_set().get(face_swapper_model)
|
||||
|
||||
|
||||
@ -448,6 +448,9 @@ def forward_swap_face(source_face : Face, crop_vision_frame : VisionFrame) -> Vi
|
||||
model_type = get_model_options().get('type')
|
||||
face_swapper_inputs = {}
|
||||
|
||||
if has_execution_provider('coreml') and model_type in [ 'ghost', 'uniface' ]:
|
||||
face_swapper.set_providers(execution_provider_set.get('cpu'))
|
||||
|
||||
for face_swapper_input in face_swapper.get_inputs():
|
||||
if face_swapper_input.name == 'source':
|
||||
if model_type in [ 'blendswap', 'uniface' ]:
|
||||
|
@ -9,8 +9,10 @@ import facefusion.jobs.job_manager
|
||||
import facefusion.jobs.job_store
|
||||
import facefusion.processors.core as processors
|
||||
from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, wording
|
||||
from facefusion.choices import execution_provider_set
|
||||
from facefusion.common_helper import create_int_metavar
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.execution import has_execution_provider
|
||||
from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension
|
||||
from facefusion.processors import choices as processors_choices
|
||||
from facefusion.processors.typing import FrameColorizerInputs
|
||||
@ -197,6 +199,9 @@ def colorize_frame(temp_vision_frame : VisionFrame) -> VisionFrame:
|
||||
def forward(color_vision_frame : VisionFrame) -> VisionFrame:
|
||||
frame_colorizer = get_inference_pool().get('frame_colorizer')
|
||||
|
||||
if has_execution_provider('coreml'):
|
||||
frame_colorizer.set_providers(execution_provider_set.get('cpu'))
|
||||
|
||||
with thread_semaphore():
|
||||
color_vision_frame = frame_colorizer.run(None,
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user