Improved color matching (#800)

* aura fix

* fix import

* move to vision.py

* changes

* changes

* changes

* changes

* further reduction

* add test

* better test

* change name
This commit is contained in:
Harisreedhar 2024-10-29 18:39:48 +05:30 committed by henryruhs
parent efb7cf41ee
commit 04bbb89756
3 changed files with 34 additions and 32 deletions

View File

@ -3,7 +3,6 @@ from typing import Any, List
import cv2
import numpy
from cv2.typing import Size
from numpy.typing import NDArray
import facefusion.jobs.job_manager
@ -22,8 +21,8 @@ from facefusion.processors import choices as processors_choices
from facefusion.processors.typing import AgeModifierInputs
from facefusion.program_helper import find_argument_group
from facefusion.thread_helper import thread_semaphore
from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
from facefusion.vision import read_image, read_static_image, write_image
from facefusion.typing import ApplyStateItem, Args, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
from facefusion.vision import match_frame_color, read_image, read_static_image, write_image
MODEL_SET : ModelSet =\
{
@ -147,7 +146,7 @@ def modify_age(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFra
extend_vision_frame = prepare_vision_frame(extend_vision_frame)
extend_vision_frame = forward(crop_vision_frame, extend_vision_frame)
extend_vision_frame = normalize_extend_frame(extend_vision_frame)
extend_vision_frame = fix_color(extend_vision_frame_raw, extend_vision_frame)
extend_vision_frame = match_frame_color(extend_vision_frame_raw, extend_vision_frame)
extend_affine_matrix *= (model_sizes.get('target')[0] * 4) / model_sizes.get('target_with_background')[0]
crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1)
crop_mask = cv2.resize(crop_mask, (model_sizes.get('target')[0] * 4, model_sizes.get('target')[1] * 4))
@ -173,33 +172,6 @@ def forward(crop_vision_frame : VisionFrame, extend_vision_frame : VisionFrame)
return crop_vision_frame
def fix_color(extend_vision_frame_raw : VisionFrame, extend_vision_frame : VisionFrame) -> VisionFrame:
color_difference = compute_color_difference(extend_vision_frame_raw, extend_vision_frame, (48, 48))
color_difference_mask = create_static_box_mask(extend_vision_frame.shape[:2][::-1], 1.0, (0, 0, 0, 0))
color_difference_mask = numpy.stack((color_difference_mask, ) * 3, axis = -1)
extend_vision_frame = normalize_color_difference(color_difference, color_difference_mask, extend_vision_frame)
return extend_vision_frame
def compute_color_difference(extend_vision_frame_raw : VisionFrame, extend_vision_frame : VisionFrame, size : Size) -> VisionFrame:
extend_vision_frame_raw = extend_vision_frame_raw.astype(numpy.float32) / 255
extend_vision_frame_raw = cv2.resize(extend_vision_frame_raw, size, interpolation = cv2.INTER_AREA)
extend_vision_frame = extend_vision_frame.astype(numpy.float32) / 255
extend_vision_frame = cv2.resize(extend_vision_frame, size, interpolation = cv2.INTER_AREA)
color_difference = extend_vision_frame_raw - extend_vision_frame
return color_difference
def normalize_color_difference(color_difference : VisionFrame, color_difference_mask : Mask, extend_vision_frame : VisionFrame) -> VisionFrame:
color_difference = cv2.resize(color_difference, extend_vision_frame.shape[:2][::-1], interpolation = cv2.INTER_CUBIC)
color_difference_mask = 1 - color_difference_mask.clip(0, 0.75)
extend_vision_frame = extend_vision_frame.astype(numpy.float32) / 255
extend_vision_frame += color_difference * color_difference_mask
extend_vision_frame = extend_vision_frame.clip(0, 1)
extend_vision_frame = numpy.multiply(extend_vision_frame, 255).astype(numpy.uint8)
return extend_vision_frame
def prepare_direction(direction : int) -> NDArray[Any]:
direction = numpy.interp(float(direction), [ -100, 100 ], [ 2.5, -2.5 ]) #type:ignore[assignment]
return numpy.array(direction).astype(numpy.float32)

View File

@ -247,3 +247,21 @@ def merge_tile_frames(tile_vision_frames : List[VisionFrame], temp_width : int,
merge_vision_frame[top:bottom, left:right, :] = tile_vision_frame
merge_vision_frame = merge_vision_frame[size[1] : size[1] + temp_height, size[1]: size[1] + temp_width, :]
return merge_vision_frame
def match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
color_difference_sizes = numpy.linspace(16, target_vision_frame.shape[0], 3, endpoint = False)
for color_difference_size in color_difference_sizes:
source_vision_frame = equalize_frame_color(source_vision_frame, target_vision_frame, normalize_resolution((color_difference_size, color_difference_size)))
target_vision_frame = equalize_frame_color(source_vision_frame, target_vision_frame, target_vision_frame.shape[:2][::-1])
return target_vision_frame
def equalize_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, size : Size) -> VisionFrame:
source_frame_resize = cv2.resize(source_vision_frame, size, interpolation = cv2.INTER_AREA).astype(numpy.float32)
target_frame_resize = cv2.resize(target_vision_frame, size, interpolation = cv2.INTER_AREA).astype(numpy.float32)
color_difference_vision_frame = numpy.subtract(source_frame_resize, target_frame_resize)
color_difference_vision_frame = cv2.resize(color_difference_vision_frame, target_vision_frame.shape[:2][::-1], interpolation = cv2.INTER_CUBIC)
target_vision_frame = numpy.add(target_vision_frame, color_difference_vision_frame).clip(0, 255).astype(numpy.uint8)
return target_vision_frame

View File

@ -1,9 +1,10 @@
import subprocess
import cv2
import pytest
from facefusion.download import conditional_download
from facefusion.vision import count_video_frame_total, create_image_resolutions, create_video_resolutions,detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution, get_video_frame, normalize_resolution, pack_resolution, restrict_image_resolution, restrict_video_fps, restrict_video_resolution, unpack_resolution
from facefusion.vision import count_video_frame_total, create_image_resolutions, create_video_resolutions, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution, get_video_frame, match_frame_color, normalize_resolution, pack_resolution, read_image, restrict_image_resolution, restrict_video_fps, restrict_video_resolution, unpack_resolution
from .helper import get_test_example_file, get_test_examples_directory
@ -114,3 +115,14 @@ def test_pack_resolution() -> None:
def test_unpack_resolution() -> None:
assert unpack_resolution('0x0') == (0, 0)
assert unpack_resolution('2x2') == (2, 2)
def test_match_frame_color() -> None:
source_vision_frame = read_image(get_test_example_file('target-1080p.jpg'))
target_vision_frame = cv2.cvtColor(cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2GRAY), cv2.COLOR_GRAY2BGR)
output_vision_frame = match_frame_color(source_vision_frame, target_vision_frame)
histogram_source = cv2.calcHist([ cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV) ], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
histogram_output = cv2.calcHist([ cv2.cvtColor(output_vision_frame, cv2.COLOR_BGR2HSV) ], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ])
cv2.normalize(histogram_source, histogram_source, 0, 1, cv2.NORM_MINMAX)
cv2.normalize(histogram_output, histogram_output, 0, 1, cv2.NORM_MINMAX)
assert cv2.compareHist(histogram_source, histogram_output, cv2.HISTCMP_CORREL) > 0.5