diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py index 76bb4556..91da50c5 100644 --- a/facefusion/inference_manager.py +++ b/facefusion/inference_manager.py @@ -27,6 +27,10 @@ def get_inference_pool(model_context : str, model_sources : DownloadSet) -> Infe app_context = detect_app_context() inference_context = get_inference_context(model_context) + if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context): + INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context) + if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context): + INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context) if not INFERENCE_POOLS.get(app_context).get(inference_context): execution_provider_keys = resolve_execution_provider_keys(model_context) INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), execution_provider_keys) diff --git a/tests/test_inference_pool.py b/tests/test_inference_pool.py index 70352bb8..563f1df0 100644 --- a/tests/test_inference_pool.py +++ b/tests/test_inference_pool.py @@ -27,3 +27,4 @@ def test_get_inference_pool() -> None: assert isinstance(INFERENCE_POOLS.get('ui').get('test.cpu').get('content_analyser'), InferenceSession) + assert INFERENCE_POOLS.get('cli').get('test.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.cpu').get('content_analyser')