Merge pull request #351 from hzxjy1/master

Use coroutine in url download
This commit is contained in:
Ricter Zheng 2024-11-17 10:10:54 +08:00 committed by GitHub
commit 90b17832cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,20 +1,18 @@
# coding: utf- # coding: utf-
import multiprocessing import multiprocessing
import signal
import sys
import os import os
import requests
import time import time
import urllib3.exceptions import urllib3.exceptions
from urllib.parse import urlparse from urllib.parse import urlparse
from nhentai import constant from nhentai import constant
from nhentai.logger import logger from nhentai.logger import logger
from nhentai.parser import request
from nhentai.utils import Singleton from nhentai.utils import Singleton
import asyncio
import httpx
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
semaphore = multiprocessing.Semaphore(1) semaphore = multiprocessing.Semaphore(1)
@ -40,14 +38,13 @@ def download_callback(result):
class Downloader(Singleton): class Downloader(Singleton):
def __init__(self, path='', size=5, timeout=30, delay=0): def __init__(self, path='', size=5, timeout=30, delay=0):
self.size = size self.size = size
self.path = str(path) self.path = str(path)
self.timeout = timeout self.timeout = timeout
self.delay = delay self.delay = delay
def download(self, url, folder='', filename='', retried=0, proxy=None): async def download(self, url, folder='', filename='', retried=0, proxy=None):
if self.delay: if self.delay:
time.sleep(self.delay) time.sleep(self.delay)
logger.info(f'Starting to download {url} ...') logger.info(f'Starting to download {url} ...')
@ -55,48 +52,40 @@ class Downloader(Singleton):
base_filename, extension = os.path.splitext(filename) base_filename, extension = os.path.splitext(filename)
save_file_path = os.path.join(folder, base_filename.zfill(3) + extension) save_file_path = os.path.join(folder, base_filename.zfill(3) + extension)
try: try:
if not os.path.exists(folder):
os.makedirs(folder, exist_ok=True)
if os.path.exists(save_file_path): if os.path.exists(save_file_path):
logger.warning(f'Skipped download: {save_file_path} already exists') logger.warning(f'Skipped download: {save_file_path} already exists')
return 1, url return 1, url
response = None response = await self.async_request(url, self.timeout) # TODO: Add proxy
with open(save_file_path, "wb") as f:
i = 0
while i < 10:
try:
response = request('get', url, stream=True, timeout=self.timeout, proxies=proxy)
if response.status_code != 200:
path = urlparse(url).path
for mirror in constant.IMAGE_URL_MIRRORS:
print(f'{mirror}{path}')
mirror_url = f'{mirror}{path}'
response = request('get', mirror_url, stream=True,
timeout=self.timeout, proxies=proxy)
if response.status_code == 200:
break
except Exception as e: if response.status_code != 200:
i += 1 path = urlparse(url).path
if not i < 10: for mirror in constant.IMAGE_URL_MIRRORS:
logger.critical(str(e)) print(f'{mirror}{path}')
return 0, None mirror_url = f'{mirror}{path}'
continue response = await self.async_request(mirror_url, self.timeout)
if response.status_code == 200:
break
break if not await self.save(save_file_path, response):
logger.error(f'Can not download image {url}')
return 1, None
length = response.headers.get('content-length') except (httpx.HTTPStatusError, httpx.TimeoutException) as e:
if length is None:
f.write(response.content)
else:
for chunk in response.iter_content(2048):
f.write(chunk)
except (requests.HTTPError, requests.Timeout) as e:
if retried < 3: if retried < 3:
logger.warning(f'Warning: {e}, retrying({retried}) ...') logger.warning(f'Warning: {e}, retrying({retried}) ...')
return 0, self.download(url=url, folder=folder, filename=filename, return 0, await self.download(
retried=retried+1, proxy=proxy) url=url,
folder=folder,
filename=filename,
retried=retried + 1,
proxy=proxy,
)
else: else:
return 0, None return 0, None
@ -106,6 +95,7 @@ class Downloader(Singleton):
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_stack() traceback.print_stack()
logger.critical(str(e)) logger.critical(str(e))
return 0, None return 0, None
@ -115,6 +105,25 @@ class Downloader(Singleton):
return 1, url return 1, url
async def save(self, save_file_path, response) -> bool:
if response is None:
logger.error('Error: Response is None')
return False
with open(save_file_path, 'wb') as f:
if response is not None:
length = response.headers.get('content-length')
if length is None:
f.write(response.content)
else:
async for chunk in response.aiter_bytes(2048):
f.write(chunk)
return True
async def async_request(self, url, timeout):
async with httpx.AsyncClient() as client:
return await client.get(url, timeout=timeout)
def start_download(self, queue, folder='') -> bool: def start_download(self, queue, folder='') -> bool:
if not isinstance(folder, (str, )): if not isinstance(folder, (str, )):
folder = str(folder) folder = str(folder)
@ -132,30 +141,20 @@ class Downloader(Singleton):
if os.getenv('DEBUG', None) == 'NODOWNLOAD': if os.getenv('DEBUG', None) == 'NODOWNLOAD':
# Assuming we want to continue with rest of process. # Assuming we want to continue with rest of process.
return True return True
queue = [(self, url, folder, constant.CONFIG['proxy']) for url in queue]
pool = multiprocessing.Pool(self.size, init_worker) async def co_wrapper(tasks):
[pool.apply_async(download_wrapper, args=item) for item in queue] for completed_task in asyncio.as_completed(tasks):
try:
result = await completed_task
logger.info(f'{result[1]} download completed')
except Exception as e:
logger.error(f'An error occurred: {e}')
pool.close() tasks = [
pool.join() self.download(url, filename=os.path.basename(urlparse(url).path))
for url in queue
]
# Prevent coroutines infection
asyncio.run(co_wrapper(tasks))
return True return True
def download_wrapper(obj, url, folder='', proxy=None):
if sys.platform == 'darwin' or semaphore.get_value():
return Downloader.download(obj, url=url, folder=folder, proxy=proxy)
else:
return -3, None
def init_worker():
signal.signal(signal.SIGINT, subprocess_signal)
def subprocess_signal(sig, frame):
if semaphore.acquire(timeout=1):
logger.warning('Ctrl-C pressed, exiting sub processes ...')
raise KeyboardInterrupt