From 14c6db9cc30c3cbe0e98e1b43bb70785ed6241ef Mon Sep 17 00:00:00 2001 From: Hellagur4225 Date: Sat, 16 Nov 2024 15:57:59 +0800 Subject: [PATCH] Use coroutine in url download and improve the extensibility of class Downloader --- nhentai/downloader.py | 121 +++++++++++++++++++++--------------------- 1 file changed, 60 insertions(+), 61 deletions(-) diff --git a/nhentai/downloader.py b/nhentai/downloader.py index a8e0e2f..fce8606 100644 --- a/nhentai/downloader.py +++ b/nhentai/downloader.py @@ -1,20 +1,18 @@ # coding: utf- import multiprocessing -import signal -import sys import os -import requests import time import urllib3.exceptions from urllib.parse import urlparse from nhentai import constant from nhentai.logger import logger -from nhentai.parser import request from nhentai.utils import Singleton +import asyncio +import httpx urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) semaphore = multiprocessing.Semaphore(1) @@ -40,14 +38,13 @@ def download_callback(result): class Downloader(Singleton): - def __init__(self, path='', size=5, timeout=30, delay=0): self.size = size self.path = str(path) self.timeout = timeout self.delay = delay - def download(self, url, folder='', filename='', retried=0, proxy=None): + async def download(self, url, folder='', filename='', retried=0, proxy=None): if self.delay: time.sleep(self.delay) logger.info(f'Starting to download {url} ...') @@ -55,48 +52,40 @@ class Downloader(Singleton): base_filename, extension = os.path.splitext(filename) save_file_path = os.path.join(folder, base_filename.zfill(3) + extension) + try: + if not os.path.exists(folder): + os.makedirs(folder, exist_ok=True) + if os.path.exists(save_file_path): logger.warning(f'Skipped download: {save_file_path} already exists') return 1, url - response = None - with open(save_file_path, "wb") as f: - i = 0 - while i < 10: - try: - response = request('get', url, stream=True, timeout=self.timeout, proxies=proxy) - if response.status_code != 200: - path = urlparse(url).path - for mirror in constant.IMAGE_URL_MIRRORS: - print(f'{mirror}{path}') - mirror_url = f'{mirror}{path}' - response = request('get', mirror_url, stream=True, - timeout=self.timeout, proxies=proxy) - if response.status_code == 200: - break + response = await self.async_request(url, self.timeout) # TODO: Add proxy - except Exception as e: - i += 1 - if not i < 10: - logger.critical(str(e)) - return 0, None - continue + if response.status_code != 200: + path = urlparse(url).path + for mirror in constant.IMAGE_URL_MIRRORS: + print(f'{mirror}{path}') + mirror_url = f'{mirror}{path}' + response = await self.async_request(mirror_url, self.timeout) + if response.status_code == 200: + break - break + if not await self.save(save_file_path, response): + logger.error(f'Can not download image {url}') + return 1, None - length = response.headers.get('content-length') - if length is None: - f.write(response.content) - else: - for chunk in response.iter_content(2048): - f.write(chunk) - - except (requests.HTTPError, requests.Timeout) as e: + except (httpx.HTTPStatusError, httpx.TimeoutException) as e: if retried < 3: logger.warning(f'Warning: {e}, retrying({retried}) ...') - return 0, self.download(url=url, folder=folder, filename=filename, - retried=retried+1, proxy=proxy) + return 0, await self.download( + url=url, + folder=folder, + filename=filename, + retried=retried + 1, + proxy=proxy, + ) else: return 0, None @@ -106,6 +95,7 @@ class Downloader(Singleton): except Exception as e: import traceback + traceback.print_stack() logger.critical(str(e)) return 0, None @@ -115,6 +105,25 @@ class Downloader(Singleton): return 1, url + async def save(self, save_file_path, response) -> bool: + if response is None: + logger.error('Error: Response is None') + return False + + with open(save_file_path, 'wb') as f: + if response is not None: + length = response.headers.get('content-length') + if length is None: + f.write(response.content) + else: + async for chunk in response.aiter_bytes(2048): + f.write(chunk) + return True + + async def async_request(self, url, timeout): + async with httpx.AsyncClient() as client: + return await client.get(url, timeout=timeout) + def start_download(self, queue, folder='') -> bool: if not isinstance(folder, (str, )): folder = str(folder) @@ -132,30 +141,20 @@ class Downloader(Singleton): if os.getenv('DEBUG', None) == 'NODOWNLOAD': # Assuming we want to continue with rest of process. return True - queue = [(self, url, folder, constant.CONFIG['proxy']) for url in queue] - pool = multiprocessing.Pool(self.size, init_worker) - [pool.apply_async(download_wrapper, args=item) for item in queue] + async def co_wrapper(tasks): + for completed_task in asyncio.as_completed(tasks): + try: + result = await completed_task + logger.info(f'{result[1]} download completed') + except Exception as e: + logger.error(f'An error occurred: {e}') - pool.close() - pool.join() + tasks = [ + self.download(url, filename=os.path.basename(urlparse(url).path)) + for url in queue + ] + # Prevent coroutines infection + asyncio.run(co_wrapper(tasks)) return True - - -def download_wrapper(obj, url, folder='', proxy=None): - if sys.platform == 'darwin' or semaphore.get_value(): - return Downloader.download(obj, url=url, folder=folder, proxy=proxy) - else: - return -3, None - - -def init_worker(): - signal.signal(signal.SIGINT, subprocess_signal) - - -def subprocess_signal(sig, frame): - if semaphore.acquire(timeout=1): - logger.warning('Ctrl-C pressed, exiting sub processes ...') - - raise KeyboardInterrupt