mirror of
				https://github.com/RicterZ/nhentai.git
				synced 2025-11-03 18:50:53 +01:00 
			
		
		
		
	Merge pull request #351 from hzxjy1/master
Use coroutine in url download
This commit is contained in:
		@@ -1,20 +1,18 @@
 | 
			
		||||
# coding: utf-
 | 
			
		||||
 | 
			
		||||
import multiprocessing
 | 
			
		||||
import signal
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import os
 | 
			
		||||
import requests
 | 
			
		||||
import time
 | 
			
		||||
import urllib3.exceptions
 | 
			
		||||
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
from nhentai import constant
 | 
			
		||||
from nhentai.logger import logger
 | 
			
		||||
from nhentai.parser import request
 | 
			
		||||
from nhentai.utils import Singleton
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import httpx
 | 
			
		||||
 | 
			
		||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
 | 
			
		||||
semaphore = multiprocessing.Semaphore(1)
 | 
			
		||||
@@ -40,14 +38,13 @@ def download_callback(result):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Downloader(Singleton):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, path='', size=5, timeout=30, delay=0):
 | 
			
		||||
        self.size = size
 | 
			
		||||
        self.path = str(path)
 | 
			
		||||
        self.timeout = timeout
 | 
			
		||||
        self.delay = delay
 | 
			
		||||
 | 
			
		||||
    def download(self, url, folder='', filename='', retried=0, proxy=None):
 | 
			
		||||
    async def download(self, url, folder='', filename='', retried=0, proxy=None):
 | 
			
		||||
        if self.delay:
 | 
			
		||||
            time.sleep(self.delay)
 | 
			
		||||
        logger.info(f'Starting to download {url} ...')
 | 
			
		||||
