This commit is contained in:
harisreedhar 2024-11-06 22:19:42 +05:30 committed by henryruhs
parent feaaf6c028
commit ca9ccbfd35
2 changed files with 16 additions and 10 deletions

View File

@ -33,17 +33,17 @@ def calc_face_distance(face : Face, reference_face : Face) -> float:
def sort_and_filter_faces(faces : List[Face]) -> List[Face]:
if faces:
if state_manager.get_item('face_selector_order'):
faces = sort_by_order(faces, state_manager.get_item('face_selector_order'))
faces = sort_faces_by_order(faces, state_manager.get_item('face_selector_order'))
if state_manager.get_item('face_selector_gender'):
faces = filter_by_gender(faces, state_manager.get_item('face_selector_gender'))
faces = filter_faces_by_gender(faces, state_manager.get_item('face_selector_gender'))
if state_manager.get_item('face_selector_race'):
faces = filter_by_race(faces, state_manager.get_item('face_selector_race'))
faces = filter_faces_by_race(faces, state_manager.get_item('face_selector_race'))
if state_manager.get_item('face_selector_age_start') or state_manager.get_item('face_selector_age_end'):
faces = filter_by_age(faces, state_manager.get_item('face_selector_age_start'), state_manager.get_item('face_selector_age_end'))
faces = filter_faces_by_age(faces, state_manager.get_item('face_selector_age_start'), state_manager.get_item('face_selector_age_end'))
return faces
def sort_by_order(faces : List[Face], order : FaceSelectorOrder) -> List[Face]:
def sort_faces_by_order(faces : List[Face], order : FaceSelectorOrder) -> List[Face]:
if order == 'left-right':
return sorted(faces, key = lambda face: face.bounding_box[0])
if order == 'right-left':
@ -63,7 +63,7 @@ def sort_by_order(faces : List[Face], order : FaceSelectorOrder) -> List[Face]:
return faces
def filter_by_gender(faces : List[Face], gender : Gender) -> List[Face]:
def filter_faces_by_gender(faces : List[Face], gender : Gender) -> List[Face]:
filter_faces = []
for face in faces:
@ -72,7 +72,7 @@ def filter_by_gender(faces : List[Face], gender : Gender) -> List[Face]:
return filter_faces
def filter_by_age(faces : List[Face], face_selector_age_start : int, face_selector_age_end : int) -> List[Face]:
def filter_faces_by_age(faces : List[Face], face_selector_age_start : int, face_selector_age_end : int) -> List[Face]:
filter_faces = []
age = range(face_selector_age_start, face_selector_age_end)
@ -82,7 +82,7 @@ def filter_by_age(faces : List[Face], face_selector_age_start : int, face_select
return filter_faces
def filter_by_race(faces : List[Face], race : Race) -> List[Face]:
def filter_faces_by_race(faces : List[Face], race : Race) -> List[Face]:
filter_faces = []
for face in faces:

View File

@ -13,7 +13,7 @@ from facefusion.execution import has_execution_provider
from facefusion.face_analyser import get_average_face, get_many_faces, get_one_face
from facefusion.face_helper import paste_back, warp_face_by_face_landmark_5
from facefusion.face_masker import create_occlusion_mask, create_region_mask, create_static_box_mask
from facefusion.face_selector import find_similar_faces, sort_and_filter_faces
from facefusion.face_selector import find_similar_faces, sort_and_filter_faces, sort_faces_by_order
from facefusion.face_store import get_reference_faces
from facefusion.filesystem import filter_image_paths, has_image, in_directory, is_image, is_video, resolve_relative_path, same_file_extension
from facefusion.inference_manager import get_static_model_initializer
@ -564,7 +564,13 @@ def process_frame(inputs : FaceSwapperInputs) -> VisionFrame:
def process_frames(source_paths : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None:
reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None
source_frames = read_static_images(source_paths)
source_faces = get_many_faces(source_frames)
source_faces = []
for source_frame in source_frames:
temp_faces = get_many_faces([ source_frame ])
temp_faces = sort_faces_by_order(temp_faces, 'large-small')
if temp_faces:
source_faces.append(get_first(temp_faces))
source_face = get_average_face(source_faces)
for queue_payload in process_manager.manage(queue_payloads):