diff --git a/server/core.py b/server/core.py index 724261e..e580c9a 100644 --- a/server/core.py +++ b/server/core.py @@ -1,12 +1,24 @@ +from __future__ import annotations + import asyncio from contextlib import asynccontextmanager -from typing import Any, Generator, Optional, Self +from typing import TYPE_CHECKING, Any, Generator, Optional, Self import asyncpg from fastapi import Depends, FastAPI, status +from fastapi.exceptions import HTTPException, RequestValidationError from fastapi.openapi.utils import get_openapi -from fastapi.responses import ORJSONResponse +from fastapi.responses import ORJSONResponse, Response +from fastapi.utils import is_body_allowed_for_status_code from utils.config import KanaeConfig +from utils.errors import ( + HTTPExceptionMessage, + RequestValidationErrorDetails, + RequestValidationErrorMessage, +) + +if TYPE_CHECKING: + from utils.request import RouteRequest __title__ = "Kanae" __description__ = """ @@ -36,12 +48,50 @@ def __init__( version=__version__, dependencies=[Depends(self.get_db)], default_response_class=ORJSONResponse, + responses={400: {"model": RequestValidationErrorMessage}}, loop=self.loop, redoc_url="/docs", docs_url=None, lifespan=self.lifespan, ) self.config = config + self.add_exception_handler( + HTTPException, + self.http_exception_handler, # type: ignore + ) + self.add_exception_handler( + RequestValidationError, + self.request_validation_error_handler, # type: ignore + ) + + ### Exception Handlers + + async def http_exception_handler( + self, request: RouteRequest, exc: HTTPException + ) -> Response: + headers = getattr(exc, "headers", None) + if not is_body_allowed_for_status_code(exc.status_code): + return Response(status_code=exc.status_code, headers=headers) + message = HTTPExceptionMessage(detail=exc.detail) + return ORJSONResponse( + content=message.model_dump(), status_code=exc.status_code, headers=headers + ) + + async def request_validation_error_handler( + self, request: RouteRequest, exc: RequestValidationError + ) -> ORJSONResponse: + message = RequestValidationErrorMessage( + errors=[ + RequestValidationErrorDetails( + detail=exception["msg"], context=exception["ctx"]["error"] + ) + for exception in exc.errors() + ] + ) + + return ORJSONResponse( + content=message.model_dump(), status_code=status.HTTP_400_BAD_REQUEST + ) ### Server-related utilities diff --git a/server/utils/errors.py b/server/utils/errors.py index 92b5a31..8d4dcd6 100644 --- a/server/utils/errors.py +++ b/server/utils/errors.py @@ -12,3 +12,18 @@ class NotFoundException(HTTPException): def __init__(self, detail: str = HTTP_404_DETAIL): self.status_code = 404 self.detail = detail + + +class RequestValidationErrorDetails(BaseModel, frozen=True): + detail: str + context: str + + +class RequestValidationErrorMessage(BaseModel, frozen=True): + result: str = "error" + errors: list[RequestValidationErrorDetails] + + +class HTTPExceptionMessage(BaseModel, frozen=True): + result: str = "error" + detail: str