adaptive color correction

This commit is contained in:
harisreedhar 2024-11-04 17:26:50 +05:30 committed by henryruhs
parent 95a63ea7a2
commit 447ca53d54
2 changed files with 26 additions and 21 deletions

View File

@ -20,7 +20,7 @@ from facefusion.processors.typing import DeepSwapperInputs
from facefusion.program_helper import find_argument_group from facefusion.program_helper import find_argument_group
from facefusion.thread_helper import thread_semaphore from facefusion.thread_helper import thread_semaphore
from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
from facefusion.vision import read_image, read_static_image, write_image from facefusion.vision import adaptive_match_frame_color, read_image, read_static_image, write_image
MODEL_SET : ModelSet =\ MODEL_SET : ModelSet =\
{ {
@ -127,8 +127,11 @@ def swap_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFram
crop_vision_frame = prepare_crop_frame(crop_vision_frame) crop_vision_frame = prepare_crop_frame(crop_vision_frame)
crop_vision_frame, crop_source_mask, crop_target_mask = forward(crop_vision_frame) crop_vision_frame, crop_source_mask, crop_target_mask = forward(crop_vision_frame)
crop_vision_frame = normalize_crop_frame(crop_vision_frame) crop_vision_frame = normalize_crop_frame(crop_vision_frame)
crop_vision_frame = match_frame_color_with_mask(crop_vision_frame_raw, crop_vision_frame, crop_source_mask, crop_target_mask) crop_vision_frame = adaptive_match_frame_color(crop_vision_frame_raw, crop_vision_frame)
crop_masks.append(numpy.maximum.reduce([ feather_crop_mask(crop_source_mask), feather_crop_mask(crop_target_mask) ]).clip(0, 1)) crop_source_mask = feather_crop_mask(crop_source_mask)
crop_target_mask = feather_crop_mask(crop_target_mask)
crop_combine_mask = numpy.maximum.reduce([ crop_source_mask, crop_target_mask ])
crop_masks.append(crop_combine_mask)
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix) paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix)
return paste_vision_frame return paste_vision_frame
@ -167,27 +170,11 @@ def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
def feather_crop_mask(crop_source_mask : Mask) -> Mask: def feather_crop_mask(crop_source_mask : Mask) -> Mask:
model_size = get_model_options().get('size') model_size = get_model_options().get('size')
crop_mask = crop_source_mask.reshape(model_size).clip(0, 1) crop_mask = crop_source_mask.reshape(model_size).clip(0, 1)
crop_mask = cv2.erode(crop_mask, numpy.ones((7, 7), numpy.uint8), iterations = 1) crop_mask = cv2.erode(crop_mask, numpy.ones((5, 5), numpy.uint8), iterations = 1)
crop_mask = cv2.GaussianBlur(crop_mask, (15, 15), 0) crop_mask = cv2.GaussianBlur(crop_mask, (7, 7), 0)
return crop_mask return crop_mask
def match_frame_color_with_mask(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, source_mask : Mask, target_mask : Mask) -> VisionFrame:
target_lab_frame = cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2LAB).astype(numpy.float32) / 255
source_lab_frame = cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2LAB).astype(numpy.float32) / 255
source_mask = (source_mask > 0.5).astype(numpy.float32)
target_mask = (target_mask > 0.5).astype(numpy.float32)
target_lab_filter = target_lab_frame * cv2.cvtColor(source_mask, cv2.COLOR_GRAY2BGR)
source_lab_filter = source_lab_frame * cv2.cvtColor(target_mask, cv2.COLOR_GRAY2BGR)
target_lab_frame -= target_lab_filter.mean(axis = ( 0, 1 ))
target_lab_frame /= target_lab_filter.std(axis = ( 0, 1 )) + 1e-6
target_lab_frame *= source_lab_filter.std(axis = ( 0, 1 ))
target_lab_frame += source_lab_filter.mean(axis = ( 0, 1 ))
target_lab_frame = numpy.multiply(target_lab_frame.clip(0, 1), 255).astype(numpy.uint8)
target_vision_frame = cv2.cvtColor(target_lab_frame, cv2.COLOR_LAB2BGR)
return target_vision_frame
def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
return swap_face(target_face, temp_vision_frame) return swap_face(target_face, temp_vision_frame)

View File

@ -210,6 +210,12 @@ def normalize_frame_color(vision_frame : VisionFrame) -> VisionFrame:
return cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB) return cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB)
def adaptive_match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
histogram_factor = calc_histogram_difference(source_vision_frame, target_vision_frame)
target_vision_frame = blend_vision_frames(target_vision_frame, match_frame_color(source_vision_frame, target_vision_frame), histogram_factor)
return target_vision_frame
def match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: def match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
color_difference_sizes = numpy.linspace(16, target_vision_frame.shape[0], 3, endpoint = False) color_difference_sizes = numpy.linspace(16, target_vision_frame.shape[0], 3, endpoint = False)
@ -228,6 +234,18 @@ def equalize_frame_color(source_vision_frame : VisionFrame, target_vision_frame
return target_vision_frame return target_vision_frame
def calc_histogram_difference(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> float:
histogram_source = cv2.calcHist([cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
histogram_target = cv2.calcHist([cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
histogram_differnce = float(numpy.interp(cv2.compareHist(histogram_source, histogram_target, cv2.HISTCMP_CORREL), [ -1, 1 ], [ 0, 1 ]))
return histogram_differnce
def blend_vision_frames(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, factor : float) -> VisionFrame:
blend_vision_frame = cv2.addWeighted(target_vision_frame, 1 - factor, source_vision_frame, factor, 0)
return blend_vision_frame
def create_tile_frames(vision_frame : VisionFrame, size : Size) -> Tuple[List[VisionFrame], int, int]: def create_tile_frames(vision_frame : VisionFrame, size : Size) -> Tuple[List[VisionFrame], int, int]:
vision_frame = numpy.pad(vision_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0))) vision_frame = numpy.pad(vision_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0)))
tile_width = size[0] - 2 * size[2] tile_width = size[0] - 2 * size[2]