diff --git a/sanic/__version__.py b/sanic/__version__.py index 325664388b..ec8701ae2a 100644 --- a/sanic/__version__.py +++ b/sanic/__version__.py @@ -1 +1 @@ -__version__ = "21.9.0" +__version__ = "21.9.3" diff --git a/sanic/app.py b/sanic/app.py index 01aa07cb0b..ac1a9a1b28 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -173,18 +173,18 @@ def __init__( self.asgi = False self.auto_reload = False self.blueprints: Dict[str, Blueprint] = {} - self.config = config or Config( - load_env=load_env, env_prefix=env_prefix + self.config: Config = config or Config( + load_env=load_env, + env_prefix=env_prefix, + app=self, ) - self.configure_logging = configure_logging - self.ctx = ctx or SimpleNamespace() + self.configure_logging: bool = configure_logging + self.ctx: Any = ctx or SimpleNamespace() self.debug = None - self.error_handler = error_handler or ErrorHandler( - fallback=self.config.FALLBACK_ERROR_FORMAT, - ) + self.error_handler: ErrorHandler = error_handler or ErrorHandler() self.is_running = False self.is_stopping = False - self.listeners: Dict[str, List[ListenerType]] = defaultdict(list) + self.listeners: Dict[str, List[ListenerType[Any]]] = defaultdict(list) self.named_request_middleware: Dict[str, Deque[MiddlewareType]] = {} self.named_response_middleware: Dict[str, Deque[MiddlewareType]] = {} self.reload_dirs: Set[Path] = set() @@ -1474,6 +1474,9 @@ def signalize(self): async def _startup(self): self.signalize() self.finalize() + ErrorHandler.finalize( + self.error_handler, fallback=self.config.FALLBACK_ERROR_FORMAT + ) TouchUp.run(self) async def _server_event( diff --git a/sanic/config.py b/sanic/config.py index 649d9414bc..0484450f0f 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -1,7 +1,9 @@ +from __future__ import annotations + from inspect import isclass from os import environ from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from warnings import warn from sanic.errorpages import check_error_format @@ -10,6 +12,10 @@ from .utils import load_module_from_file_location, str_to_bool +if TYPE_CHECKING: # no cov + from sanic import Sanic + + SANIC_PREFIX = "SANIC_" BASE_LOGO = """ @@ -71,11 +77,14 @@ def __init__( load_env: Optional[Union[bool, str]] = True, env_prefix: Optional[str] = SANIC_PREFIX, keep_alive: Optional[bool] = None, + *, + app: Optional[Sanic] = None, ): defaults = defaults or {} super().__init__({**DEFAULT_CONFIG, **defaults}) - self.LOGO = BASE_LOGO + self._app = app + self._LOGO = BASE_LOGO if keep_alive is not None: self.KEEP_ALIVE = keep_alive @@ -97,6 +106,7 @@ def __init__( self._configure_header_size() self._check_error_format() + self._init = True def __getattr__(self, attr): try: @@ -104,16 +114,51 @@ def __getattr__(self, attr): except KeyError as ke: raise AttributeError(f"Config has no '{ke.args[0]}'") - def __setattr__(self, attr, value): - self[attr] = value - if attr in ( - "REQUEST_MAX_HEADER_SIZE", - "REQUEST_BUFFER_SIZE", - "REQUEST_MAX_SIZE", - ): - self._configure_header_size() - elif attr == "FALLBACK_ERROR_FORMAT": - self._check_error_format() + def __setattr__(self, attr, value) -> None: + self.update({attr: value}) + + def __setitem__(self, attr, value) -> None: + self.update({attr: value}) + + def update(self, *other, **kwargs) -> None: + other_mapping = {k: v for item in other for k, v in dict(item).items()} + super().update(*other, **kwargs) + for attr, value in {**other_mapping, **kwargs}.items(): + self._post_set(attr, value) + + def _post_set(self, attr, value) -> None: + if self.get("_init"): + if attr in ( + "REQUEST_MAX_HEADER_SIZE", + "REQUEST_BUFFER_SIZE", + "REQUEST_MAX_SIZE", + ): + self._configure_header_size() + elif attr == "FALLBACK_ERROR_FORMAT": + self._check_error_format() + if self.app and value != self.app.error_handler.fallback: + if self.app.error_handler.fallback != "auto": + warn( + "Overriding non-default ErrorHandler fallback " + "value. Changing from " + f"{self.app.error_handler.fallback} to {value}." + ) + self.app.error_handler.fallback = value + elif attr == "LOGO": + self._LOGO = value + warn( + "Setting the config.LOGO is deprecated and will no longer " + "be supported starting in v22.6.", + DeprecationWarning, + ) + + @property + def app(self): + return self._app + + @property + def LOGO(self): + return self._LOGO def _configure_header_size(self): Http.set_header_max_size( diff --git a/sanic/errorpages.py b/sanic/errorpages.py index 82cdd57a5c..4ef59390b6 100644 --- a/sanic/errorpages.py +++ b/sanic/errorpages.py @@ -393,7 +393,8 @@ def exception_response( # from the route if request.route: try: - render_format = request.route.ctx.error_format + if request.route.ctx.error_format: + render_format = request.route.ctx.error_format except AttributeError: ... diff --git a/sanic/handlers.py b/sanic/handlers.py index ffeb76b8d8..35deff0f35 100644 --- a/sanic/handlers.py +++ b/sanic/handlers.py @@ -1,3 +1,4 @@ +from inspect import signature from typing import Dict, List, Optional, Tuple, Type from sanic.errorpages import BaseRenderer, HTMLRenderer, exception_response @@ -25,7 +26,9 @@ class ErrorHandler: """ # Beginning in v22.3, the base renderer will be TextRenderer - def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer): + def __init__( + self, fallback: str = "auto", base: Type[BaseRenderer] = HTMLRenderer + ): self.handlers: List[Tuple[Type[BaseException], RouteHandler]] = [] self.cached_handlers: Dict[ Tuple[Type[BaseException], Optional[str]], Optional[RouteHandler] @@ -34,6 +37,41 @@ def __init__(self, fallback: str, base: Type[BaseRenderer] = HTMLRenderer): self.fallback = fallback self.base = base + @classmethod + def finalize(cls, error_handler, fallback: Optional[str] = None): + if ( + fallback + and fallback != "auto" + and error_handler.fallback == "auto" + ): + error_handler.fallback = fallback + + if not isinstance(error_handler, cls): + error_logger.warning( + f"Error handler is non-conforming: {type(error_handler)}" + ) + + sig = signature(error_handler.lookup) + if len(sig.parameters) == 1: + error_logger.warning( + DeprecationWarning( + "You are using a deprecated error handler. The lookup " + "method should accept two positional parameters: " + "(exception, route_name: Optional[str]). " + "Until you upgrade your ErrorHandler.lookup, Blueprint " + "specific exceptions will not work properly. Beginning " + "in v22.3, the legacy style lookup method will not " + "work at all." + ), + ) + error_handler._lookup = error_handler._legacy_lookup + + def _full_lookup(self, exception, route_name: Optional[str] = None): + return self.lookup(exception, route_name) + + def _legacy_lookup(self, exception, route_name: Optional[str] = None): + return self.lookup(exception) + def add(self, exception, handler, route_names: Optional[List[str]] = None): """ Add a new exception handler to an already existing handler object. @@ -56,7 +94,7 @@ def add(self, exception, handler, route_names: Optional[List[str]] = None): else: self.cached_handlers[(exception, None)] = handler - def lookup(self, exception, route_name: Optional[str]): + def lookup(self, exception, route_name: Optional[str] = None): """ Lookup the existing instance of :class:`ErrorHandler` and fetch the registered handler for a specific type of exception. @@ -94,6 +132,8 @@ def lookup(self, exception, route_name: Optional[str]): handler = None return handler + _lookup = _full_lookup + def response(self, request, exception): """Fetches and executes an exception handler and returns a response object @@ -109,7 +149,7 @@ def response(self, request, exception): or registered handler for that type of exception. """ route_name = request.name if request else None - handler = self.lookup(exception, route_name) + handler = self._lookup(exception, route_name) response = None try: if handler: diff --git a/sanic/http.py b/sanic/http.py index d30e4c82b8..6f59ef250f 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -105,7 +105,6 @@ def __init__(self, protocol): self.keep_alive = True self.stage: Stage = Stage.IDLE self.dispatch = self.protocol.app.dispatch - self.init_for_request() def init_for_request(self): """Init/reset all per-request variables.""" @@ -129,14 +128,20 @@ async def http1(self): """ HTTP 1.1 connection handler """ - while True: # As long as connection stays keep-alive + # Handle requests while the connection stays reusable + while self.keep_alive and self.stage is Stage.IDLE: + self.init_for_request() + # Wait for incoming bytes (in IDLE stage) + if not self.recv_buffer: + await self._receive_more() + self.stage = Stage.REQUEST try: # Receive and handle a request - self.stage = Stage.REQUEST self.response_func = self.http1_response_header await self.http1_request_header() + self.stage = Stage.HANDLER self.request.conn_info = self.protocol.conn_info await self.protocol.request_handler(self.request) @@ -187,16 +192,6 @@ async def http1(self): if self.response: self.response.stream = None - # Exit and disconnect if no more requests can be taken - if self.stage is not Stage.IDLE or not self.keep_alive: - break - - self.init_for_request() - - # Wait for the next request - if not self.recv_buffer: - await self._receive_more() - async def http1_request_header(self): # no cov """ Receive and parse request header into self.request. @@ -299,7 +294,6 @@ async def http1_request_header(self): # no cov # Remove header and its trailing CRLF del buf[: pos + 4] - self.stage = Stage.HANDLER self.request, request.stream = request, self self.protocol.state["requests_count"] += 1 diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 8467a2e340..7139cd3c8a 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -918,7 +918,7 @@ def _register_static( return route - def _determine_error_format(self, handler) -> str: + def _determine_error_format(self, handler) -> Optional[str]: if not isinstance(handler, CompositionView): try: src = dedent(getsource(handler)) @@ -930,7 +930,7 @@ def _determine_error_format(self, handler) -> str: except (OSError, TypeError): ... - return "auto" + return None def _get_response_types(self, node): types = set() diff --git a/sanic/router.py b/sanic/router.py index 6995ed6da4..b15c2a3e16 100644 --- a/sanic/router.py +++ b/sanic/router.py @@ -139,11 +139,10 @@ def add( # type: ignore route.ctx.stream = stream route.ctx.hosts = hosts route.ctx.static = static - route.ctx.error_format = ( - error_format or self.ctx.app.config.FALLBACK_ERROR_FORMAT - ) + route.ctx.error_format = error_format - check_error_format(route.ctx.error_format) + if error_format: + check_error_format(route.ctx.error_format) routes.append(route) diff --git a/tests/test_config.py b/tests/test_config.py index 42a7e3ecdb..67324f1e25 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,6 +3,7 @@ from pathlib import Path from tempfile import TemporaryDirectory from textwrap import dedent +from unittest.mock import Mock import pytest @@ -350,3 +351,40 @@ def test_update_from_lowercase_key(app): d = {"test_setting_value": 1} app.update_config(d) assert "test_setting_value" not in app.config + + +def test_deprecation_notice_when_setting_logo(app): + message = ( + "Setting the config.LOGO is deprecated and will no longer be " + "supported starting in v22.6." + ) + with pytest.warns(DeprecationWarning, match=message): + app.config.LOGO = "My Custom Logo" + + +def test_config_set_methods(app, monkeypatch): + post_set = Mock() + monkeypatch.setattr(Config, "_post_set", post_set) + + app.config.FOO = 1 + post_set.assert_called_once_with("FOO", 1) + post_set.reset_mock() + + app.config["FOO"] = 2 + post_set.assert_called_once_with("FOO", 2) + post_set.reset_mock() + + app.config.update({"FOO": 3}) + post_set.assert_called_once_with("FOO", 3) + post_set.reset_mock() + + app.config.update([("FOO", 4)]) + post_set.assert_called_once_with("FOO", 4) + post_set.reset_mock() + + app.config.update(FOO=5) + post_set.assert_called_once_with("FOO", 5) + post_set.reset_mock() + + app.config.update_config({"FOO": 6}) + post_set.assert_called_once_with("FOO", 6) diff --git a/tests/test_errorpages.py b/tests/test_errorpages.py index 5af4ca5fe0..1843f6a707 100644 --- a/tests/test_errorpages.py +++ b/tests/test_errorpages.py @@ -1,8 +1,10 @@ import pytest from sanic import Sanic +from sanic.config import Config from sanic.errorpages import HTMLRenderer, exception_response from sanic.exceptions import NotFound, SanicException +from sanic.handlers import ErrorHandler from sanic.request import Request from sanic.response import HTTPResponse, html, json, text @@ -271,3 +273,72 @@ def test_combinations_for_auto(fake_request, accept, content_type, expected): ) assert response.content_type == expected + + +def test_allow_fallback_error_format_set_main_process_start(app): + @app.main_process_start + async def start(app, _): + app.config.FALLBACK_ERROR_FORMAT = "text" + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_setting_fallback_to_non_default_raise_warning(app): + app.error_handler = ErrorHandler(fallback="text") + + assert app.error_handler.fallback == "text" + + with pytest.warns( + UserWarning, + match=( + "Overriding non-default ErrorHandler fallback value. " + "Changing from text to auto." + ), + ): + app.config.FALLBACK_ERROR_FORMAT = "auto" + + assert app.error_handler.fallback == "auto" + + app.config.FALLBACK_ERROR_FORMAT = "text" + + with pytest.warns( + UserWarning, + match=( + "Overriding non-default ErrorHandler fallback value. " + "Changing from text to json." + ), + ): + app.config.FALLBACK_ERROR_FORMAT = "json" + + assert app.error_handler.fallback == "json" + + +def test_allow_fallback_error_format_in_config_injection(): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app = Sanic("test", config=MyConfig()) + + @app.route("/error", methods=["GET", "POST"]) + def err(request): + raise Exception("something went wrong") + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" + + +def test_allow_fallback_error_format_in_config_replacement(app): + class MyConfig(Config): + FALLBACK_ERROR_FORMAT = "text" + + app.config = MyConfig() + + request, response = app.test_client.get("/error") + assert request.app.error_handler.fallback == "text" + assert response.status == 500 + assert response.content_type == "text/plain; charset=utf-8" diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 29797e1e1f..503e47cbb1 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -4,6 +4,7 @@ import pytest from bs4 import BeautifulSoup +from websockets.version import version as websockets_version from sanic import Sanic from sanic.exceptions import ( @@ -16,7 +17,6 @@ abort, ) from sanic.response import text -from websockets.version import version as websockets_version class SanicExceptionTestException(Exception): diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index dbf9fcbb9b..9bedf7e67c 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -1,4 +1,5 @@ import asyncio +import logging import pytest @@ -206,3 +207,23 @@ def test_exception_handler_processed_request_middleware(exception_handler_app): request, response = exception_handler_app.test_client.get("/8") assert response.status == 200 assert response.text == "Done." + + +def test_single_arg_exception_handler_notice(exception_handler_app, caplog): + class CustomErrorHandler(ErrorHandler): + def lookup(self, exception): + return super().lookup(exception, None) + + exception_handler_app.error_handler = CustomErrorHandler() + + with caplog.at_level(logging.WARNING): + _, response = exception_handler_app.test_client.get("/1") + + assert caplog.records[0].message == ( + "You are using a deprecated error handler. The lookup method should " + "accept two positional parameters: (exception, route_name: " + "Optional[str]). Until you upgrade your ErrorHandler.lookup, " + "Blueprint specific exceptions will not work properly. Beginning in " + "v22.3, the legacy style lookup method will not work at all." + ) + assert response.status == 400 diff --git a/tests/test_request_timeout.py b/tests/test_request_timeout.py deleted file mode 100644 index 48e23f1d63..0000000000 --- a/tests/test_request_timeout.py +++ /dev/null @@ -1,109 +0,0 @@ -import asyncio - -import httpcore -import httpx -import pytest - -from sanic_testing.testing import SanicTestClient - -from sanic import Sanic -from sanic.response import text - - -class DelayableHTTPConnection(httpcore._async.connection.AsyncHTTPConnection): - async def arequest(self, *args, **kwargs): - await asyncio.sleep(2) - return await super().arequest(*args, **kwargs) - - async def _open_socket(self, *args, **kwargs): - retval = await super()._open_socket(*args, **kwargs) - if self._request_delay: - await asyncio.sleep(self._request_delay) - return retval - - -class DelayableSanicConnectionPool(httpcore.AsyncConnectionPool): - def __init__(self, request_delay=None, *args, **kwargs): - self._request_delay = request_delay - super().__init__(*args, **kwargs) - - async def _add_to_pool(self, connection, timeout): - connection.__class__ = DelayableHTTPConnection - connection._request_delay = self._request_delay - await super()._add_to_pool(connection, timeout) - - -class DelayableSanicSession(httpx.AsyncClient): - def __init__(self, request_delay=None, *args, **kwargs) -> None: - transport = DelayableSanicConnectionPool(request_delay=request_delay) - super().__init__(transport=transport, *args, **kwargs) - - -class DelayableSanicTestClient(SanicTestClient): - def __init__(self, app, request_delay=None): - super().__init__(app) - self._request_delay = request_delay - self._loop = None - - def get_new_session(self): - return DelayableSanicSession(request_delay=self._request_delay) - - -@pytest.fixture -def request_no_timeout_app(): - app = Sanic("test_request_no_timeout") - app.config.REQUEST_TIMEOUT = 0.6 - - @app.route("/1") - async def handler2(request): - return text("OK") - - return app - - -@pytest.fixture -def request_timeout_default_app(): - app = Sanic("test_request_timeout_default") - app.config.REQUEST_TIMEOUT = 0.6 - - @app.route("/1") - async def handler1(request): - return text("OK") - - @app.websocket("/ws1") - async def ws_handler1(request, ws): - await ws.send("OK") - - return app - - -def test_default_server_error_request_timeout(request_timeout_default_app): - client = DelayableSanicTestClient(request_timeout_default_app, 2) - _, response = client.get("/1") - assert response.status == 408 - assert "Request Timeout" in response.text - - -def test_default_server_error_request_dont_timeout(request_no_timeout_app): - client = DelayableSanicTestClient(request_no_timeout_app, 0.2) - _, response = client.get("/1") - assert response.status == 200 - assert response.text == "OK" - - -def test_default_server_error_websocket_request_timeout( - request_timeout_default_app, -): - - headers = { - "Upgrade": "websocket", - "Connection": "upgrade", - "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", - "Sec-WebSocket-Version": "13", - } - - client = DelayableSanicTestClient(request_timeout_default_app, 2) - _, response = client.get("/ws1", headers=headers) - - assert response.status == 408 - assert "Request Timeout" in response.text diff --git a/tests/test_timeout_logic.py b/tests/test_timeout_logic.py index 05249f11cf..497deda92a 100644 --- a/tests/test_timeout_logic.py +++ b/tests/test_timeout_logic.py @@ -26,6 +26,7 @@ def protocol(app, mock_transport): protocol = HttpProtocol(loop=loop, app=app) protocol.connection_made(mock_transport) protocol._setup_connection() + protocol._http.init_for_request() protocol._task = Mock(spec=asyncio.Task) protocol._task.cancel = Mock() return protocol