diff --git a/facefusion/vision.py b/facefusion/vision.py index 2b7c9f62..26021e32 100644 --- a/facefusion/vision.py +++ b/facefusion/vision.py @@ -210,6 +210,24 @@ def normalize_frame_color(vision_frame : VisionFrame) -> VisionFrame: return cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB) +def match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: + color_difference_sizes = numpy.linspace(16, target_vision_frame.shape[0], 3, endpoint = False) + + for color_difference_size in color_difference_sizes: + source_vision_frame = equalize_frame_color(source_vision_frame, target_vision_frame, normalize_resolution((color_difference_size, color_difference_size))) + target_vision_frame = equalize_frame_color(source_vision_frame, target_vision_frame, target_vision_frame.shape[:2][::-1]) + return target_vision_frame + + +def equalize_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, size : Size) -> VisionFrame: + source_frame_resize = cv2.resize(source_vision_frame, size, interpolation = cv2.INTER_AREA).astype(numpy.float32) + target_frame_resize = cv2.resize(target_vision_frame, size, interpolation = cv2.INTER_AREA).astype(numpy.float32) + color_difference_vision_frame = numpy.subtract(source_frame_resize, target_frame_resize) + color_difference_vision_frame = cv2.resize(color_difference_vision_frame, target_vision_frame.shape[:2][::-1], interpolation = cv2.INTER_CUBIC) + target_vision_frame = numpy.add(target_vision_frame, color_difference_vision_frame).clip(0, 255).astype(numpy.uint8) + return target_vision_frame + + def create_tile_frames(vision_frame : VisionFrame, size : Size) -> Tuple[List[VisionFrame], int, int]: vision_frame = numpy.pad(vision_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0))) tile_width = size[0] - 2 * size[2] @@ -247,21 +265,3 @@ def merge_tile_frames(tile_vision_frames : List[VisionFrame], temp_width : int, merge_vision_frame[top:bottom, left:right, :] = tile_vision_frame merge_vision_frame = merge_vision_frame[size[1] : size[1] + temp_height, size[1]: size[1] + temp_width, :] return merge_vision_frame - - -def match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: - color_difference_sizes = numpy.linspace(16, target_vision_frame.shape[0], 3, endpoint = False) - - for color_difference_size in color_difference_sizes: - source_vision_frame = equalize_frame_color(source_vision_frame, target_vision_frame, normalize_resolution((color_difference_size, color_difference_size))) - target_vision_frame = equalize_frame_color(source_vision_frame, target_vision_frame, target_vision_frame.shape[:2][::-1]) - return target_vision_frame - - -def equalize_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, size : Size) -> VisionFrame: - source_frame_resize = cv2.resize(source_vision_frame, size, interpolation = cv2.INTER_AREA).astype(numpy.float32) - target_frame_resize = cv2.resize(target_vision_frame, size, interpolation = cv2.INTER_AREA).astype(numpy.float32) - color_difference_vision_frame = numpy.subtract(source_frame_resize, target_frame_resize) - color_difference_vision_frame = cv2.resize(color_difference_vision_frame, target_vision_frame.shape[:2][::-1], interpolation = cv2.INTER_CUBIC) - target_vision_frame = numpy.add(target_vision_frame, color_difference_vision_frame).clip(0, 255).astype(numpy.uint8) - return target_vision_frame diff --git a/tests/test_vision.py b/tests/test_vision.py index 418ce7af..9f5e770c 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -123,6 +123,5 @@ def test_match_frame_color() -> None: output_vision_frame = match_frame_color(source_vision_frame, target_vision_frame) histogram_source = cv2.calcHist([ cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV) ], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) histogram_output = cv2.calcHist([ cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2HSV) ], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) - cv2.normalize(histogram_source, histogram_source, 0, 1, cv2.NORM_MINMAX) - cv2.normalize(histogram_output, histogram_output, 0, 1, cv2.NORM_MINMAX) + assert cv2.compareHist(histogram_source, histogram_output, cv2.HISTCMP_CORREL) > 0.5