adaptive color correction
This commit is contained in:
parent
95a63ea7a2
commit
447ca53d54
@ -20,7 +20,7 @@ from facefusion.processors.typing import DeepSwapperInputs
|
||||
from facefusion.program_helper import find_argument_group
|
||||
from facefusion.thread_helper import thread_semaphore
|
||||
from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
|
||||
from facefusion.vision import read_image, read_static_image, write_image
|
||||
from facefusion.vision import adaptive_match_frame_color, read_image, read_static_image, write_image
|
||||
|
||||
MODEL_SET : ModelSet =\
|
||||
{
|
||||
@ -127,8 +127,11 @@ def swap_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFram
|
||||
crop_vision_frame = prepare_crop_frame(crop_vision_frame)
|
||||
crop_vision_frame, crop_source_mask, crop_target_mask = forward(crop_vision_frame)
|
||||
crop_vision_frame = normalize_crop_frame(crop_vision_frame)
|
||||
crop_vision_frame = match_frame_color_with_mask(crop_vision_frame_raw, crop_vision_frame, crop_source_mask, crop_target_mask)
|
||||
crop_masks.append(numpy.maximum.reduce([ feather_crop_mask(crop_source_mask), feather_crop_mask(crop_target_mask) ]).clip(0, 1))
|
||||
crop_vision_frame = adaptive_match_frame_color(crop_vision_frame_raw, crop_vision_frame)
|
||||
crop_source_mask = feather_crop_mask(crop_source_mask)
|
||||
crop_target_mask = feather_crop_mask(crop_target_mask)
|
||||
crop_combine_mask = numpy.maximum.reduce([ crop_source_mask, crop_target_mask ])
|
||||
crop_masks.append(crop_combine_mask)
|
||||
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
|
||||
paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix)
|
||||
return paste_vision_frame
|
||||
@ -167,27 +170,11 @@ def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
|
||||
def feather_crop_mask(crop_source_mask : Mask) -> Mask:
|
||||
model_size = get_model_options().get('size')
|
||||
crop_mask = crop_source_mask.reshape(model_size).clip(0, 1)
|
||||
crop_mask = cv2.erode(crop_mask, numpy.ones((7, 7), numpy.uint8), iterations = 1)
|
||||
crop_mask = cv2.GaussianBlur(crop_mask, (15, 15), 0)
|
||||
crop_mask = cv2.erode(crop_mask, numpy.ones((5, 5), numpy.uint8), iterations = 1)
|
||||
crop_mask = cv2.GaussianBlur(crop_mask, (7, 7), 0)
|
||||
return crop_mask
|
||||
|
||||
|
||||
def match_frame_color_with_mask(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, source_mask : Mask, target_mask : Mask) -> VisionFrame:
|
||||
target_lab_frame = cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2LAB).astype(numpy.float32) / 255
|
||||
source_lab_frame = cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2LAB).astype(numpy.float32) / 255
|
||||
source_mask = (source_mask > 0.5).astype(numpy.float32)
|
||||
target_mask = (target_mask > 0.5).astype(numpy.float32)
|
||||
target_lab_filter = target_lab_frame * cv2.cvtColor(source_mask, cv2.COLOR_GRAY2BGR)
|
||||
source_lab_filter = source_lab_frame * cv2.cvtColor(target_mask, cv2.COLOR_GRAY2BGR)
|
||||
target_lab_frame -= target_lab_filter.mean(axis = ( 0, 1 ))
|
||||
target_lab_frame /= target_lab_filter.std(axis = ( 0, 1 )) + 1e-6
|
||||
target_lab_frame *= source_lab_filter.std(axis = ( 0, 1 ))
|
||||
target_lab_frame += source_lab_filter.mean(axis = ( 0, 1 ))
|
||||
target_lab_frame = numpy.multiply(target_lab_frame.clip(0, 1), 255).astype(numpy.uint8)
|
||||
target_vision_frame = cv2.cvtColor(target_lab_frame, cv2.COLOR_LAB2BGR)
|
||||
return target_vision_frame
|
||||
|
||||
|
||||
def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
|
||||
return swap_face(target_face, temp_vision_frame)
|
||||
|
||||
|
@ -210,6 +210,12 @@ def normalize_frame_color(vision_frame : VisionFrame) -> VisionFrame:
|
||||
return cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
|
||||
def adaptive_match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
histogram_factor = calc_histogram_difference(source_vision_frame, target_vision_frame)
|
||||
target_vision_frame = blend_vision_frames(target_vision_frame, match_frame_color(source_vision_frame, target_vision_frame), histogram_factor)
|
||||
return target_vision_frame
|
||||
|
||||
|
||||
def match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
color_difference_sizes = numpy.linspace(16, target_vision_frame.shape[0], 3, endpoint = False)
|
||||
|
||||
@ -228,6 +234,18 @@ def equalize_frame_color(source_vision_frame : VisionFrame, target_vision_frame
|
||||
return target_vision_frame
|
||||
|
||||
|
||||
def calc_histogram_difference(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> float:
|
||||
histogram_source = cv2.calcHist([cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
|
||||
histogram_target = cv2.calcHist([cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
|
||||
histogram_differnce = float(numpy.interp(cv2.compareHist(histogram_source, histogram_target, cv2.HISTCMP_CORREL), [ -1, 1 ], [ 0, 1 ]))
|
||||
return histogram_differnce
|
||||
|
||||
|
||||
def blend_vision_frames(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, factor : float) -> VisionFrame:
|
||||
blend_vision_frame = cv2.addWeighted(target_vision_frame, 1 - factor, source_vision_frame, factor, 0)
|
||||
return blend_vision_frame
|
||||
|
||||
|
||||
def create_tile_frames(vision_frame : VisionFrame, size : Size) -> Tuple[List[VisionFrame], int, int]:
|
||||
vision_frame = numpy.pad(vision_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0)))
|
||||
tile_width = size[0] - 2 * size[2]
|
||||
|
Loading…
Reference in New Issue
Block a user