Introduce download providers fallback, Use CURL everywhre

This commit is contained in:
henryruhs 2024-12-21 17:16:18 +01:00
parent e26381753c
commit 5e6b3d55c2
6 changed files with 75 additions and 27 deletions

View File

@ -55,8 +55,16 @@ execution_provider_set : ExecutionProviderSet =\
execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys())
download_provider_set : DownloadProviderSet =\
{
'github': 'https://github.com/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}',
'huggingface': 'https://huggingface.co/facefusion/{base_name}/resolve/main/{file_name}'
'github':
{
'url': 'https://github.com',
'path': '/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}'
},
'huggingface':
{
'url': 'https://huggingface.co',
'path': '/facefusion/{base_name}/resolve/main/{file_name}'
}
}
download_providers : List[DownloadProvider] = list(download_provider_set.keys())
download_scopes : List[DownloadScope] = [ 'lite', 'full' ]

View File

@ -1,8 +1,6 @@
import os
import shutil
import ssl
import subprocess
import urllib.request
from functools import lru_cache
from typing import List, Optional, Tuple
from urllib.parse import urlparse
@ -11,13 +9,15 @@ from tqdm import tqdm
import facefusion.choices
from facefusion import logger, process_manager, state_manager, wording
from facefusion.common_helper import is_macos
from facefusion.filesystem import get_file_size, is_file, remove_file
from facefusion.hash_helper import validate_hash
from facefusion.typing import DownloadProvider, DownloadSet
if is_macos():
ssl._create_default_https_context = ssl._create_unverified_context
def open_curl(args : List[str]) -> subprocess.Popen[bytes]:
commands = [ shutil.which('curl'), '--silent', '--insecure', '--location' ]
commands.extend(args)
return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE)
def conditional_download(download_directory_path : str, urls : List[str]) -> None:
@ -25,13 +25,15 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
download_file_name = os.path.basename(urlparse(url).path)
download_file_path = os.path.join(download_directory_path, download_file_name)
initial_size = get_file_size(download_file_path)
download_size = get_download_size(url)
download_size = get_static_download_size(url)
if initial_size < download_size:
with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
subprocess.Popen([ shutil.which('curl'), '--create-dirs', '--silent', '--insecure', '--location', '--continue-at', '-', '--output', download_file_path, url ])
commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url ]
open_curl(commands)
current_size = initial_size
progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name)
while current_size < download_size:
if is_file(download_file_path):
current_size = get_file_size(download_file_path)
@ -39,13 +41,28 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
@lru_cache(maxsize = None)
def get_download_size(url : str) -> int:
try:
response = urllib.request.urlopen(url, timeout = 10)
content_length = response.headers.get('Content-Length')
return int(content_length)
except (OSError, TypeError, ValueError):
return 0
def get_static_download_size(url : str) -> int:
commands = [ '-I', url ]
process = open_curl(commands)
process.wait()
while line := process.stdout.readline().decode().lower():
if 'content-length:' in line:
_, content_length = line.split('content-length:')
content_length = int(content_length)
if content_length > 0:
return content_length
return 0
@lru_cache(maxsize = None)
def ping_static_url(url : str) -> bool:
commands = [ '-I', url ]
process = open_curl(commands)
process.wait()
return process.returncode == 0
def conditional_download_hashes(hashes : DownloadSet) -> bool:
@ -61,6 +78,7 @@ def conditional_download_hashes(hashes : DownloadSet) -> bool:
conditional_download(download_directory_path, [ invalid_hash_url ])
valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths)
for valid_hash_path in valid_hash_paths:
valid_hash_file_name, _ = os.path.splitext(os.path.basename(valid_hash_path))
logger.debug(wording.get('validating_hash_succeed').format(hash_file_name = valid_hash_file_name), __name__)
@ -86,6 +104,7 @@ def conditional_download_sources(sources : DownloadSet) -> bool:
conditional_download(download_directory_path, [ invalid_source_url ])
valid_source_paths, invalid_source_paths = validate_source_paths(source_paths)
for valid_source_path in valid_source_paths:
valid_source_file_name, _ = os.path.splitext(os.path.basename(valid_source_path))
logger.debug(wording.get('validating_source_succeed').format(source_file_name = valid_source_file_name), __name__)
@ -128,11 +147,17 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str
def resolve_download_url(base_name : str, file_name : str) -> Optional[str]:
download_providers = state_manager.get_item('download_providers')
for download_provider in facefusion.choices.download_provider_set:
if download_provider in download_providers:
for download_provider in download_providers:
if ping_download_provider(download_provider):
return resolve_download_url_by_provider(download_provider, base_name, file_name)
return None
def ping_download_provider(download_provider : DownloadProvider) -> bool:
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
return ping_static_url(download_provider_value.get('url'))
def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]:
return facefusion.choices.download_provider_set.get(download_provider).format(base_name = base_name, file_name = file_name)
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
return download_provider_value.get('url') + download_provider_value.get('path').format(base_name = base_name, file_name = file_name)

View File

@ -22,10 +22,15 @@ def run_ffmpeg_with_progress(args: List[str], update_progress : UpdateProgress)
while process_manager.is_processing():
try:
while line := process.stdout.readline().decode():
while line := process.stdout.readline().decode().lower():
if 'frame=' in line:
_, frame_number = line.split('frame=')
update_progress(int(frame_number))
frame_number = int(frame_number)
if frame_number > 0:
update_progress(frame_number)
if log_level == 'debug':
log_debug(process)
process.wait(timeout = 0.5)

View File

@ -206,7 +206,7 @@ def create_download_providers_program() -> ArgumentParser:
program = ArgumentParser(add_help = False)
download_providers = list(facefusion.choices.download_provider_set.keys())
group_download = program.add_argument_group('download')
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', 'github'), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
job_store.register_job_keys([ 'download_providers' ])
return program

View File

@ -159,7 +159,12 @@ ExecutionDevice = TypedDict('ExecutionDevice',
})
DownloadProvider = Literal['github', 'huggingface']
DownloadProviderSet = Dict[DownloadProvider, str]
DownloadProviderValue = TypedDict('DownloadProviderValue',
{
'url' : str,
'path' : str
})
DownloadProviderSet = Dict[DownloadProvider, DownloadProviderValue]
DownloadScope = Literal['lite', 'full']
Download = TypedDict('Download',
{

View File

@ -1,6 +1,6 @@
import pytest
from facefusion.download import conditional_download, get_download_size
from facefusion.download import conditional_download, get_static_download_size, ping_url
from .helper import get_test_examples_directory
@ -13,6 +13,11 @@ def before_all() -> None:
def test_get_download_size() -> None:
assert get_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4') == 191675
assert get_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-360p.mp4') == 370732
assert get_download_size('invalid') == 0
assert get_static_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4') == 191675
assert get_static_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-360p.mp4') == 370732
assert get_static_download_size('invalid') == 0
def test_ping_url() -> None:
assert ping_url('https://github.com') is True
assert ping_url('invalid') is False