Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: process_start_urls in parallel #159

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions ruia/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@

import asyncio
import weakref

from asyncio.locks import Semaphore
from inspect import iscoroutinefunction
from types import AsyncGeneratorType
from typing import Coroutine, Optional, Tuple

import aiohttp
import async_timeout

from ruia.exceptions import InvalidRequestMethod
from ruia.response import Response
from ruia.utils import get_logger
Expand Down
3 changes: 2 additions & 1 deletion ruia/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import asyncio
import json

from http.cookies import SimpleCookie
from typing import Any, Callable, Optional

Expand Down Expand Up @@ -163,6 +162,8 @@ async def text(
) -> str:
"""Read response payload and decode."""
encoding = encoding or self._encoding
if self._aws_text is None:
return ''
self._html = await self._aws_text(encoding=encoding, errors=errors)
return self._html

Expand Down
44 changes: 34 additions & 10 deletions ruia/spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from types import AsyncGeneratorType

from aiohttp import ClientSession

from ruia.exceptions import NothingMatchedError, NotImplementedParseError
from ruia.item import Item
from ruia.middleware import Middleware
Expand Down Expand Up @@ -67,6 +66,7 @@ class Spider(SpiderHook):
# Concurrency control
worker_numbers: int = 2
concurrency: int = 3
max_batch_size: int = 30

# Spider entry
start_urls: list = []
Expand All @@ -82,6 +82,7 @@ def __init__(
cancel_tasks: bool = True,
**spider_kwargs,
):
print("use hf ruia")
"""
Init spider object.
:param middleware: a list of or a single Middleware
Expand Down Expand Up @@ -166,12 +167,13 @@ async def _process_response(self, request: Request, response: Response):
if response:
if response.ok:
# Process succeed response
self.success_counts += 1
await self.process_succeed_response(request, response)
return True
else:
# Process failed response
self.failed_counts += 1
await self.process_failed_response(request, response)
return False
return False

async def _run_request_middleware(self, request: Request):
if self.middleware.request_middleware:
Expand Down Expand Up @@ -350,12 +352,15 @@ async def handle_request(
typing.Tuple[AsyncGeneratorType, Request, Response]: Returns a result tuple after each request
"""
callback_result, response = None, None
if_success = False

try:
await self._run_request_middleware(request)
callback_result, response = await request.fetch_callback(self.sem)
await self._run_response_middleware(request, response)
await self._process_response(request=request, response=response)
if_success = await self._process_response(
request=request, response=response
)
except NotImplementedParseError as e:
self.logger.error(e)
except NothingMatchedError as e:
Expand All @@ -364,6 +369,11 @@ async def handle_request(
except Exception as e:
self.logger.error(f"<Callback[{request.callback.__name__}]: {e}")

if if_success:
self.success_counts += 1
else:
self.failed_counts += 1

return callback_result, request, response

async def multiple_request(self, urls, is_gather=False, **kwargs):
Expand Down Expand Up @@ -463,31 +473,41 @@ async def start_master(self):
"""
Actually start crawling
"""
async for request_ins in self.process_start_urls():
self.request_queue.put_nowait(self.handle_request(request_ins))
process_urls_task = asyncio.create_task(self.enqueue_start_urls())

workers = [
asyncio.ensure_future(self.start_worker())
asyncio.ensure_future(self.start_worker(i))
for i in range(self.worker_numbers)
]
for worker in workers:
self.logger.info(f"Worker started: {id(worker)}")
await self.request_queue.join()

await asyncio.gather(process_urls_task, self.request_queue.join())

if not self.is_async_start:
await self.stop(SIGINT)
else:
if self.cancel_tasks:
await self.cancel_all_tasks()

async def start_worker(self):
async def enqueue_start_urls(self):
async for request_ins in self.process_start_urls():
await self.request_queue.put(self.handle_request(request_ins))

async def start_worker(self, i: int):
"""
Start spider worker
:return:
"""
while True:
request_item = await self.request_queue.get()
self.worker_tasks.append(request_item)
if self.request_queue.empty():
# TODO: 这样写不是很好,它把所有的 queue 都拿过来了,再去请求?
# 假设有无限的 urls 呢
if (
self.request_queue.empty()
or len(self.worker_tasks) > self.max_batch_size
):
results = await asyncio.gather(
*self.worker_tasks, return_exceptions=True
)
Expand All @@ -505,12 +525,16 @@ async def start_worker(self):
self.worker_tasks = []
self.request_queue.task_done()

async def cancel_callback(self):
self.logger.info("Spider Cancaled")

async def stop(self, _signal):
"""
Finish all running tasks, cancel remaining tasks.
:param _signal:
:return:
"""
self.logger.info(f"Stopping spider: {self.name}")
await self.cancel_callback()
await self.cancel_all_tasks()
# self.loop.stop()