@@ -55,48 +52,40 @@ class Downloader(Singleton):
 | 
			
		||||
        base_filename, extension = os.path.splitext(filename)
 | 
			
		||||
 | 
			
		||||
        save_file_path = os.path.join(folder, base_filename.zfill(3) + extension)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            if not os.path.exists(folder):
 | 
			
		||||
                os.makedirs(folder, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
            if os.path.exists(save_file_path):
 | 
			
		||||
                logger.warning(f'Skipped download: {save_file_path} already exists')
 | 
			
		||||
                return 1, url
 | 
			
		||||
 | 
			
		||||
            response = None
 | 
			
		||||
            with open(save_file_path, "wb") as f:
 | 
			
		||||
                i = 0
 | 
			
		||||
                while i < 10:
 | 
			
		||||
                    try:
 | 
			
		||||
                        response = request('get', url, stream=True, timeout=self.timeout, proxies=proxy)
 | 
			
		||||
                        if response.status_code != 200:
 | 
			
		||||
                            path = urlparse(url).path
 | 
			
		||||
                            for mirror in constant.IMAGE_URL_MIRRORS:
 | 
			
		||||
                                print(f'{mirror}{path}')
 | 
			
		||||
                                mirror_url = f'{mirror}{path}'
 | 
			
		||||
                                response = request('get', mirror_url, stream=True,
 | 
			
		||||
                                                   timeout=self.timeout, proxies=proxy)
 | 
			
		||||
                                if response.status_code == 200:
 | 
			
		||||
                                    break
 | 
			
		||||
            response = await self.async_request(url, self.timeout)  # TODO: Add proxy
 | 
			
		||||
 | 
			
		||||
                    except Exception as e:
 | 
			
		||||
                        i += 1
 | 
			
		||||
                        if not i < 10:
 | 
			
		||||
                            logger.critical(str(e))
 | 
			
		||||
                            return 0, None
 | 
			
		||||
                        continue
 | 
			
		||||
            if response.status_code != 200:
 | 
			
		||||
                path = urlparse(url).path
 | 
			
		||||
                for mirror in constant.IMAGE_URL_MIRRORS:
 | 
			
		||||
                    print(f'{mirror}{path}')
 | 
			
		||||
                    mirror_url = f'{mirror}{path}'
 | 
			
		||||
                    response = await self.async_request(mirror_url, self.timeout)
 | 
			
		||||
                    if response.status_code == 200:
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
                    break
 | 
			
		||||
            if not await self.save(save_file_path, response):
 | 
			
		||||
                logger.error(f'Can not download image {url}')
 | 
			
		||||
                return 1, None
 | 
			
		||||
 | 
			
		||||
                length = response.headers.get('content-length')
 | 
			
		||||
                if length is None:
 | 
			
		||||
                    f.write(response.content)
 | 
			
		||||
                else:
 | 
			
		||||
                    for chunk in response.iter_content(2048):
 | 
			
		||||
                        f.write(chunk)
 | 
			
		||||
 | 
			
		||||
        except (requests.HTTPError, requests.Timeout) as e:
 | 
			
		||||
        except (httpx.HTTPStatusError, httpx.TimeoutException) as e:
 | 
			
		||||
            if retried < 3:
 | 
			
		||||
                logger.warning(f'Warning: {e}, retrying({retried}) ...')
 | 
			
		||||
                return 0, self.download(url=url, folder=folder, filename=filename,
 | 
			
		||||
                                        retried=retried+1, proxy=proxy)
 | 
			
		||||
                return 0, await self.download(
 | 
			
		||||
                    url=url,
 | 
			
		||||
                    folder=folder,
 | 
			
		||||
                    filename=filename,
 | 
			
		||||
                    retried=retried + 1,
 | 
			
		||||
                    proxy=proxy,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                return 0, None
 | 
			
		||||
 | 
			
		||||
@@ -106,6 +95,7 @@ class Downloader(Singleton):
 | 
			
		||||
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            import traceback
 | 
			
		||||
 | 
			
		||||
            traceback.print_stack()
 | 
			
		||||
            logger.critical(str(e))
 | 
			
		||||
            return 0, None
 | 
			
		||||
@@ -115,6 +105,25 @@ class Downloader(Singleton):
 | 
			
		||||
 | 
			
		||||
        return 1, url
 | 
			
		||||
 | 
			
		||||
    async def save(self, save_file_path, response) -> bool:
 | 
			
		||||
        if response is None:
 | 
			
		||||
            logger.error('Error: Response is None')
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        with open(save_file_path, 'wb') as f:
 | 
			
		||||
            if response is not None:
 | 
			
		||||
                length = response.headers.get('content-length')
 | 
			
		||||
                if length is None:
 | 
			
		||||
                    f.write(response.content)
 | 
			
		||||
                else:
 | 
			
		||||
                    async for chunk in response.aiter_bytes(2048):
 | 
			
		||||
                        f.write(chunk)
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    async def async_request(self, url, timeout):
 | 
			
		||||
        async with httpx.AsyncClient() as client:
 | 
			
		||||
            return await client.get(url, timeout=timeout)
 | 
			
		||||
 | 
			
		||||
    def start_download(self, queue, folder='') -> bool:
 | 
			
		||||
        if not isinstance(folder, (str, )):
 | 
			
		||||
            folder = str(folder)
 | 
			
		||||
@@ -132,30 +141,20 @@ class Downloader(Singleton):
 | 
			
		||||
        if os.getenv('DEBUG', None) == 'NODOWNLOAD':
 | 
			
		||||
            # Assuming we want to continue with rest of process.
 | 
			
		||||
            return True
 | 
			
		||||
        queue = [(self, url, folder, constant.CONFIG['proxy']) for url in queue]
 | 
			
		||||
 | 
			
		||||
        pool = multiprocessing.Pool(self.size, init_worker)
 | 
			
		||||
        [pool.apply_async(download_wrapper, args=item) for item in queue]
 | 
			
		||||
        async def co_wrapper(tasks):
 | 
			
		||||
            for completed_task in asyncio.as_completed(tasks):
 | 
			
		||||
                try:
 | 
			
		||||
                    result = await completed_task
 | 
			
		||||
                    logger.info(f'{result[1]} download completed')
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    logger.error(f'An error occurred: {e}')
 | 
			
		||||
 | 
			
		||||
        pool.close()
 | 
			
		||||
        pool.join()
 | 
			
		||||
        tasks = [
 | 
			
		||||
            self.download(url, filename=os.path.basename(urlparse(url).path))
 | 
			
		||||
            for url in queue
 | 
			
		||||
        ]
 | 
			
		||||
        # Prevent coroutines infection
 | 
			
		||||
        asyncio.run(co_wrapper(tasks))
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def download_wrapper(obj, url, folder='', proxy=None):
 | 
			
		||||
    if sys.platform == 'darwin' or semaphore.get_value():
 | 
			
		||||
        return Downloader.download(obj, url=url, folder=folder, proxy=proxy)
 | 
			
		||||
    else:
 | 
			
		||||
        return -3, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_worker():
 | 
			
		||||
    signal.signal(signal.SIGINT, subprocess_signal)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def subprocess_signal(sig, frame):
 | 
			
		||||
    if semaphore.acquire(timeout=1):
 | 
			
		||||
        logger.warning('Ctrl-C pressed, exiting sub processes ...')
 | 
			
		||||
 | 
			
		||||
    raise KeyboardInterrupt
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user