From 82bf5de0ba6b9098c548664bb19eae5837dfc48c Mon Sep 17 00:00:00 2001 From: Noelle Wang <73260931+No767@users.noreply.github.com> Date: Sun, 17 Nov 2024 20:48:55 -0800 Subject: [PATCH] Remove HTTP 422 responses from OpenAPI schema (#24) --- server/core.py | 56 ++++++++++++++++++++++++++++++++++++++++++++-- server/launcher.py | 1 - 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/server/core.py b/server/core.py index 5ef3078..2d92c8d 100644 --- a/server/core.py +++ b/server/core.py @@ -1,12 +1,22 @@ +from __future__ import annotations + import asyncio +from collections import OrderedDict from contextlib import asynccontextmanager -from typing import Literal, NamedTuple, Optional +from itertools import chain +from typing import TYPE_CHECKING, Any, Generator, Literal, NamedTuple, Optional import asyncpg -from fastapi import FastAPI +from fastapi import FastAPI, status +from fastapi.exceptions import RequestValidationError +from fastapi.openapi.utils import get_openapi +from fastapi.responses import ORJSONResponse from typing_extensions import Self from utils.config import KanaeConfig +if TYPE_CHECKING: + from utils.request import RouteRequest + class VersionInfo(NamedTuple): major: int @@ -47,8 +57,50 @@ def __init__( lifespan=self.lifespan, ) self.config = config + self.add_exception_handler( + RequestValidationError, + self.request_validation_error_handler, # type: ignore + ) + + ### Exception Handlers + + async def request_validation_error_handler( + self, request: RouteRequest, exc: RequestValidationError + ) -> ORJSONResponse: + errors = ", ".join( + OrderedDict.fromkeys( + chain.from_iterable(exception["loc"] for exception in exc.errors()) + ).keys() + ) + return ORJSONResponse(content=f"Field required at: {errors}", status_code=422) + + ### Server-related utilities @asynccontextmanager async def lifespan(self, app: Self): async with asyncpg.create_pool(dsn=self.config["postgres_uri"]) as app.pool: yield + + def get_db(self) -> Generator[asyncpg.Pool, None, None]: + yield self.pool + + def openapi(self) -> dict[str, Any]: + if not self.openapi_schema: + self.openapi_schema = get_openapi( + title=self.title, + version=self.version, + openapi_version=self.openapi_version, + description=self.description, + terms_of_service=self.terms_of_service, + contact=self.contact, + license_info=self.license_info, + routes=self.routes, + tags=self.openapi_tags, + servers=self.servers, + ) + for path in self.openapi_schema["paths"].values(): + for method in path.values(): + responses = method.get("responses") + if str(status.HTTP_422_UNPROCESSABLE_ENTITY) in responses: + del responses[str(status.HTTP_422_UNPROCESSABLE_ENTITY)] + return self.openapi_schema diff --git a/server/launcher.py b/server/launcher.py index ee142a7..864cfee 100644 --- a/server/launcher.py +++ b/server/launcher.py @@ -19,7 +19,6 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore app.state.limiter = router.limiter - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument(