Local model support for deep swapper

This commit is contained in:
henryruhs 2024-12-09 23:23:19 +01:00
parent cc5ee3be33
commit 341545ab94
2 changed files with 38 additions and 4 deletions

View File

@ -1,6 +1,7 @@
from typing import List, Sequence
from facefusion.common_helper import create_float_range, create_int_range
from facefusion.filesystem import list_directory, resolve_relative_path
from facefusion.processors.typing import AgeModifierModel, DeepSwapperModel, ExpressionRestorerModel, FaceDebuggerItem, FaceEditorModel, FaceEnhancerModel, FaceSwapperSet, FrameColorizerModel, FrameEnhancerModel, LipSyncerModel
age_modifier_models : List[AgeModifierModel] = [ 'styleganex_age' ]
@ -153,6 +154,13 @@ deep_swapper_models : List[DeepSwapperModel] =\
'rumateus/sophie_turner_224',
'rumateus/taylor_swift_224'
]
model_files = list_directory(resolve_relative_path('../.assets/models/local'))
for model_file in model_files:
model_id = '/'.join(['local', model_file.get('name') ])
deep_swapper_models.append(model_id)
expression_restorer_models : List[ExpressionRestorerModel] = [ 'live_portrait' ]
face_debugger_items : List[FaceDebuggerItem] = [ 'bounding-box', 'face-landmark-5', 'face-landmark-5/68', 'face-landmark-68', 'face-landmark-68/5', 'face-mask', 'face-detector-score', 'face-landmarker-score', 'age', 'gender', 'race' ]
face_editor_models : List[FaceEditorModel] = [ 'live_portrait' ]

View File

@ -4,6 +4,7 @@ from typing import List, Tuple
import cv2
import numpy
from cv2.typing import Size
import facefusion.jobs.job_manager
import facefusion.jobs.job_store
@ -16,7 +17,7 @@ from facefusion.face_helper import paste_back, warp_face_by_face_landmark_5
from facefusion.face_masker import create_occlusion_mask, create_static_box_mask
from facefusion.face_selector import find_similar_faces, sort_and_filter_faces
from facefusion.face_store import get_reference_faces
from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension
from facefusion.filesystem import in_directory, is_image, is_video, list_directory, resolve_relative_path, same_file_extension
from facefusion.processors import choices as processors_choices
from facefusion.processors.typing import DeepSwapperInputs, DeepSwapperMorph
from facefusion.program_helper import find_argument_group
@ -214,6 +215,23 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'size': model_size
}
model_files = list_directory(resolve_relative_path('../.assets/models/local'))
if model_files:
for model_file in model_files:
model_id = '/'.join([ 'local', model_file.get('name') ])
model_set[model_id] =\
{
'sources':
{
'deep_swapper':
{
'path': resolve_relative_path(model_file.get('path'))
}
},
'template': 'dfl_whole_face'
}
return model_set
@ -233,6 +251,12 @@ def get_model_options() -> ModelOptions:
return create_static_model_set('full').get(deep_swapper_model)
def get_model_size() -> Size:
deep_swapper = get_inference_pool().get('deep_swapper')
model_size = deep_swapper.get_outputs()[-1].shape[1:3]
return model_size
def register_args(program : ArgumentParser) -> None:
group_processors = find_argument_group(program, 'processors')
if group_processors:
@ -250,7 +274,9 @@ def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
if model_hashes and model_sources:
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return True
def pre_process(mode : ProcessMode) -> bool:
@ -281,7 +307,7 @@ def post_process() -> None:
def swap_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
model_template = get_model_options().get('template')
model_size = get_model_options().get('size')
model_size = get_model_options().get('size') or get_model_size()
crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, model_size)
crop_vision_frame_raw = crop_vision_frame.copy()
box_mask = create_static_box_mask(crop_vision_frame.shape[:2][::-1], state_manager.get_item('face_mask_blur'), state_manager.get_item('face_mask_padding'))
@ -344,7 +370,7 @@ def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
def prepare_crop_mask(crop_source_mask : Mask, crop_target_mask : Mask) -> Mask:
model_size = get_model_options().get('size')
model_size = get_model_options().get('size') or get_model_size()
blur_size = 6.25
kernel_size = 3
crop_mask = numpy.minimum.reduce([ crop_source_mask, crop_target_mask ])