API experiment part2 + ugly frontend

This commit is contained in:
henryruhs 2024-12-09 14:29:01 +01:00
parent 382a036f66
commit 098f64916f
5 changed files with 287 additions and 14 deletions

View File

@ -1,4 +1,5 @@
from time import time
from typing import Optional
import cv2
import numpy
@ -15,7 +16,7 @@ from facefusion.typing import AudioFrame, Face, FaceSet, VisionFrame
from facefusion.vision import get_video_frame, read_static_image, read_static_images, resize_frame_resolution
def process_frame(frame_number : int = 0) -> None:
def process_frame(frame_number : int = 0) -> Optional[VisionFrame]:
core.conditional_append_reference_faces()
reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None
source_frames = read_static_images(state_manager.get_item('source_paths'))
@ -48,6 +49,8 @@ def process_frame(frame_number : int = 0) -> None:
preview_vision_frame = process_preview_frame(reference_faces, source_face, source_audio_frame, temp_vision_frame)
return preview_vision_frame
return None
def process_preview_frame(reference_faces : FaceSet, source_face : Face, source_audio_frame : AudioFrame, target_vision_frame : VisionFrame) -> VisionFrame:
target_vision_frame = resize_frame_resolution(target_vision_frame, (1024, 1024))

View File

@ -1,13 +1,13 @@
import asyncio
import time
from io import BytesIO
import json
from typing import Any, List
import cv2
import uvicorn
from litestar import Litestar, WebSocket, get as read, websocket as stream
from litestar import Litestar, WebSocket, get as read, websocket as stream, websocket_listener
from litestar.static_files import create_static_files_router
from facefusion import choices, execution, _preview
from facefusion import _preview, choices, execution, state_manager, vision
from facefusion.processors import choices as processors_choices
from facefusion.state_manager import get_state
from facefusion.typing import ExecutionDevice
@ -47,7 +47,7 @@ async def read_execution_providers() -> Any:
@stream('/execution/devices')
async def stream_execution_devices(socket : WebSocket) -> None:
async def stream_execution_devices(socket : WebSocket[Any, Any, Any]) -> None:
await socket.accept()
while True:
@ -66,7 +66,7 @@ async def read_static_execution_devices() -> List[ExecutionDevice]:
@stream('/state')
async def stream_state(socket : WebSocket) -> None:
async def stream_state(socket : WebSocket[Any, Any, Any]) -> None:
await socket.accept()
while True:
@ -74,12 +74,30 @@ async def stream_state(socket : WebSocket) -> None:
await asyncio.sleep(0.5)
@read('/preview', media_type = 'image/png')
async def read_preview() -> None:
_, preview_vision_frame = cv2.imencode('.png', _preview.process_frame())
@read('/preview', media_type = 'image/png', mode = "binary")
async def read_preview(frame_number : int) -> bytes:
_, preview_vision_frame = cv2.imencode('.png', _preview.process_frame(frame_number)) #type:ignore
return preview_vision_frame.tobytes()
@websocket_listener("/preview", send_mode = "binary")
async def stream_preview(data : str) -> bytes:
frame_number = int(json.loads(data).get('frame_number'))
_, preview_vision_frame = cv2.imencode('.png', _preview.process_frame(frame_number)) #type:ignore
return preview_vision_frame.tobytes()
@read('/ui/preview_slider')
async def read_ui_preview_slider() -> Any:
target_path = state_manager.get_item('target_path')
video_frame_total = vision.count_video_frame_total(target_path)
return\
{
'video_frame_total': video_frame_total
}
api = Litestar(
[
read_choices,
@ -88,7 +106,14 @@ api = Litestar(
read_execution_devices,
read_static_execution_devices,
stream_state,
read_preview
read_preview,
read_ui_preview_slider,
stream_preview,
create_static_files_router(
path = '/frontend',
directories = [ 'facefusion/static' ],
html_mode = True,
)
])

View File

@ -2,6 +2,7 @@ import itertools
import shutil
import signal
import sys
import webbrowser
from time import time
import numpy
@ -61,6 +62,9 @@ def route(args : Args) -> None:
if not pre_check():
return conditional_exit(2)
if state_manager.get_item('command') == 'run':
if state_manager.get_item('open_browser'):
webbrowser.open('http://127.0.0.1:8000/frontend')
logger.info('http://127.0.0.1:8000/frontend', __name__)
api.run()
if state_manager.get_item('command') == 'headless-run':
if not job_manager.init_jobs(state_manager.get_item('jobs_path')):

View File

@ -0,0 +1,242 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Demo</title>
<style>
.preview, meter {
width: 100%;
}
meter { border-radius: 0}
img, textarea {
width: 100%;
}
input[type="range"] {
width: 100%;
margin-top: 1em;
}
</style>
</head>
<body>
<div style="display:flex">
<div class="preview">
<h1>Preview</h1>
<div class="image-container">
<img id="image" alt="Frame Image">
</div>
<input type="range" id="slider" value="0">
<p>Frame: <span id="frameValue">0</span></p>
<button id="playBtn">Play</button>
<button id="stopBtn" disabled>Stop</button>
<input type="checkbox" id="useWebSocket" /> Use WebSocket for Preview
</div>
<div>
<h1>Debug</h1>
<p>Video Memory <meter id="video_memory" min="0" max="100" value="0"></meter></p>
<p>GPU Utilization <meter id="gpu_utilization" min="0" max="100" value="0"></meter></p>
<textarea id="debug" rows="1" cols="80" readonly></textarea>
<textarea id="devices" rows="4" cols="80" readonly></textarea>
<textarea id="state" rows="30" cols="80" readonly></textarea>
<textarea id="log" rows="10" cols="80" readonly></textarea>
<textarea id="fps" rows="2" cols="80" readonly></textarea>
</div>
</div>
<script>
function createWebSocketConnection(url, debug, container) {
const socket = new WebSocket(url);
socket.onopen = () => {
debug.value = `WebSocket connection established for URL: ${url}`;
};
socket.onmessage = event => {
debug.value = `WebSocket Event: ${event.type}`;
container.value = event.data;
};
socket.onerror = error => {
debug.value = `WebSocket Error: ${error}`;
};
socket.onclose = () => {
debug.value = `WebSocket connection closed for URL: ${url} -> Reloading Page`;
setTimeout(() => location.reload(), 1000);
};
return socket;
}
devicesSocket = createWebSocketConnection('ws://127.0.0.1:8000/execution/devices', debug, devices);
createWebSocketConnection('ws://127.0.0.1:8000/state', debug, state);
devicesSocket.addEventListener('message', event => {
const data = JSON.parse(event.data)[0]
const freeMemory = data.video_memory.free.value;
const totalMemory = data.video_memory.total.value;
const usedMemory = totalMemory - freeMemory;
const usedMemoryPercentage = (usedMemory / totalMemory) * 100;
video_memory.value = usedMemoryPercentage;
gpu_utilization.value = data.utilization.gpu.value
})
</script>
<script>
const slider = document.getElementById('slider');
const image = document.getElementById('image');
const frameValue = document.getElementById('frameValue');
const playBtn = document.getElementById('playBtn');
const stopBtn = document.getElementById('stopBtn');
const logTextarea = document.getElementById('log');
const useWebSocketCheckbox = document.getElementById('useWebSocket');
let totalFrames = 0;
let currentFrame = 0;
let totalFps = 0;
let requestCount = 0;
let isPlaying = false;
let socket;
// Fetch the total frame count and set up the slider
async function fetchSliderTotal() {
try {
const start = performance.now();
const response = await fetch('http://127.0.0.1:8000/ui/preview_slider');
const end = performance.now();
logRequest('GET', 'http://127.0.0.1:8000/ui/preview_slider', start, end);
if (response.ok) {
const data = await response.json();
totalFrames = data.video_frame_total;
slider.max = totalFrames;
slider.value = 0;
frameValue.textContent = 0;
image.src = `http://127.0.0.1:8000/preview?frame_number=0`;
} else {
console.error('Failed to fetch total frame count');
}
} catch (error) {
console.error('Error fetching total frame count:', error);
}
}
// Function to log request details to the textarea
function logRequest(method, url, startTime, endTime) {
const duration = (endTime - startTime).toFixed(2);
const logMessage = `${method} ${url} | Duration: ${duration}ms\n`;
// Append to the log textarea
logTextarea.value += logMessage;
logTextarea.scrollTop = logTextarea.scrollHeight; // Auto scroll to the bottom
}
function logFps(startTime, endTime) {
const duration = (endTime - startTime).toFixed(2);
const durationInSeconds = duration / 1000; // Convert ms to seconds
const fps = (1 / durationInSeconds).toFixed(2); // FPS for this request
// Update total FPS and request count
totalFps += parseFloat(fps);
requestCount++;
// Calculate average FPS
const averageFps = (totalFps / requestCount).toFixed(2);
// Update the textarea with id 'fps' to show the average FPS
const fpsTextarea = document.getElementById('fps');
fpsTextarea.value = `Average FPS: ${averageFps}\n`;
// Optionally, you can append the current FPS to the textarea as well:
fpsTextarea.value += `Current FPS: ${fps}\n`;
}
// Function to update the image based on the slider's value or WebSocket message
function updateImage() {
const frameNumber = slider.value;
// If WebSocket is enabled, use WebSocket to fetch the image
if (useWebSocketCheckbox.checked) {
image.onload = null
if (!socket) {
socket = new WebSocket('ws://127.0.0.1:8000/preview');
}
const start = performance.now();
socket.send(JSON.stringify({ frame_number: frameNumber }));
socket.onmessage = function (event) {
const end = performance.now();
logRequest('WEBSOCKET', 'ws://127.0.0.1:8000/preview', start, end);
logFps(start, end)
// Create a Blob URL from the WebSocket message (assumed to be a Blob)
const imageUrl = URL.createObjectURL(event.data);
// Set the image source to the Blob URL
image.src = imageUrl;
frameValue.textContent = frameNumber;
// Continue if playing
if (isPlaying && currentFrame < totalFrames) {
currentFrame++;
slider.value = currentFrame;
updateImage(); // Continue to next frame
}
};
} else {
socket = null
// Use default fetch for the image
const start = performance.now();
image.src = `http://127.0.0.1:8000/preview?frame_number=${frameNumber}`;
image.onload = function () {
const end = performance.now();
logRequest('GET', image.src, start, end);
logFps(start, end)
frameValue.textContent = frameNumber;
// Continue if playing
if (isPlaying && currentFrame < totalFrames) {
currentFrame++;
slider.value = currentFrame;
updateImage(); // Continue to next frame
}
};
}
}
// Function to start the play action (without setInterval, only on image load)
function startPlay() {
playBtn.disabled = true;
stopBtn.disabled = false;
isPlaying = true;
currentFrame = parseInt(slider.value, 10);
// Start loading the first image
updateImage();
}
// Function to stop the play action
function stopPlay() {
isPlaying = false;
playBtn.disabled = false;
stopBtn.disabled = true;
if (socket) {
socket.close(); // Close WebSocket when stopping
}
}
// Event listeners for Play/Stop buttons
playBtn.addEventListener('click', startPlay);
stopBtn.addEventListener('click', stopPlay);
// Slider manual update
slider.addEventListener('change', function () {
currentFrame = slider.value;
updateImage();
});
// Fetch the total number of frames when the page loads
window.onload = fetchSliderTotal;
</script>
</body>
</html>

View File

@ -1,6 +1,5 @@
filetype==1.2.0
gradio==5.7.1
gradio-rangeslider==0.0.8
litestar==2.13.0
numpy==2.1.3
onnx==1.17.0
onnxruntime==1.20.1