mirror of
https://github.com/RicterZ/nhentai.git
synced 2025-04-20 02:41:19 +02:00
Merge pull request #351 from hzxjy1/master
Use coroutine in url download
This commit is contained in:
commit
90b17832cc
@ -1,20 +1,18 @@
|
||||
# coding: utf-
|
||||
|
||||
import multiprocessing
|
||||
import signal
|
||||
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
import time
|
||||
import urllib3.exceptions
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from nhentai import constant
|
||||
from nhentai.logger import logger
|
||||
from nhentai.parser import request
|
||||
from nhentai.utils import Singleton
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
semaphore = multiprocessing.Semaphore(1)
|
||||
@ -40,14 +38,13 @@ def download_callback(result):
|
||||
|
||||
|
||||
class Downloader(Singleton):
|
||||
|
||||
def __init__(self, path='', size=5, timeout=30, delay=0):
|
||||
self.size = size
|
||||
self.path = str(path)
|
||||
self.timeout = timeout
|
||||
self.delay = delay
|
||||
|
||||
def download(self, url, folder='', filename='', retried=0, proxy=None):
|
||||
async def download(self, url, folder='', filename='', retried=0, proxy=None):
|
||||
if self.delay:
|
||||
time.sleep(self.delay)
|
||||
logger.info(f'Starting to download {url} ...')
|
||||
@ -55,48 +52,40 @@ class Downloader(Singleton):
|
||||
base_filename, extension = os.path.splitext(filename)
|
||||
|
||||
save_file_path = os.path.join(folder, base_filename.zfill(3) + extension)
|
||||
|
||||
try:
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
|
||||
if os.path.exists(save_file_path):
|
||||
logger.warning(f'Skipped download: {save_file_path} already exists')
|
||||
return 1, url
|
||||
|
||||
response = None
|
||||
with open(save_file_path, "wb") as f:
|
||||
i = 0
|
||||
while i < 10:
|
||||
try:
|
||||
response = request('get', url, stream=True, timeout=self.timeout, proxies=proxy)
|
||||
if response.status_code != 200:
|
||||
path = urlparse(url).path
|
||||
for mirror in constant.IMAGE_URL_MIRRORS:
|
||||
print(f'{mirror}{path}')
|
||||
mirror_url = f'{mirror}{path}'
|
||||
response = request('get', mirror_url, stream=True,
|
||||
timeout=self.timeout, proxies=proxy)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
response = await self.async_request(url, self.timeout) # TODO: Add proxy
|
||||
|
||||
except Exception as e:
|
||||
i += 1
|
||||
if not i < 10:
|
||||
logger.critical(str(e))
|
||||
return 0, None
|
||||
continue
|
||||
if response.status_code != 200:
|
||||
path = urlparse(url).path
|
||||
for mirror in constant.IMAGE_URL_MIRRORS:
|
||||
print(f'{mirror}{path}')
|
||||
mirror_url = f'{mirror}{path}'
|
||||
response = await self.async_request(mirror_url, self.timeout)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
|
||||
break
|
||||
if not await self.save(save_file_path, response):
|
||||
logger.error(f'Can not download image {url}')
|
||||
return 1, None
|
||||
|
||||
length = response.headers.get('content-length')
|
||||
if length is None:
|
||||
f.write(response.content)
|
||||
else:
|
||||
for chunk in response.iter_content(2048):
|
||||
f.write(chunk)
|
||||
|
||||
except (requests.HTTPError, requests.Timeout) as e:
|
||||
except (httpx.HTTPStatusError, httpx.TimeoutException) as e:
|
||||
if retried < 3:
|
||||
logger.warning(f'Warning: {e}, retrying({retried}) ...')
|
||||
return 0, self.download(url=url, folder=folder, filename=filename,
|
||||
retried=retried+1, proxy=proxy)
|
||||
return 0, await self.download(
|
||||
url=url,
|
||||
folder=folder,
|
||||
filename=filename,
|
||||
retried=retried + 1,
|
||||
proxy=proxy,
|
||||
)
|
||||
else:
|
||||
return 0, None
|
||||
|
||||
@ -106,6 +95,7 @@ class Downloader(Singleton):
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_stack()
|
||||
logger.critical(str(e))
|
||||
return 0, None
|
||||
@ -115,6 +105,25 @@ class Downloader(Singleton):
|
||||
|
||||
return 1, url
|
||||
|
||||
async def save(self, save_file_path, response) -> bool:
|
||||
if response is None:
|
||||
logger.error('Error: Response is None')
|
||||
return False
|
||||
|
||||
with open(save_file_path, 'wb') as f:
|
||||
if response is not None:
|
||||
length = response.headers.get('content-length')
|
||||
if length is None:
|
||||
f.write(response.content)
|
||||
else:
|
||||
async for chunk in response.aiter_bytes(2048):
|
||||
f.write(chunk)
|
||||
return True
|
||||
|
||||
async def async_request(self, url, timeout):
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.get(url, timeout=timeout)
|
||||
|
||||
def start_download(self, queue, folder='') -> bool:
|
||||
if not isinstance(folder, (str, )):
|
||||
folder = str(folder)
|
||||
@ -132,30 +141,20 @@ class Downloader(Singleton):
|
||||
if os.getenv('DEBUG', None) == 'NODOWNLOAD':
|
||||
# Assuming we want to continue with rest of process.
|
||||
return True
|
||||
queue = [(self, url, folder, constant.CONFIG['proxy']) for url in queue]
|
||||
|
||||
pool = multiprocessing.Pool(self.size, init_worker)
|
||||
[pool.apply_async(download_wrapper, args=item) for item in queue]
|
||||
async def co_wrapper(tasks):
|
||||
for completed_task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
result = await completed_task
|
||||
logger.info(f'{result[1]} download completed')
|
||||
except Exception as e:
|
||||
logger.error(f'An error occurred: {e}')
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
tasks = [
|
||||
self.download(url, filename=os.path.basename(urlparse(url).path))
|
||||
for url in queue
|
||||
]
|
||||
# Prevent coroutines infection
|
||||
asyncio.run(co_wrapper(tasks))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def download_wrapper(obj, url, folder='', proxy=None):
|
||||
if sys.platform == 'darwin' or semaphore.get_value():
|
||||
return Downloader.download(obj, url=url, folder=folder, proxy=proxy)
|
||||
else:
|
||||
return -3, None
|
||||
|
||||
|
||||
def init_worker():
|
||||
signal.signal(signal.SIGINT, subprocess_signal)
|
||||
|
||||
|
||||
def subprocess_signal(sig, frame):
|
||||
if semaphore.acquire(timeout=1):
|
||||
logger.warning('Ctrl-C pressed, exiting sub processes ...')
|
||||
|
||||
raise KeyboardInterrupt
|
||||
|
Loading…
x
Reference in New Issue
Block a user