Use coroutine in url download and improve the extensibility of class Downloader

This commit is contained in:
Hellagur4225 2024-11-16 15:57:59 +08:00
parent f30ff59b2b
commit 14c6db9cc3

View File

@ -1,20 +1,18 @@
# coding: utf-
import multiprocessing
import signal
import sys
import os
import requests
import time
import urllib3.exceptions
from urllib.parse import urlparse
from nhentai import constant
from nhentai.logger import logger
from nhentai.parser import request
from nhentai.utils import Singleton
import asyncio
import httpx
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
semaphore = multiprocessing.Semaphore(1)
@ -40,14 +38,13 @@ def download_callback(result):
class Downloader(Singleton):
def __init__(self, path='', size=5, timeout=30, delay=0):
self.size = size
self.path = str(path)
self.timeout = timeout
self.delay = delay
def download(self, url, folder='', filename='', retried=0, proxy=None):
async def download(self, url, folder='', filename='', retried=0, proxy=None):
if self.delay:
time.sleep(self.delay)
logger.info(f'Starting to download {url} ...')
@ -55,48 +52,40 @@ class Downloader(Singleton):
base_filename, extension = os.path.splitext(filename)
save_file_path = os.path.join(folder, base_filename.zfill(3) + extension)
try:
if not os.path.exists(folder):
os.makedirs(folder, exist_ok=True)
if os.path.exists(save_file_path):
logger.warning(f'Skipped download: {save_file_path} already exists')
return 1, url
response = None
with open(save_file_path, "wb") as f:
i = 0
while i < 10:
try:
response = request('get', url, stream=True, timeout=self.timeout, proxies=proxy)
response = await self.async_request(url, self.timeout) # TODO: Add proxy
if response.status_code != 200:
path = urlparse(url).path
for mirror in constant.IMAGE_URL_MIRRORS:
print(f'{mirror}{path}')
mirror_url = f'{mirror}{path}'
response = request('get', mirror_url, stream=True,
timeout=self.timeout, proxies=proxy)
response = await self.async_request(mirror_url, self.timeout)
if response.status_code == 200:
break
except Exception as e:
i += 1
if not i < 10:
logger.critical(str(e))
return 0, None
continue
if not await self.save(save_file_path, response):
logger.error(f'Can not download image {url}')
return 1, None
break
length = response.headers.get('content-length')
if length is None:
f.write(response.content)
else:
for chunk in response.iter_content(2048):
f.write(chunk)
except (requests.HTTPError, requests.Timeout) as e:
except (httpx.HTTPStatusError, httpx.TimeoutException) as e:
if retried < 3:
logger.warning(f'Warning: {e}, retrying({retried}) ...')
return 0, self.download(url=url, folder=folder, filename=filename,
retried=retried+1, proxy=proxy)
return 0, await self.download(
url=url,
folder=folder,
filename=filename,
retried=retried + 1,
proxy=proxy,
)
else:
return 0, None
@ -106,6 +95,7 @@ class Downloader(Singleton):
except Exception as e:
import traceback
traceback.print_stack()
logger.critical(str(e))
return 0, None
@ -115,6 +105,25 @@ class Downloader(Singleton):
return 1, url
async def save(self, save_file_path, response) -> bool:
if response is None:
logger.error('Error: Response is None')
return False
with open(save_file_path, 'wb') as f:
if response is not None:
length = response.headers.get('content-length')
if length is None:
f.write(response.content)
else:
async for chunk in response.aiter_bytes(2048):
f.write(chunk)
return True
async def async_request(self, url, timeout):
async with httpx.AsyncClient() as client:
return await client.get(url, timeout=timeout)
def start_download(self, queue, folder='') -> bool:
if not isinstance(folder, (str, )):
folder = str(folder)
@ -132,30 +141,20 @@ class Downloader(Singleton):
if os.getenv('DEBUG', None) == 'NODOWNLOAD':
# Assuming we want to continue with rest of process.
return True
queue = [(self, url, folder, constant.CONFIG['proxy']) for url in queue]
pool = multiprocessing.Pool(self.size, init_worker)
[pool.apply_async(download_wrapper, args=item) for item in queue]
async def co_wrapper(tasks):
for completed_task in asyncio.as_completed(tasks):
try:
result = await completed_task
logger.info(f'{result[1]} download completed')
except Exception as e:
logger.error(f'An error occurred: {e}')
pool.close()
pool.join()
tasks = [
self.download(url, filename=os.path.basename(urlparse(url).path))
for url in queue
]
# Prevent coroutines infection
asyncio.run(co_wrapper(tasks))
return True
def download_wrapper(obj, url, folder='', proxy=None):
if sys.platform == 'darwin' or semaphore.get_value():
return Downloader.download(obj, url=url, folder=folder, proxy=proxy)
else:
return -3, None
def init_worker():
signal.signal(signal.SIGINT, subprocess_signal)
def subprocess_signal(sig, frame):
if semaphore.acquire(timeout=1):
logger.warning('Ctrl-C pressed, exiting sub processes ...')
raise KeyboardInterrupt