Feat/download provider fallback (#837)
* Introduce download providers fallback, Use CURL everywhre * Fix CI * Use readlines() over readline() to avoid while * Use readlines() over readline() to avoid while * Use readlines() over readline() to avoid while
This commit is contained in:
parent
e26381753c
commit
034d029a41
@ -55,8 +55,16 @@ execution_provider_set : ExecutionProviderSet =\
|
||||
execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys())
|
||||
download_provider_set : DownloadProviderSet =\
|
||||
{
|
||||
'github': 'https://github.com/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}',
|
||||
'huggingface': 'https://huggingface.co/facefusion/{base_name}/resolve/main/{file_name}'
|
||||
'github':
|
||||
{
|
||||
'url': 'https://github.com',
|
||||
'path': '/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}'
|
||||
},
|
||||
'huggingface':
|
||||
{
|
||||
'url': 'https://huggingface.co',
|
||||
'path': '/facefusion/{base_name}/resolve/main/{file_name}'
|
||||
}
|
||||
}
|
||||
download_providers : List[DownloadProvider] = list(download_provider_set.keys())
|
||||
download_scopes : List[DownloadScope] = [ 'lite', 'full' ]
|
||||
|
@ -1,8 +1,6 @@
|
||||
import os
|
||||
import shutil
|
||||
import ssl
|
||||
import subprocess
|
||||
import urllib.request
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
@ -11,13 +9,15 @@ from tqdm import tqdm
|
||||
|
||||
import facefusion.choices
|
||||
from facefusion import logger, process_manager, state_manager, wording
|
||||
from facefusion.common_helper import is_macos
|
||||
from facefusion.filesystem import get_file_size, is_file, remove_file
|
||||
from facefusion.hash_helper import validate_hash
|
||||
from facefusion.typing import DownloadProvider, DownloadSet
|
||||
|
||||
if is_macos():
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
|
||||
def open_curl(args : List[str]) -> subprocess.Popen[bytes]:
|
||||
commands = [ shutil.which('curl'), '--silent', '--insecure', '--location' ]
|
||||
commands.extend(args)
|
||||
return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE)
|
||||
|
||||
|
||||
def conditional_download(download_directory_path : str, urls : List[str]) -> None:
|
||||
@ -25,13 +25,15 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
|
||||
download_file_name = os.path.basename(urlparse(url).path)
|
||||
download_file_path = os.path.join(download_directory_path, download_file_name)
|
||||
initial_size = get_file_size(download_file_path)
|
||||
download_size = get_download_size(url)
|
||||
download_size = get_static_download_size(url)
|
||||
|
||||
if initial_size < download_size:
|
||||
with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
|
||||
subprocess.Popen([ shutil.which('curl'), '--create-dirs', '--silent', '--insecure', '--location', '--continue-at', '-', '--output', download_file_path, url ])
|
||||
commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url ]
|
||||
open_curl(commands)
|
||||
current_size = initial_size
|
||||
progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name)
|
||||
|
||||
while current_size < download_size:
|
||||
if is_file(download_file_path):
|
||||
current_size = get_file_size(download_file_path)
|
||||
@ -39,13 +41,27 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
|
||||
|
||||
|
||||
@lru_cache(maxsize = None)
|
||||
def get_download_size(url : str) -> int:
|
||||
try:
|
||||
response = urllib.request.urlopen(url, timeout = 10)
|
||||
content_length = response.headers.get('Content-Length')
|
||||
return int(content_length)
|
||||
except (OSError, TypeError, ValueError):
|
||||
return 0
|
||||
def get_static_download_size(url : str) -> int:
|
||||
commands = [ '-I', url ]
|
||||
process = open_curl(commands)
|
||||
lines = reversed(process.stdout.readlines())
|
||||
|
||||
for line in lines:
|
||||
__line__ = line.decode().lower()
|
||||
|
||||
if 'content-length:' in __line__:
|
||||
_, content_length = __line__.split('content-length:')
|
||||
return int(content_length)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@lru_cache(maxsize = None)
|
||||
def ping_static_url(url : str) -> bool:
|
||||
commands = [ '-I', url ]
|
||||
process = open_curl(commands)
|
||||
process.wait()
|
||||
return process.returncode == 0
|
||||
|
||||
|
||||
def conditional_download_hashes(hashes : DownloadSet) -> bool:
|
||||
@ -61,6 +77,7 @@ def conditional_download_hashes(hashes : DownloadSet) -> bool:
|
||||
conditional_download(download_directory_path, [ invalid_hash_url ])
|
||||
|
||||
valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths)
|
||||
|
||||
for valid_hash_path in valid_hash_paths:
|
||||
valid_hash_file_name, _ = os.path.splitext(os.path.basename(valid_hash_path))
|
||||
logger.debug(wording.get('validating_hash_succeed').format(hash_file_name = valid_hash_file_name), __name__)
|
||||
@ -86,6 +103,7 @@ def conditional_download_sources(sources : DownloadSet) -> bool:
|
||||
conditional_download(download_directory_path, [ invalid_source_url ])
|
||||
|
||||
valid_source_paths, invalid_source_paths = validate_source_paths(source_paths)
|
||||
|
||||
for valid_source_path in valid_source_paths:
|
||||
valid_source_file_name, _ = os.path.splitext(os.path.basename(valid_source_path))
|
||||
logger.debug(wording.get('validating_source_succeed').format(source_file_name = valid_source_file_name), __name__)
|
||||
@ -128,11 +146,17 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str
|
||||
def resolve_download_url(base_name : str, file_name : str) -> Optional[str]:
|
||||
download_providers = state_manager.get_item('download_providers')
|
||||
|
||||
for download_provider in facefusion.choices.download_provider_set:
|
||||
if download_provider in download_providers:
|
||||
for download_provider in download_providers:
|
||||
if ping_download_provider(download_provider):
|
||||
return resolve_download_url_by_provider(download_provider, base_name, file_name)
|
||||
return None
|
||||
|
||||
|
||||
def ping_download_provider(download_provider : DownloadProvider) -> bool:
|
||||
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
|
||||
return ping_static_url(download_provider_value.get('url'))
|
||||
|
||||
|
||||
def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]:
|
||||
return facefusion.choices.download_provider_set.get(download_provider).format(base_name = base_name, file_name = file_name)
|
||||
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
|
||||
return download_provider_value.get('url') + download_provider_value.get('path').format(base_name = base_name, file_name = file_name)
|
||||
|
@ -22,10 +22,15 @@ def run_ffmpeg_with_progress(args: List[str], update_progress : UpdateProgress)
|
||||
|
||||
while process_manager.is_processing():
|
||||
try:
|
||||
while line := process.stdout.readline().decode():
|
||||
if 'frame=' in line:
|
||||
_, frame_number = line.split('frame=')
|
||||
lines = process.stdout.readlines()
|
||||
|
||||
for line in lines:
|
||||
__line__ = line.decode().lower()
|
||||
|
||||
if 'frame=' in __line__:
|
||||
_, frame_number = __line__.split('frame=')
|
||||
update_progress(int(frame_number))
|
||||
|
||||
if log_level == 'debug':
|
||||
log_debug(process)
|
||||
process.wait(timeout = 0.5)
|
||||
|
@ -206,7 +206,7 @@ def create_download_providers_program() -> ArgumentParser:
|
||||
program = ArgumentParser(add_help = False)
|
||||
download_providers = list(facefusion.choices.download_provider_set.keys())
|
||||
group_download = program.add_argument_group('download')
|
||||
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', 'github'), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
|
||||
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
|
||||
job_store.register_job_keys([ 'download_providers' ])
|
||||
return program
|
||||
|
||||
|
@ -159,7 +159,12 @@ ExecutionDevice = TypedDict('ExecutionDevice',
|
||||
})
|
||||
|
||||
DownloadProvider = Literal['github', 'huggingface']
|
||||
DownloadProviderSet = Dict[DownloadProvider, str]
|
||||
DownloadProviderValue = TypedDict('DownloadProviderValue',
|
||||
{
|
||||
'url' : str,
|
||||
'path' : str
|
||||
})
|
||||
DownloadProviderSet = Dict[DownloadProvider, DownloadProviderValue]
|
||||
DownloadScope = Literal['lite', 'full']
|
||||
Download = TypedDict('Download',
|
||||
{
|
||||
|
@ -1,18 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from facefusion.download import conditional_download, get_download_size
|
||||
from .helper import get_test_examples_directory
|
||||
from facefusion.download import get_static_download_size, ping_static_url
|
||||
|
||||
|
||||
@pytest.fixture(scope = 'module', autouse = True)
|
||||
def before_all() -> None:
|
||||
conditional_download(get_test_examples_directory(),
|
||||
[
|
||||
'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4'
|
||||
])
|
||||
def test_get_static_download_size() -> None:
|
||||
assert get_static_download_size('https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fairface.onnx') == 85170772
|
||||
assert get_static_download_size('https://huggingface.co/facefusion/models-3.0.0/resolve/main/fairface.onnx') == 85170772
|
||||
assert get_static_download_size('invalid') == 0
|
||||
|
||||
|
||||
def test_get_download_size() -> None:
|
||||
assert get_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4') == 191675
|
||||
assert get_download_size('https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-360p.mp4') == 370732
|
||||
assert get_download_size('invalid') == 0
|
||||
def test_static_ping_url() -> None:
|
||||
assert ping_static_url('https://github.com') is True
|
||||
assert ping_static_url('https://huggingface.co') is True
|
||||
assert ping_static_url('invalid') is False
|
||||
|
Loading…
Reference in New Issue
Block a user