diff --git a/nhentai/command.py b/nhentai/command.py index 6661592..39607b1 100644 --- a/nhentai/command.py +++ b/nhentai/command.py @@ -77,7 +77,7 @@ def main(): doujinshi_ids = list(set(map(int, doujinshi_ids)) - set(data)) if not options.is_show: - downloader = Downloader(path=options.output_dir, size=options.threads, + downloader = Downloader(path=options.output_dir, threads=options.threads, timeout=options.timeout, delay=options.delay) for doujinshi_id in doujinshi_ids: diff --git a/nhentai/downloader.py b/nhentai/downloader.py index e030b86..c28e944 100644 --- a/nhentai/downloader.py +++ b/nhentai/downloader.py @@ -1,9 +1,6 @@ # coding: utf- -import multiprocessing - import os -import time import urllib3.exceptions from urllib.parse import urlparse @@ -15,8 +12,6 @@ import asyncio import httpx urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) -semaphore = multiprocessing.Semaphore(1) - class NHentaiImageNotExistException(Exception): pass @@ -38,16 +33,23 @@ def download_callback(result): class Downloader(Singleton): - def __init__(self, path='', size=5, timeout=30, delay=0): - self.size = size + def __init__(self, path='', threads=5, timeout=30, delay=0): + self.semaphore = asyncio.Semaphore(threads) self.path = str(path) self.timeout = timeout self.delay = delay + async def _semaphore_download(self, *args, **kwargs): + # This sets a concurrency limit for AsyncIO + async with self.semaphore: + return await self.download(*args, **kwargs) + async def download(self, url, folder='', filename='', retried=0, proxy=None): - if self.delay: - time.sleep(self.delay) logger.info(f'Starting to download {url} ...') + + if self.delay: + await asyncio.sleep(self.delay) + filename = filename if filename else os.path.basename(urlparse(url).path) save_file_path = os.path.join(self.folder, filename) @@ -150,7 +152,7 @@ class Downloader(Singleton): logger.error(f'An error occurred: {e}') tasks = [ - self.download(url, filename=os.path.basename(urlparse(url).path)) + self._semaphore_download(url, filename=os.path.basename(urlparse(url).path)) for url in queue ] # Prevent coroutines infection