API experiment part2 + ugly frontend
This commit is contained in:
parent
382a036f66
commit
098f64916f
@ -1,4 +1,5 @@
|
|||||||
from time import time
|
from typing import Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
@ -15,7 +16,7 @@ from facefusion.typing import AudioFrame, Face, FaceSet, VisionFrame
|
|||||||
from facefusion.vision import get_video_frame, read_static_image, read_static_images, resize_frame_resolution
|
from facefusion.vision import get_video_frame, read_static_image, read_static_images, resize_frame_resolution
|
||||||
|
|
||||||
|
|
||||||
def process_frame(frame_number : int = 0) -> None:
|
def process_frame(frame_number : int = 0) -> Optional[VisionFrame]:
|
||||||
core.conditional_append_reference_faces()
|
core.conditional_append_reference_faces()
|
||||||
reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None
|
reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None
|
||||||
source_frames = read_static_images(state_manager.get_item('source_paths'))
|
source_frames = read_static_images(state_manager.get_item('source_paths'))
|
||||||
@ -48,6 +49,8 @@ def process_frame(frame_number : int = 0) -> None:
|
|||||||
preview_vision_frame = process_preview_frame(reference_faces, source_face, source_audio_frame, temp_vision_frame)
|
preview_vision_frame = process_preview_frame(reference_faces, source_face, source_audio_frame, temp_vision_frame)
|
||||||
return preview_vision_frame
|
return preview_vision_frame
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def process_preview_frame(reference_faces : FaceSet, source_face : Face, source_audio_frame : AudioFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
def process_preview_frame(reference_faces : FaceSet, source_face : Face, source_audio_frame : AudioFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||||
target_vision_frame = resize_frame_resolution(target_vision_frame, (1024, 1024))
|
target_vision_frame = resize_frame_resolution(target_vision_frame, (1024, 1024))
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import json
|
||||||
from io import BytesIO
|
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from litestar import Litestar, WebSocket, get as read, websocket as stream
|
from litestar import Litestar, WebSocket, get as read, websocket as stream, websocket_listener
|
||||||
|
from litestar.static_files import create_static_files_router
|
||||||
|
|
||||||
from facefusion import choices, execution, _preview
|
from facefusion import _preview, choices, execution, state_manager, vision
|
||||||
from facefusion.processors import choices as processors_choices
|
from facefusion.processors import choices as processors_choices
|
||||||
from facefusion.state_manager import get_state
|
from facefusion.state_manager import get_state
|
||||||
from facefusion.typing import ExecutionDevice
|
from facefusion.typing import ExecutionDevice
|
||||||
@ -47,7 +47,7 @@ async def read_execution_providers() -> Any:
|
|||||||
|
|
||||||
|
|
||||||
@stream('/execution/devices')
|
@stream('/execution/devices')
|
||||||
async def stream_execution_devices(socket : WebSocket) -> None:
|
async def stream_execution_devices(socket : WebSocket[Any, Any, Any]) -> None:
|
||||||
await socket.accept()
|
await socket.accept()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -66,7 +66,7 @@ async def read_static_execution_devices() -> List[ExecutionDevice]:
|
|||||||
|
|
||||||
|
|
||||||
@stream('/state')
|
@stream('/state')
|
||||||
async def stream_state(socket : WebSocket) -> None:
|
async def stream_state(socket : WebSocket[Any, Any, Any]) -> None:
|
||||||
await socket.accept()
|
await socket.accept()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -74,12 +74,30 @@ async def stream_state(socket : WebSocket) -> None:
|
|||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
@read('/preview', media_type = 'image/png')
|
@read('/preview', media_type = 'image/png', mode = "binary")
|
||||||
async def read_preview() -> None:
|
async def read_preview(frame_number : int) -> bytes:
|
||||||
_, preview_vision_frame = cv2.imencode('.png', _preview.process_frame())
|
_, preview_vision_frame = cv2.imencode('.png', _preview.process_frame(frame_number)) #type:ignore
|
||||||
return preview_vision_frame.tobytes()
|
return preview_vision_frame.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_listener("/preview", send_mode = "binary")
|
||||||
|
async def stream_preview(data : str) -> bytes:
|
||||||
|
frame_number = int(json.loads(data).get('frame_number'))
|
||||||
|
_, preview_vision_frame = cv2.imencode('.png', _preview.process_frame(frame_number)) #type:ignore
|
||||||
|
return preview_vision_frame.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
@read('/ui/preview_slider')
|
||||||
|
async def read_ui_preview_slider() -> Any:
|
||||||
|
target_path = state_manager.get_item('target_path')
|
||||||
|
video_frame_total = vision.count_video_frame_total(target_path)
|
||||||
|
|
||||||
|
return\
|
||||||
|
{
|
||||||
|
'video_frame_total': video_frame_total
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
api = Litestar(
|
api = Litestar(
|
||||||
[
|
[
|
||||||
read_choices,
|
read_choices,
|
||||||
@ -88,7 +106,14 @@ api = Litestar(
|
|||||||
read_execution_devices,
|
read_execution_devices,
|
||||||
read_static_execution_devices,
|
read_static_execution_devices,
|
||||||
stream_state,
|
stream_state,
|
||||||
read_preview
|
read_preview,
|
||||||
|
read_ui_preview_slider,
|
||||||
|
stream_preview,
|
||||||
|
create_static_files_router(
|
||||||
|
path = '/frontend',
|
||||||
|
directories = [ 'facefusion/static' ],
|
||||||
|
html_mode = True,
|
||||||
|
)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import itertools
|
|||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import webbrowser
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
@ -61,6 +62,9 @@ def route(args : Args) -> None:
|
|||||||
if not pre_check():
|
if not pre_check():
|
||||||
return conditional_exit(2)
|
return conditional_exit(2)
|
||||||
if state_manager.get_item('command') == 'run':
|
if state_manager.get_item('command') == 'run':
|
||||||
|
if state_manager.get_item('open_browser'):
|
||||||
|
webbrowser.open('http://127.0.0.1:8000/frontend')
|
||||||
|
logger.info('http://127.0.0.1:8000/frontend', __name__)
|
||||||
api.run()
|
api.run()
|
||||||
if state_manager.get_item('command') == 'headless-run':
|
if state_manager.get_item('command') == 'headless-run':
|
||||||
if not job_manager.init_jobs(state_manager.get_item('jobs_path')):
|
if not job_manager.init_jobs(state_manager.get_item('jobs_path')):
|
||||||
|
242
facefusion/static/index.html
Normal file
242
facefusion/static/index.html
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Demo</title>
|
||||||
|
<style>
|
||||||
|
.preview, meter {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
meter { border-radius: 0}
|
||||||
|
img, textarea {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
input[type="range"] {
|
||||||
|
width: 100%;
|
||||||
|
margin-top: 1em;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div style="display:flex">
|
||||||
|
<div class="preview">
|
||||||
|
<h1>Preview</h1>
|
||||||
|
<div class="image-container">
|
||||||
|
<img id="image" alt="Frame Image">
|
||||||
|
</div>
|
||||||
|
<input type="range" id="slider" value="0">
|
||||||
|
<p>Frame: <span id="frameValue">0</span></p>
|
||||||
|
<button id="playBtn">Play</button>
|
||||||
|
<button id="stopBtn" disabled>Stop</button>
|
||||||
|
<input type="checkbox" id="useWebSocket" /> Use WebSocket for Preview
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<h1>Debug</h1>
|
||||||
|
<p>Video Memory <meter id="video_memory" min="0" max="100" value="0"></meter></p>
|
||||||
|
<p>GPU Utilization <meter id="gpu_utilization" min="0" max="100" value="0"></meter></p>
|
||||||
|
<textarea id="debug" rows="1" cols="80" readonly></textarea>
|
||||||
|
<textarea id="devices" rows="4" cols="80" readonly></textarea>
|
||||||
|
<textarea id="state" rows="30" cols="80" readonly></textarea>
|
||||||
|
<textarea id="log" rows="10" cols="80" readonly></textarea>
|
||||||
|
<textarea id="fps" rows="2" cols="80" readonly></textarea>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<script>
|
||||||
|
function createWebSocketConnection(url, debug, container) {
|
||||||
|
const socket = new WebSocket(url);
|
||||||
|
|
||||||
|
socket.onopen = () => {
|
||||||
|
debug.value = `WebSocket connection established for URL: ${url}`;
|
||||||
|
};
|
||||||
|
|
||||||
|
socket.onmessage = event => {
|
||||||
|
debug.value = `WebSocket Event: ${event.type}`;
|
||||||
|
container.value = event.data;
|
||||||
|
};
|
||||||
|
|
||||||
|
socket.onerror = error => {
|
||||||
|
debug.value = `WebSocket Error: ${error}`;
|
||||||
|
};
|
||||||
|
|
||||||
|
socket.onclose = () => {
|
||||||
|
debug.value = `WebSocket connection closed for URL: ${url} -> Reloading Page`;
|
||||||
|
setTimeout(() => location.reload(), 1000);
|
||||||
|
};
|
||||||
|
|
||||||
|
return socket;
|
||||||
|
}
|
||||||
|
|
||||||
|
devicesSocket = createWebSocketConnection('ws://127.0.0.1:8000/execution/devices', debug, devices);
|
||||||
|
createWebSocketConnection('ws://127.0.0.1:8000/state', debug, state);
|
||||||
|
|
||||||
|
devicesSocket.addEventListener('message', event => {
|
||||||
|
const data = JSON.parse(event.data)[0]
|
||||||
|
const freeMemory = data.video_memory.free.value;
|
||||||
|
const totalMemory = data.video_memory.total.value;
|
||||||
|
const usedMemory = totalMemory - freeMemory;
|
||||||
|
const usedMemoryPercentage = (usedMemory / totalMemory) * 100;
|
||||||
|
|
||||||
|
video_memory.value = usedMemoryPercentage;
|
||||||
|
gpu_utilization.value = data.utilization.gpu.value
|
||||||
|
})
|
||||||
|
</script>
|
||||||
|
<script>
|
||||||
|
const slider = document.getElementById('slider');
|
||||||
|
const image = document.getElementById('image');
|
||||||
|
const frameValue = document.getElementById('frameValue');
|
||||||
|
const playBtn = document.getElementById('playBtn');
|
||||||
|
const stopBtn = document.getElementById('stopBtn');
|
||||||
|
const logTextarea = document.getElementById('log');
|
||||||
|
const useWebSocketCheckbox = document.getElementById('useWebSocket');
|
||||||
|
|
||||||
|
let totalFrames = 0;
|
||||||
|
let currentFrame = 0;
|
||||||
|
let totalFps = 0;
|
||||||
|
let requestCount = 0;
|
||||||
|
let isPlaying = false;
|
||||||
|
let socket;
|
||||||
|
|
||||||
|
// Fetch the total frame count and set up the slider
|
||||||
|
async function fetchSliderTotal() {
|
||||||
|
try {
|
||||||
|
const start = performance.now();
|
||||||
|
const response = await fetch('http://127.0.0.1:8000/ui/preview_slider');
|
||||||
|
const end = performance.now();
|
||||||
|
|
||||||
|
logRequest('GET', 'http://127.0.0.1:8000/ui/preview_slider', start, end);
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
|
totalFrames = data.video_frame_total;
|
||||||
|
slider.max = totalFrames;
|
||||||
|
slider.value = 0;
|
||||||
|
frameValue.textContent = 0;
|
||||||
|
image.src = `http://127.0.0.1:8000/preview?frame_number=0`;
|
||||||
|
} else {
|
||||||
|
console.error('Failed to fetch total frame count');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching total frame count:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to log request details to the textarea
|
||||||
|
function logRequest(method, url, startTime, endTime) {
|
||||||
|
const duration = (endTime - startTime).toFixed(2);
|
||||||
|
const logMessage = `${method} ${url} | Duration: ${duration}ms\n`;
|
||||||
|
|
||||||
|
// Append to the log textarea
|
||||||
|
logTextarea.value += logMessage;
|
||||||
|
logTextarea.scrollTop = logTextarea.scrollHeight; // Auto scroll to the bottom
|
||||||
|
}
|
||||||
|
|
||||||
|
function logFps(startTime, endTime) {
|
||||||
|
const duration = (endTime - startTime).toFixed(2);
|
||||||
|
const durationInSeconds = duration / 1000; // Convert ms to seconds
|
||||||
|
const fps = (1 / durationInSeconds).toFixed(2); // FPS for this request
|
||||||
|
|
||||||
|
// Update total FPS and request count
|
||||||
|
totalFps += parseFloat(fps);
|
||||||
|
requestCount++;
|
||||||
|
|
||||||
|
// Calculate average FPS
|
||||||
|
const averageFps = (totalFps / requestCount).toFixed(2);
|
||||||
|
|
||||||
|
// Update the textarea with id 'fps' to show the average FPS
|
||||||
|
const fpsTextarea = document.getElementById('fps');
|
||||||
|
fpsTextarea.value = `Average FPS: ${averageFps}\n`;
|
||||||
|
|
||||||
|
// Optionally, you can append the current FPS to the textarea as well:
|
||||||
|
fpsTextarea.value += `Current FPS: ${fps}\n`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to update the image based on the slider's value or WebSocket message
|
||||||
|
function updateImage() {
|
||||||
|
const frameNumber = slider.value;
|
||||||
|
|
||||||
|
|
||||||
|
// If WebSocket is enabled, use WebSocket to fetch the image
|
||||||
|
if (useWebSocketCheckbox.checked) {
|
||||||
|
image.onload = null
|
||||||
|
if (!socket) {
|
||||||
|
socket = new WebSocket('ws://127.0.0.1:8000/preview');
|
||||||
|
}
|
||||||
|
const start = performance.now();
|
||||||
|
socket.send(JSON.stringify({ frame_number: frameNumber }));
|
||||||
|
socket.onmessage = function (event) {
|
||||||
|
const end = performance.now();
|
||||||
|
logRequest('WEBSOCKET', 'ws://127.0.0.1:8000/preview', start, end);
|
||||||
|
logFps(start, end)
|
||||||
|
|
||||||
|
// Create a Blob URL from the WebSocket message (assumed to be a Blob)
|
||||||
|
const imageUrl = URL.createObjectURL(event.data);
|
||||||
|
|
||||||
|
// Set the image source to the Blob URL
|
||||||
|
image.src = imageUrl;
|
||||||
|
frameValue.textContent = frameNumber;
|
||||||
|
|
||||||
|
// Continue if playing
|
||||||
|
if (isPlaying && currentFrame < totalFrames) {
|
||||||
|
currentFrame++;
|
||||||
|
slider.value = currentFrame;
|
||||||
|
updateImage(); // Continue to next frame
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
socket = null
|
||||||
|
// Use default fetch for the image
|
||||||
|
const start = performance.now();
|
||||||
|
image.src = `http://127.0.0.1:8000/preview?frame_number=${frameNumber}`;
|
||||||
|
image.onload = function () {
|
||||||
|
const end = performance.now();
|
||||||
|
logRequest('GET', image.src, start, end);
|
||||||
|
logFps(start, end)
|
||||||
|
frameValue.textContent = frameNumber;
|
||||||
|
|
||||||
|
// Continue if playing
|
||||||
|
if (isPlaying && currentFrame < totalFrames) {
|
||||||
|
currentFrame++;
|
||||||
|
slider.value = currentFrame;
|
||||||
|
updateImage(); // Continue to next frame
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to start the play action (without setInterval, only on image load)
|
||||||
|
function startPlay() {
|
||||||
|
playBtn.disabled = true;
|
||||||
|
stopBtn.disabled = false;
|
||||||
|
isPlaying = true;
|
||||||
|
currentFrame = parseInt(slider.value, 10);
|
||||||
|
|
||||||
|
// Start loading the first image
|
||||||
|
updateImage();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to stop the play action
|
||||||
|
function stopPlay() {
|
||||||
|
isPlaying = false;
|
||||||
|
playBtn.disabled = false;
|
||||||
|
stopBtn.disabled = true;
|
||||||
|
if (socket) {
|
||||||
|
socket.close(); // Close WebSocket when stopping
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Event listeners for Play/Stop buttons
|
||||||
|
playBtn.addEventListener('click', startPlay);
|
||||||
|
stopBtn.addEventListener('click', stopPlay);
|
||||||
|
|
||||||
|
// Slider manual update
|
||||||
|
slider.addEventListener('change', function () {
|
||||||
|
currentFrame = slider.value;
|
||||||
|
updateImage();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Fetch the total number of frames when the page loads
|
||||||
|
window.onload = fetchSliderTotal;
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
@ -1,6 +1,5 @@
|
|||||||
filetype==1.2.0
|
filetype==1.2.0
|
||||||
gradio==5.7.1
|
litestar==2.13.0
|
||||||
gradio-rangeslider==0.0.8
|
|
||||||
numpy==2.1.3
|
numpy==2.1.3
|
||||||
onnx==1.17.0
|
onnx==1.17.0
|
||||||
onnxruntime==1.20.1
|
onnxruntime==1.20.1
|
||||||
|
Loading…
Reference in New Issue
Block a user