adaptive color correction

This commit is contained in:
harisreedhar 2024-11-04 17:26:50 +05:30 committed by henryruhs
parent 95a63ea7a2
commit 447ca53d54
2 changed files with 26 additions and 21 deletions

View File

@ -20,7 +20,7 @@ from facefusion.processors.typing import DeepSwapperInputs
from facefusion.program_helper import find_argument_group
from facefusion.thread_helper import thread_semaphore
from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
from facefusion.vision import read_image, read_static_image, write_image
from facefusion.vision import adaptive_match_frame_color, read_image, read_static_image, write_image
MODEL_SET : ModelSet =\
{
@ -127,8 +127,11 @@ def swap_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFram
crop_vision_frame = prepare_crop_frame(crop_vision_frame)
crop_vision_frame, crop_source_mask, crop_target_mask = forward(crop_vision_frame)
crop_vision_frame = normalize_crop_frame(crop_vision_frame)
crop_vision_frame = match_frame_color_with_mask(crop_vision_frame_raw, crop_vision_frame, crop_source_mask, crop_target_mask)
crop_masks.append(numpy.maximum.reduce([ feather_crop_mask(crop_source_mask), feather_crop_mask(crop_target_mask) ]).clip(0, 1))
crop_vision_frame = adaptive_match_frame_color(crop_vision_frame_raw, crop_vision_frame)
crop_source_mask = feather_crop_mask(crop_source_mask)
crop_target_mask = feather_crop_mask(crop_target_mask)
crop_combine_mask = numpy.maximum.reduce([ crop_source_mask, crop_target_mask ])
crop_masks.append(crop_combine_mask)
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix)
return paste_vision_frame
@ -167,27 +170,11 @@ def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
def feather_crop_mask(crop_source_mask : Mask) -> Mask:
model_size = get_model_options().get('size')
crop_mask = crop_source_mask.reshape(model_size).clip(0, 1)
crop_mask = cv2.erode(crop_mask, numpy.ones((7, 7), numpy.uint8), iterations = 1)
crop_mask = cv2.GaussianBlur(crop_mask, (15, 15), 0)
crop_mask = cv2.erode(crop_mask, numpy.ones((5, 5), numpy.uint8), iterations = 1)
crop_mask = cv2.GaussianBlur(crop_mask, (7, 7), 0)
return crop_mask
def match_frame_color_with_mask(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, source_mask : Mask, target_mask : Mask) -> VisionFrame:
target_lab_frame = cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2LAB).astype(numpy.float32) / 255
source_lab_frame = cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2LAB).astype(numpy.float32) / 255
source_mask = (source_mask > 0.5).astype(numpy.float32)
target_mask = (target_mask > 0.5).astype(numpy.float32)
target_lab_filter = target_lab_frame * cv2.cvtColor(source_mask, cv2.COLOR_GRAY2BGR)
source_lab_filter = source_lab_frame * cv2.cvtColor(target_mask, cv2.COLOR_GRAY2BGR)
target_lab_frame -= target_lab_filter.mean(axis = ( 0, 1 ))
target_lab_frame /= target_lab_filter.std(axis = ( 0, 1 )) + 1e-6
target_lab_frame *= source_lab_filter.std(axis = ( 0, 1 ))
target_lab_frame += source_lab_filter.mean(axis = ( 0, 1 ))
target_lab_frame = numpy.multiply(target_lab_frame.clip(0, 1), 255).astype(numpy.uint8)
target_vision_frame = cv2.cvtColor(target_lab_frame, cv2.COLOR_LAB2BGR)
return target_vision_frame
def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame:
return swap_face(target_face, temp_vision_frame)

View File

@ -210,6 +210,12 @@ def normalize_frame_color(vision_frame : VisionFrame) -> VisionFrame:
return cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB)
def adaptive_match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
histogram_factor = calc_histogram_difference(source_vision_frame, target_vision_frame)
target_vision_frame = blend_vision_frames(target_vision_frame, match_frame_color(source_vision_frame, target_vision_frame), histogram_factor)
return target_vision_frame
def match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
color_difference_sizes = numpy.linspace(16, target_vision_frame.shape[0], 3, endpoint = False)
@ -228,6 +234,18 @@ def equalize_frame_color(source_vision_frame : VisionFrame, target_vision_frame
return target_vision_frame
def calc_histogram_difference(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> float:
histogram_source = cv2.calcHist([cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
histogram_target = cv2.calcHist([cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
histogram_differnce = float(numpy.interp(cv2.compareHist(histogram_source, histogram_target, cv2.HISTCMP_CORREL), [ -1, 1 ], [ 0, 1 ]))
return histogram_differnce
def blend_vision_frames(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, factor : float) -> VisionFrame:
blend_vision_frame = cv2.addWeighted(target_vision_frame, 1 - factor, source_vision_frame, factor, 0)
return blend_vision_frame
def create_tile_frames(vision_frame : VisionFrame, size : Size) -> Tuple[List[VisionFrame], int, int]:
vision_frame = numpy.pad(vision_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0)))
tile_width = size[0] - 2 * size[2]