Move clear over to the UI (#825)

This commit is contained in:
Henry Ruhs 2024-11-21 11:02:26 +01:00 committed by GitHub
parent 48440407e2
commit b4f1a0e083
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 100 additions and 95 deletions

View File

@ -101,7 +101,7 @@ def pre_check() -> bool:
def common_pre_check() -> bool:
modules =\
common_modules =\
[
content_analyser,
face_classifier,
@ -112,7 +112,7 @@ def common_pre_check() -> bool:
voice_extractor
]
return all(module.pre_check() for module in modules)
return all(module.pre_check() for module in common_modules)
def processors_pre_check() -> bool:
@ -122,77 +122,23 @@ def processors_pre_check() -> bool:
return True
def conditional_process() -> ErrorCode:
start_time = time()
for processor_module in get_processors_modules(state_manager.get_item('processors')):
if not processor_module.pre_process('output'):
return 2
conditional_append_reference_faces()
if is_image(state_manager.get_item('target_path')):
return process_image(start_time)
if is_video(state_manager.get_item('target_path')):
return process_video(start_time)
return 0
def conditional_append_reference_faces() -> None:
if 'reference' in state_manager.get_item('face_selector_mode') and not get_reference_faces():
source_frames = read_static_images(state_manager.get_item('source_paths'))
source_faces = get_many_faces(source_frames)
source_face = get_average_face(source_faces)
if is_video(state_manager.get_item('target_path')):
reference_frame = get_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number'))
else:
reference_frame = read_image(state_manager.get_item('target_path'))
reference_faces = sort_and_filter_faces(get_many_faces([ reference_frame ]))
reference_face = get_one_face(reference_faces, state_manager.get_item('reference_face_position'))
append_reference_face('origin', reference_face)
if source_face and reference_face:
for processor_module in get_processors_modules(state_manager.get_item('processors')):
abstract_reference_frame = processor_module.get_reference_frame(source_face, reference_face, reference_frame)
if numpy.any(abstract_reference_frame):
abstract_reference_faces = sort_and_filter_faces(get_many_faces([ abstract_reference_frame ]))
abstract_reference_face = get_one_face(abstract_reference_faces, state_manager.get_item('reference_face_position'))
append_reference_face(processor_module.__name__, abstract_reference_face)
def clear_model_sets() -> None:
available_processors = list_directory('facefusion/processors/modules')
def force_download() -> ErrorCode:
common_modules =\
[
content_analyser,
face_classifier,
face_detector,
face_landmarker,
face_recognizer,
face_masker,
face_recognizer,
voice_extractor
]
available_processors = list_directory('facefusion/processors/modules')
processor_modules = get_processors_modules(available_processors)
for module in common_modules + processor_modules:
if hasattr(module, 'create_static_model_set'):
module.create_static_model_set.cache_clear()
def force_download() -> ErrorCode:
available_processors = list_directory('facefusion/processors/modules')
common_modules =\
[
content_analyser,
face_classifier,
face_detector,
face_landmarker,
face_recognizer,
face_masker,
voice_extractor
]
processor_modules = get_processors_modules(available_processors)
for module in common_modules + processor_modules:
if hasattr(module, 'create_model_set'):
for model in module.create_model_set().values():
for model in module.create_static_model_set().values():
model_hashes = model.get('hashes')
model_sources = model.get('sources')
@ -306,19 +252,6 @@ def route_job_runner() -> ErrorCode:
return 2
def process_step(job_id : str, step_index : int, step_args : Args) -> bool:
clear_reference_faces()
step_total = job_manager.count_step_total(job_id)
step_args.update(collect_job_args())
apply_args(step_args, state_manager.set_item)
logger.info(wording.get('processing_step').format(step_current = step_index + 1, step_total = step_total), __name__)
if common_pre_check() and processors_pre_check():
error_code = conditional_process()
return error_code == 0
return False
def process_headless(args : Args) -> ErrorCode:
job_id = job_helper.suggest_job_id('headless')
step_args = reduce_step_args(args)
@ -357,6 +290,54 @@ def process_batch(args : Args) -> ErrorCode:
return 1
def process_step(job_id : str, step_index : int, step_args : Args) -> bool:
clear_reference_faces()
step_total = job_manager.count_step_total(job_id)
step_args.update(collect_job_args())
apply_args(step_args, state_manager.set_item)
logger.info(wording.get('processing_step').format(step_current = step_index + 1, step_total = step_total), __name__)
if common_pre_check() and processors_pre_check():
error_code = conditional_process()
return error_code == 0
return False
def conditional_process() -> ErrorCode:
start_time = time()
for processor_module in get_processors_modules(state_manager.get_item('processors')):
if not processor_module.pre_process('output'):
return 2
conditional_append_reference_faces()
if is_image(state_manager.get_item('target_path')):
return process_image(start_time)
if is_video(state_manager.get_item('target_path')):
return process_video(start_time)
return 0
def conditional_append_reference_faces() -> None:
if 'reference' in state_manager.get_item('face_selector_mode') and not get_reference_faces():
source_frames = read_static_images(state_manager.get_item('source_paths'))
source_faces = get_many_faces(source_frames)
source_face = get_average_face(source_faces)
if is_video(state_manager.get_item('target_path')):
reference_frame = get_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number'))
else:
reference_frame = read_image(state_manager.get_item('target_path'))
reference_faces = sort_and_filter_faces(get_many_faces([ reference_frame ]))
reference_face = get_one_face(reference_faces, state_manager.get_item('reference_face_position'))
append_reference_face('origin', reference_face)
if source_face and reference_face:
for processor_module in get_processors_modules(state_manager.get_item('processors')):
abstract_reference_frame = processor_module.get_reference_frame(source_face, reference_face, reference_frame)
if numpy.any(abstract_reference_frame):
abstract_reference_faces = sort_and_filter_faces(get_many_faces([ abstract_reference_frame ]))
abstract_reference_face = get_one_face(abstract_reference_faces, state_manager.get_item('reference_face_position'))
append_reference_face(processor_module.__name__, abstract_reference_face)
def process_image(start_time : float) -> ErrorCode:
if analyse_image(state_manager.get_item('target_path')):
return 3

View File

@ -53,12 +53,6 @@ def get_processors_modules(processors : List[str]) -> List[ModuleType]:
return processor_modules
def clear_processors_modules(processors : List[str]) -> None:
for processor in processors:
processor_module = load_processor_module(processor)
processor_module.clear_inference_pool()
def multi_process_frames(source_paths : List[str], temp_frame_paths : List[str], process_frames : ProcessFrames) -> None:
queue_payloads = create_queue_payloads(temp_frame_paths)
with tqdm(total = len(queue_payloads), desc = wording.get('processing'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:

View File

@ -2,9 +2,10 @@ from typing import List, Optional
import gradio
from facefusion import state_manager, wording
from facefusion import content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, state_manager, voice_extractor, wording
from facefusion.choices import download_provider_set
from facefusion.core import clear_model_sets
from facefusion.filesystem import list_directory
from facefusion.processors.core import get_processors_modules
from facefusion.typing import DownloadProviderKey
DOWNLOAD_PROVIDERS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None
@ -25,7 +26,23 @@ def listen() -> None:
def update_download_providers(download_providers : List[DownloadProviderKey]) -> gradio.CheckboxGroup:
clear_model_sets()
common_modules =\
[
content_analyser,
face_classifier,
face_detector,
face_landmarker,
face_recognizer,
face_masker,
voice_extractor
]
available_processors = list_directory('facefusion/processors/modules')
processor_modules = get_processors_modules(available_processors)
for module in common_modules + processor_modules:
if hasattr(module, 'create_static_model_set'):
module.create_static_model_set.cache_clear()
download_providers = download_providers or list(download_provider_set.keys())
state_manager.set_item('download_providers', download_providers)
return gradio.CheckboxGroup(value = state_manager.get_item('download_providers'))

View File

@ -4,7 +4,8 @@ import gradio
from facefusion import content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, state_manager, voice_extractor, wording
from facefusion.execution import get_execution_provider_set
from facefusion.processors.core import clear_processors_modules
from facefusion.filesystem import list_directory
from facefusion.processors.core import get_processors_modules
from facefusion.typing import ExecutionProviderKey
EXECUTION_PROVIDERS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None
@ -25,14 +26,23 @@ def listen() -> None:
def update_execution_providers(execution_providers : List[ExecutionProviderKey]) -> gradio.CheckboxGroup:
content_analyser.clear_inference_pool()
face_classifier.clear_inference_pool()
face_detector.clear_inference_pool()
face_landmarker.clear_inference_pool()
face_masker.clear_inference_pool()
face_recognizer.clear_inference_pool()
voice_extractor.clear_inference_pool()
clear_processors_modules(state_manager.get_item('processors'))
common_modules =\
[
content_analyser,
face_classifier,
face_detector,
face_landmarker,
face_masker,
face_recognizer,
voice_extractor
]
available_processors = list_directory('facefusion/processors/modules')
processor_modules = get_processors_modules(available_processors)
for module in common_modules + processor_modules:
if hasattr(module, 'clear_inference_pool'):
module.clear_inference_pool()
execution_providers = execution_providers or list(get_execution_provider_set())
state_manager.set_item('execution_providers', execution_providers)
return gradio.CheckboxGroup(value = state_manager.get_item('execution_providers'))

View File

@ -4,7 +4,7 @@ import gradio
from facefusion import state_manager, wording
from facefusion.filesystem import list_directory
from facefusion.processors.core import clear_processors_modules, get_processors_modules
from facefusion.processors.core import get_processors_modules
from facefusion.uis.core import register_ui_component
PROCESSORS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None
@ -26,12 +26,15 @@ def listen() -> None:
def update_processors(processors : List[str]) -> gradio.CheckboxGroup:
clear_processors_modules(state_manager.get_item('processors'))
state_manager.set_item('processors', processors)
for processor_module in get_processors_modules(state_manager.get_item('processors')):
if hasattr(processor_module, 'clear_inference_pool'):
processor_module.clear_inference_pool()
for processor_module in get_processors_modules(processors):
if not processor_module.pre_check():
return gradio.CheckboxGroup()
state_manager.set_item('processors', processors)
return gradio.CheckboxGroup(value = state_manager.get_item('processors'), choices = sort_processors(state_manager.get_item('processors')))