searxng/searx/network/__init__.py
2023-09-18 16:20:27 +02:00

267 lines
8.1 KiB
Python

# SPDX-License-Identifier: AGPL-3.0-or-later
# lint: pylint
# pylint: disable=missing-module-docstring, global-statement
import asyncio
import threading
import concurrent.futures
from queue import SimpleQueue
from types import MethodType
from timeit import default_timer
from typing import Iterable, NamedTuple, Tuple, List, Dict, Union
from contextlib import contextmanager
import httpx
import anyio
from .network import get_network, initialize, check_network_configuration # pylint:disable=cyclic-import
from .client import get_loop
from .raise_for_httperror import raise_for_httperror
THREADLOCAL = threading.local()
"""Thread-local data is data for thread specific values."""
def reset_time_for_thread():
THREADLOCAL.total_time = 0
def get_time_for_thread():
"""returns thread's total time or None"""
return THREADLOCAL.__dict__.get('total_time')
def set_timeout_for_thread(timeout, start_time=None):
THREADLOCAL.timeout = timeout
THREADLOCAL.start_time = start_time
def set_context_network_name(network_name):
THREADLOCAL.network = get_network(network_name)
def get_context_network():
"""If set return thread's network.
If unset, return value from :py:obj:`get_network`.
"""
return THREADLOCAL.__dict__.get('network') or get_network()
@contextmanager
def _record_http_time():
# pylint: disable=too-many-branches
time_before_request = default_timer()
start_time = getattr(THREADLOCAL, 'start_time', time_before_request)
try:
yield start_time
finally:
# update total_time.
# See get_time_for_thread() and reset_time_for_thread()
if hasattr(THREADLOCAL, 'total_time'):
time_after_request = default_timer()
THREADLOCAL.total_time += time_after_request - time_before_request
def _get_timeout(start_time, kwargs):
# pylint: disable=too-many-branches
# timeout (httpx)
if 'timeout' in kwargs:
timeout = kwargs['timeout']
else:
timeout = getattr(THREADLOCAL, 'timeout', None)
if timeout is not None:
kwargs['timeout'] = timeout
# 2 minutes timeout for the requests without timeout
timeout = timeout or 120
# adjust actual timeout
timeout += 0.2 # overhead
if start_time:
timeout -= default_timer() - start_time
return timeout
def request(method, url, **kwargs):
"""same as requests/requests/api.py request(...)"""
with _record_http_time() as start_time:
network = get_context_network()
timeout = _get_timeout(start_time, kwargs)
future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
try:
return future.result(timeout)
except concurrent.futures.TimeoutError as e:
raise httpx.TimeoutException('Timeout', request=None) from e
def multi_requests(request_list: List["Request"]) -> List[Union[httpx.Response, Exception]]:
"""send multiple HTTP requests in parallel. Wait for all requests to finish."""
with _record_http_time() as start_time:
# send the requests
network = get_context_network()
loop = get_loop()
future_list = []
for request_desc in request_list:
timeout = _get_timeout(start_time, request_desc.kwargs)
future = asyncio.run_coroutine_threadsafe(
network.request(request_desc.method, request_desc.url, **request_desc.kwargs), loop
)
future_list.append((future, timeout))
# read the responses
responses = []
for future, timeout in future_list:
try:
responses.append(future.result(timeout))
except concurrent.futures.TimeoutError:
responses.append(httpx.TimeoutException('Timeout', request=None))
except Exception as e: # pylint: disable=broad-except
responses.append(e)
return responses
class Request(NamedTuple):
"""Request description for the multi_requests function"""
method: str
url: str
kwargs: Dict[str, str] = {}
@staticmethod
def get(url, **kwargs):
return Request('GET', url, kwargs)
@staticmethod
def options(url, **kwargs):
return Request('OPTIONS', url, kwargs)
@staticmethod
def head(url, **kwargs):
return Request('HEAD', url, kwargs)
@staticmethod
def post(url, **kwargs):
return Request('POST', url, kwargs)
@staticmethod
def put(url, **kwargs):
return Request('PUT', url, kwargs)
@staticmethod
def patch(url, **kwargs):
return Request('PATCH', url, kwargs)
@staticmethod
def delete(url, **kwargs):
return Request('DELETE', url, kwargs)
def get(url, **kwargs):
kwargs.setdefault('allow_redirects', True)
return request('get', url, **kwargs)
def options(url, **kwargs):
kwargs.setdefault('allow_redirects', True)
return request('options', url, **kwargs)
def head(url, **kwargs):
kwargs.setdefault('allow_redirects', False)
return request('head', url, **kwargs)
def post(url, data=None, **kwargs):
return request('post', url, data=data, **kwargs)
def put(url, data=None, **kwargs):
return request('put', url, data=data, **kwargs)
def patch(url, data=None, **kwargs):
return request('patch', url, data=data, **kwargs)
def delete(url, **kwargs):
return request('delete', url, **kwargs)
async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
try:
async with await network.stream(method, url, **kwargs) as response:
queue.put(response)
# aiter_raw: access the raw bytes on the response without applying any HTTP content decoding
# https://www.python-httpx.org/quickstart/#streaming-responses
async for chunk in response.aiter_raw(65536):
if len(chunk) > 0:
queue.put(chunk)
except (httpx.StreamClosed, anyio.ClosedResourceError):
# the response was queued before the exception.
# the exception was raised on aiter_raw.
# we do nothing here: in the finally block, None will be queued
# so stream(method, url, **kwargs) generator can stop
pass
except Exception as e: # pylint: disable=broad-except
# broad except to avoid this scenario:
# exception in network.stream(method, url, **kwargs)
# -> the exception is not catch here
# -> queue None (in finally)
# -> the function below steam(method, url, **kwargs) has nothing to return
queue.put(e)
finally:
queue.put(None)
def _stream_generator(method, url, **kwargs):
queue = SimpleQueue()
network = get_context_network()
future = asyncio.run_coroutine_threadsafe(stream_chunk_to_queue(network, queue, method, url, **kwargs), get_loop())
# yield chunks
obj_or_exception = queue.get()
while obj_or_exception is not None:
if isinstance(obj_or_exception, Exception):
raise obj_or_exception
yield obj_or_exception
obj_or_exception = queue.get()
future.result()
def _close_response_method(self):
asyncio.run_coroutine_threadsafe(self.aclose(), get_loop())
# reach the end of _self.generator ( _stream_generator ) to an avoid memory leak.
# it makes sure that :
# * the httpx response is closed (see the stream_chunk_to_queue function)
# * to call future.result() in _stream_generator
for _ in self._generator: # pylint: disable=protected-access
continue
def stream(method, url, **kwargs) -> Tuple[httpx.Response, Iterable[bytes]]:
"""Replace httpx.stream.
Usage:
response, stream = poolrequests.stream(...)
for chunk in stream:
...
httpx.Client.stream requires to write the httpx.HTTPTransport version of the
the httpx.AsyncHTTPTransport declared above.
"""
generator = _stream_generator(method, url, **kwargs)
# yield response
response = next(generator) # pylint: disable=stop-iteration-return
if isinstance(response, Exception):
raise response
response._generator = generator # pylint: disable=protected-access
response.close = MethodType(_close_response_method, response)
return response, generator