diff --git a/nhentai/command.py b/nhentai/command.py index 6661592..39607b1 100644 --- a/nhentai/command.py +++ b/nhentai/command.py @@ -77,7 +77,7 @@ def main(): doujinshi_ids = list(set(map(int, doujinshi_ids)) - set(data)) if not options.is_show: - downloader = Downloader(path=options.output_dir, size=options.threads, + downloader = Downloader(path=options.output_dir, threads=options.threads, timeout=options.timeout, delay=options.delay) for doujinshi_id in doujinshi_ids: diff --git a/nhentai/downloader.py b/nhentai/downloader.py index e030b86..9bcb6f2 100644 --- a/nhentai/downloader.py +++ b/nhentai/downloader.py @@ -1,22 +1,17 @@ # coding: utf- -import multiprocessing - import os -import time +import asyncio +import httpx import urllib3.exceptions from urllib.parse import urlparse from nhentai import constant from nhentai.logger import logger -from nhentai.utils import Singleton +from nhentai.utils import Singleton, async_request -import asyncio -import httpx urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) -semaphore = multiprocessing.Semaphore(1) - class NHentaiImageNotExistException(Exception): pass @@ -37,17 +32,32 @@ def download_callback(result): logger.log(16, f'{data} downloaded successfully') +async def fiber(tasks): + for completed_task in asyncio.as_completed(tasks): + try: + result = await completed_task + logger.info(f'{result[1]} download completed') + except Exception as e: + logger.error(f'An error occurred: {e}') + + class Downloader(Singleton): - def __init__(self, path='', size=5, timeout=30, delay=0): - self.size = size + def __init__(self, path='', threads=5, timeout=30, delay=0): + self.threads = threads self.path = str(path) self.timeout = timeout self.delay = delay + async def _semaphore_download(self, semaphore, *args, **kwargs): + async with semaphore: + return await self.download(*args, **kwargs) + async def download(self, url, folder='', filename='', retried=0, proxy=None): - if self.delay: - time.sleep(self.delay) logger.info(f'Starting to download {url} ...') + + if self.delay: + await asyncio.sleep(self.delay) + filename = filename if filename else os.path.basename(urlparse(url).path) save_file_path = os.path.join(self.folder, filename) @@ -57,14 +67,14 @@ class Downloader(Singleton): logger.warning(f'Skipped download: {save_file_path} already exists') return 1, url - response = await self.async_request(url, self.timeout) # TODO: Add proxy + response = await async_request('GET', url, timeout=self.timeout, proxies=proxy) if response.status_code != 200: path = urlparse(url).path for mirror in constant.IMAGE_URL_MIRRORS: logger.info(f"Try mirror: {mirror}{path}") mirror_url = f'{mirror}{path}' - response = await self.async_request(mirror_url, self.timeout) + response = await async_request('GET', mirror_url, timeout=self.timeout, proxies=proxy) if response.status_code == 200: break @@ -117,13 +127,9 @@ class Downloader(Singleton): f.write(chunk) return True - async def async_request(self, url, timeout): - async with httpx.AsyncClient() as client: - return await client.get(url, timeout=timeout) def start_download(self, queue, folder='') -> bool: - logger.warning("Proxy temporarily unavailable, it will be fixed later. ") - if not isinstance(folder, (str, )): + if not isinstance(folder, (str,)): folder = str(folder) if self.path: @@ -141,19 +147,14 @@ class Downloader(Singleton): # Assuming we want to continue with rest of process. return True - async def fiber(tasks): - for completed_task in asyncio.as_completed(tasks): - try: - result = await completed_task - logger.info(f'{result[1]} download completed') - except Exception as e: - logger.error(f'An error occurred: {e}') + semaphore = asyncio.Semaphore(self.threads) - tasks = [ - self.download(url, filename=os.path.basename(urlparse(url).path)) + coroutines = [ + self._semaphore_download(semaphore, url, filename=os.path.basename(urlparse(url).path)) for url in queue ] + # Prevent coroutines infection - asyncio.run(fiber(tasks)) + asyncio.run(fiber(coroutines)) return True diff --git a/nhentai/utils.py b/nhentai/utils.py index 6e50a7e..f445201 100644 --- a/nhentai/utils.py +++ b/nhentai/utils.py @@ -6,6 +6,7 @@ import os import zipfile import shutil +import httpx import requests import sqlite3 import urllib.parse @@ -32,8 +33,28 @@ def request(method, url, **kwargs): return getattr(session, method)(url, verify=False, **kwargs) +async def async_request(method, url, proxies = None, **kwargs): + headers = { + 'Referer': constant.LOGIN_URL, + 'User-Agent': constant.CONFIG['useragent'], + 'Cookie': constant.CONFIG['cookie'], + } + + if proxies is None: + proxies = constant.CONFIG['proxy'] + + if proxies.get('http') == '' and proxies.get('https') == '': + proxies = None + + async with httpx.AsyncClient(headers=headers, verify=False, proxies=proxies, **kwargs) as client: + response = await client.request(method, url, **kwargs) + + return response + + def check_cookie(): response = request('get', constant.BASE_URL) + if response.status_code == 403 and 'Just a moment...' in response.text: logger.error('Blocked by Cloudflare captcha, please set your cookie and useragent') sys.exit(1) diff --git a/tests/test_download.py b/tests/test_download.py index 9ffcdd1..68f3074 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -20,7 +20,7 @@ class TestDownload(unittest.TestCase): def test_download(self): did = 440546 info = Doujinshi(**doujinshi_parser(did), name_format='%i') - info.downloader = Downloader(path='/tmp', size=5) + info.downloader = Downloader(path='/tmp', threads=5) info.download() self.assertTrue(os.path.exists(f'/tmp/{did}/001.jpg'))