Simswap model (#171)

* Add simswap models

* Add ghost models

* Introduce normed template

* Conditional prepare and normalize for ghost

* Conditional prepare and normalize for ghost

* Get simswap working

* Get simswap working
This commit is contained in:
Henry Ruhs 2023-10-26 16:59:56 +02:00 committed by GitHub
parent 2209109c8f
commit 22c0de3fe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 56 additions and 27 deletions

View File

@ -10,17 +10,23 @@ from facefusion.face_helper import warp_face
from facefusion.typing import Frame, Face, FaceAnalyserDirection, FaceAnalyserAge, FaceAnalyserGender, ModelValue, Kps, Embedding from facefusion.typing import Frame, Face, FaceAnalyserDirection, FaceAnalyserAge, FaceAnalyserGender, ModelValue, Kps, Embedding
from facefusion.utilities import resolve_relative_path, conditional_download from facefusion.utilities import resolve_relative_path, conditional_download
from facefusion.vision import resize_frame_dimension from facefusion.vision import resize_frame_dimension
from facefusion.processors.frame import globals as frame_processors_globals
FACE_ANALYSER = None FACE_ANALYSER = None
THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore() THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore()
THREAD_LOCK : threading.Lock = threading.Lock() THREAD_LOCK : threading.Lock = threading.Lock()
MODELS : Dict[str, ModelValue] =\ MODELS : Dict[str, ModelValue] =\
{ {
'face_recognition_arcface': 'face_recognition_arcface_inswapper':
{ {
'url': 'https://huggingface.co/bluefoxcreation/insightface-retinaface-arcface-model/resolve/main/w600k_r50.onnx', 'url': 'https://huggingface.co/bluefoxcreation/insightface-retinaface-arcface-model/resolve/main/w600k_r50.onnx',
'path': resolve_relative_path('../.assets/models/w600k_r50.onnx') 'path': resolve_relative_path('../.assets/models/w600k_r50.onnx')
}, },
'face_recognition_arcface_simswap':
{
'url': 'https://github.com/harisreedhar/Face-Swappers-ONNX/releases/download/simswap/simswap_arcface_backbone.onnx',
'path': resolve_relative_path('../.assets/models/simswap_arcface_backbone.onnx')
},
'face_detection_yunet': 'face_detection_yunet':
{ {
'url': 'https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx', 'url': 'https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx',
@ -39,11 +45,15 @@ def get_face_analyser() -> Any:
with THREAD_LOCK: with THREAD_LOCK:
if FACE_ANALYSER is None: if FACE_ANALYSER is None:
if frame_processors_globals.face_swapper_model == 'inswapper_128' or frame_processors_globals.face_swapper_model == 'inswapper_128_fp16':
face_recognition_model_path = MODELS.get('face_recognition_arcface_inswapper').get('path')
if frame_processors_globals.face_swapper_model == 'simswap_244':
face_recognition_model_path = MODELS.get('face_recognition_arcface_simswap').get('path')
FACE_ANALYSER =\ FACE_ANALYSER =\
{ {
'face_detector': cv2.FaceDetectorYN.create(MODELS.get('face_detection_yunet').get('path'), None, (0, 0)), 'face_detector': cv2.FaceDetectorYN.create(MODELS.get('face_detection_yunet').get('path'), None, (0, 0)),
'face_recognition': onnxruntime.InferenceSession(MODELS.get('face_recognition_arcface').get('path'), providers = facefusion.globals.execution_providers), 'face_recognition': onnxruntime.InferenceSession(face_recognition_model_path, providers = facefusion.globals.execution_providers),
'gender_age': onnxruntime.InferenceSession(MODELS.get('gender_age').get('path'), providers = facefusion.globals.execution_providers), 'gender_age': onnxruntime.InferenceSession(MODELS.get('gender_age').get('path'), providers = facefusion.globals.execution_providers)
} }
return FACE_ANALYSER return FACE_ANALYSER
@ -57,7 +67,13 @@ def clear_face_analyser() -> Any:
def pre_check() -> bool: def pre_check() -> bool:
if not facefusion.globals.skip_download: if not facefusion.globals.skip_download:
download_directory_path = resolve_relative_path('../.assets/models') download_directory_path = resolve_relative_path('../.assets/models')
model_urls = [ MODELS.get('face_recognition_arcface').get('url'), MODELS.get('face_detection_yunet').get('url'), MODELS.get('gender_age').get('url') ] model_urls =\
[
MODELS.get('face_recognition_arcface_inswapper').get('url'),
MODELS.get('face_recognition_arcface_simswap').get('url'),
MODELS.get('face_detection_yunet').get('url'),
MODELS.get('gender_age').get('url')
]
conditional_download(download_directory_path, model_urls) conditional_download(download_directory_path, model_urls)
return True return True

View File

@ -28,8 +28,9 @@ TEMPLATES : Dict[Template, numpy.ndarray[Any, Any]] =\
def warp_face(temp_frame : Frame, kps : Kps, template : Template, size : Size) -> Tuple[Frame, Matrix]: def warp_face(temp_frame : Frame, kps : Kps, template : Template, size : Size) -> Tuple[Frame, Matrix]:
affine_matrix = cv2.estimateAffinePartial2D(kps, TEMPLATES[template], method = cv2.LMEDS)[0] normed_template = TEMPLATES.get(template) * size[1] / size[0]
crop_frame = cv2.warpAffine(temp_frame, affine_matrix, size) affine_matrix = cv2.estimateAffinePartial2D(kps, normed_template, method = cv2.LMEDS)[0]
crop_frame = cv2.warpAffine(temp_frame, affine_matrix, (size[1], size[1]))
return crop_frame, affine_matrix return crop_frame, affine_matrix

View File

@ -1,5 +1,5 @@
from typing import List from typing import List
face_swapper_models : List[str] = [ 'inswapper_128', 'inswapper_128_fp16' ] face_swapper_models : List[str] = [ 'inswapper_128', 'inswapper_128_fp16', 'simswap_244' ]
face_enhancer_models : List[str] = [ 'codeformer', 'gfpgan_1.2', 'gfpgan_1.3', 'gfpgan_1.4', 'gpen_bfr_512' ] face_enhancer_models : List[str] = [ 'codeformer', 'gfpgan_1.2', 'gfpgan_1.3', 'gfpgan_1.4', 'gpen_bfr_512' ]
frame_enhancer_models : List[str] = [ 'realesrgan_x2plus', 'realesrgan_x4plus', 'realesrnet_x4plus' ] frame_enhancer_models : List[str] = [ 'realesrgan_x2plus', 'realesrgan_x4plus', 'realesrnet_x4plus' ]

View File

@ -97,7 +97,7 @@ def set_options(key : Literal[ 'model' ], value : Any) -> None:
def register_args(program : ArgumentParser) -> None: def register_args(program : ArgumentParser) -> None:
program.add_argument('--face-enhancer-model', help = wording.get('frame_processor_model_help'), dest = 'face_enhancer_model', default = 'gfpgan_1.4', choices = frame_processors_choices.face_enhancer_models) program.add_argument('--face-enhancer-model', help = wording.get('frame_processor_model_help'), dest = 'face_enhancer_model', default = 'gfpgan_1.4', choices = frame_processors_choices.face_enhancer_models)
program.add_argument('--face-enhancer-blend', help = wording.get('frame_processor_blend_help'), dest= 'face_enhancer_blend', type = int, default= 100, choices = range(101), metavar = '[0-100]') program.add_argument('--face-enhancer-blend', help = wording.get('frame_processor_blend_help'), dest= 'face_enhancer_blend', type = int, default= 80, choices = range(101), metavar = '[0-100]')
def apply_args(program : ArgumentParser) -> None: def apply_args(program : ArgumentParser) -> None:

View File

@ -13,7 +13,7 @@ from facefusion.face_analyser import get_one_face, get_many_faces, find_similar_
from facefusion.face_helper import warp_face, paste_back from facefusion.face_helper import warp_face, paste_back
from facefusion.face_reference import get_face_reference, set_face_reference from facefusion.face_reference import get_face_reference, set_face_reference
from facefusion.predictor import clear_predictor from facefusion.predictor import clear_predictor
from facefusion.typing import Face, Frame, Update_Process, ProcessMode, ModelValue, OptionsWithModel from facefusion.typing import Face, Frame, Update_Process, ProcessMode, ModelValue, OptionsWithModel, Embedding
from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done, update_status from facefusion.utilities import conditional_download, resolve_relative_path, is_image, is_video, is_file, is_download_done, update_status
from facefusion.vision import read_image, read_static_image, write_image from facefusion.vision import read_image, read_static_image, write_image
from facefusion.processors.frame import globals as frame_processors_globals from facefusion.processors.frame import globals as frame_processors_globals
@ -38,6 +38,13 @@ MODELS : Dict[str, ModelValue] =\
'path': resolve_relative_path('../.assets/models/inswapper_128_fp16.onnx'), 'path': resolve_relative_path('../.assets/models/inswapper_128_fp16.onnx'),
'template': 'arcface', 'template': 'arcface',
'size': (128, 128) 'size': (128, 128)
},
'simswap_244':
{
'url': 'https://github.com/harisreedhar/Face-Swappers-ONNX/releases/download/simswap/simswap.onnx',
'path': resolve_relative_path('../.assets/models/simswap.onnx'),
'template': 'arcface',
'size': (112, 224)
} }
} }
OPTIONS : Optional[OptionsWithModel] = None OPTIONS : Optional[OptionsWithModel] = None
@ -146,13 +153,14 @@ def swap_face(source_face : Face, target_face : Face, temp_frame : Frame) -> Fra
frame_processor = get_frame_processor() frame_processor = get_frame_processor()
model_template = get_options('model').get('template') model_template = get_options('model').get('template')
model_size = get_options('model').get('size') model_size = get_options('model').get('size')
source_face = prepare_source_face(source_face)
crop_frame, affine_matrix = warp_face(temp_frame, target_face.kps, model_template, model_size) crop_frame, affine_matrix = warp_face(temp_frame, target_face.kps, model_template, model_size)
crop_frame = prepare_crop_frame(crop_frame) crop_frame = prepare_crop_frame(crop_frame)
frame_processor_inputs = {} frame_processor_inputs = {}
for frame_processor_input in frame_processor.get_inputs(): for frame_processor_input in frame_processor.get_inputs():
if frame_processor_input.name == 'source': if frame_processor_input.name == 'source':
frame_processor_inputs[frame_processor_input.name] = source_face frame_processor_inputs[frame_processor_input.name] = prepare_source_face(source_face)
if frame_processor_input.name == 'source_embedding':
frame_processor_inputs[frame_processor_input.name] = prepare_source_embedding(source_face) # type: ignore[assignment]
if frame_processor_input.name == 'target': if frame_processor_input.name == 'target':
frame_processor_inputs[frame_processor_input.name] = crop_frame # type: ignore[assignment] frame_processor_inputs[frame_processor_input.name] = crop_frame # type: ignore[assignment]
crop_frame = frame_processor.run(None, frame_processor_inputs)[0][0] crop_frame = frame_processor.run(None, frame_processor_inputs)[0][0]
@ -168,17 +176,22 @@ def prepare_source_face(source_face : Face) -> Face:
return source_face return source_face
def prepare_source_embedding(source_face : Face) -> Embedding:
source_embedding = source_face.normed_embedding.reshape(1, -1)
return source_embedding
def prepare_crop_frame(crop_frame : Frame) -> Frame: def prepare_crop_frame(crop_frame : Frame) -> Frame:
crop_frame = crop_frame / 255.0 crop_frame = crop_frame / 255.0
crop_frame = crop_frame[:, :, ::-1] crop_frame = crop_frame[:, :, ::-1].transpose(2, 0, 1)
crop_frame = numpy.expand_dims(crop_frame, axis = 0).transpose(0, 3, 1, 2).astype(numpy.float32) crop_frame = numpy.expand_dims(crop_frame, axis = 0).astype(numpy.float32)
return crop_frame return crop_frame
def normalize_crop_frame(crop_frame : Frame) -> Frame: def normalize_crop_frame(crop_frame : Frame) -> Frame:
crop_frame = crop_frame.transpose(1, 2, 0) crop_frame = crop_frame.transpose(1, 2, 0)
crop_frame = (crop_frame * 255.0).round() crop_frame = (crop_frame * 255.0).round()
crop_frame = crop_frame.astype(numpy.uint8)[:, :, ::-1] crop_frame = crop_frame[:, :, ::-1].astype(numpy.uint8)
return crop_frame return crop_frame

View File

@ -89,7 +89,7 @@ def set_options(key : Literal[ 'model' ], value : Any) -> None:
def register_args(program : ArgumentParser) -> None: def register_args(program : ArgumentParser) -> None:
program.add_argument('--frame-enhancer-model', help = wording.get('frame_processor_model_help'), dest = 'frame_enhancer_model', default = 'realesrgan_x2plus', choices = frame_processors_choices.frame_enhancer_models) program.add_argument('--frame-enhancer-model', help = wording.get('frame_processor_model_help'), dest = 'frame_enhancer_model', default = 'realesrgan_x2plus', choices = frame_processors_choices.frame_enhancer_models)
program.add_argument('--frame-enhancer-blend', help = wording.get('frame_processor_blend_help'), dest = 'frame_enhancer_blend', type = int, default = 100, choices = range(101), metavar = '[0-100]') program.add_argument('--frame-enhancer-blend', help = wording.get('frame_processor_blend_help'), dest = 'frame_enhancer_blend', type = int, default = 80, choices = range(101), metavar = '[0-100]')
def apply_args(program : ArgumentParser) -> None: def apply_args(program : ArgumentParser) -> None:

View File

@ -3,6 +3,9 @@ import gradio
import facefusion.globals import facefusion.globals
from facefusion import wording from facefusion import wording
from facefusion.face_analyser import clear_face_analyser
from facefusion.face_cache import clear_faces_cache
from facefusion.face_reference import clear_face_reference
from facefusion.processors.frame.core import load_frame_processor_module from facefusion.processors.frame.core import load_frame_processor_module
from facefusion.processors.frame import globals as frame_processors_globals, choices as frame_processors_choices from facefusion.processors.frame import globals as frame_processors_globals, choices as frame_processors_choices
from facefusion.uis.core import get_ui_component, register_ui_component from facefusion.uis.core import get_ui_component, register_ui_component
@ -78,6 +81,9 @@ def update_face_swapper_model(face_swapper_model : str) -> gradio.Dropdown:
face_swapper_module = load_frame_processor_module('face_swapper') face_swapper_module = load_frame_processor_module('face_swapper')
face_swapper_module.clear_frame_processor() face_swapper_module.clear_frame_processor()
face_swapper_module.set_options('model', face_swapper_module.MODELS[face_swapper_model]) face_swapper_module.set_options('model', face_swapper_module.MODELS[face_swapper_model])
clear_face_analyser()
clear_face_reference()
clear_faces_cache()
if not face_swapper_module.pre_check(): if not face_swapper_module.pre_check():
return gradio.Dropdown() return gradio.Dropdown()
return gradio.Dropdown(value = face_swapper_model) return gradio.Dropdown(value = face_swapper_model)

View File

@ -78,18 +78,6 @@ def listen() -> None:
if component: if component:
for method in [ 'upload', 'change', 'clear' ]: for method in [ 'upload', 'change', 'clear' ]:
getattr(component, method)(update_preview_frame_slider, outputs = PREVIEW_FRAME_SLIDER) getattr(component, method)(update_preview_frame_slider, outputs = PREVIEW_FRAME_SLIDER)
update_component_names : List[ComponentName] =\
[
'face_recognition_dropdown',
'frame_processors_checkbox_group',
'face_swapper_model_dropdown',
'face_enhancer_model_dropdown',
'frame_enhancer_model_dropdown'
]
for component_name in update_component_names:
component = get_ui_component(component_name)
if component:
component.change(update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE)
select_component_names : List[ComponentName] =\ select_component_names : List[ComponentName] =\
[ [
'reference_face_position_gallery', 'reference_face_position_gallery',
@ -103,8 +91,13 @@ def listen() -> None:
component.select(update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE) component.select(update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE)
change_component_names : List[ComponentName] =\ change_component_names : List[ComponentName] =\
[ [
'face_recognition_dropdown',
'reference_face_distance_slider', 'reference_face_distance_slider',
'frame_processors_checkbox_group',
'face_swapper_model_dropdown',
'face_enhancer_model_dropdown',
'face_enhancer_blend_slider', 'face_enhancer_blend_slider',
'frame_enhancer_model_dropdown',
'frame_enhancer_blend_slider' 'frame_enhancer_blend_slider'
] ]
for component_name in change_component_names: for component_name in change_component_names: