Skip to content

Commit

Permalink
Implement request validation handling for 422 responses + remove from…
Browse files Browse the repository at this point in the history
… openapi schema
  • Loading branch information
No767 committed Nov 18, 2024
1 parent 0ade06e commit 59a05f6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 3 deletions.
59 changes: 57 additions & 2 deletions server/core.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from __future__ import annotations

import asyncio
from contextlib import asynccontextmanager
from typing import Literal, NamedTuple, Optional
from typing import Literal, NamedTuple, Optional, Generator, TYPE_CHECKING, Any

from itertools import chain
from collections import OrderedDict
import asyncpg
from fastapi import FastAPI
from fastapi import FastAPI, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import ORJSONResponse
from typing_extensions import Self
from utils.config import KanaeConfig
from fastapi.openapi.utils import get_openapi

if TYPE_CHECKING:
from utils.request import RouteRequest


class VersionInfo(NamedTuple):
Expand Down Expand Up @@ -47,8 +57,53 @@ def __init__(
lifespan=self.lifespan,
)
self.config = config
self.add_exception_handler(
RequestValidationError,
self.request_validation_error_handler, # type: ignore
)

### Exception Handlers

async def request_validation_error_handler(
self, request: RouteRequest, exc: RequestValidationError
) -> ORJSONResponse:
errors = ", ".join(
OrderedDict.fromkeys(
chain.from_iterable(exception["loc"] for exception in exc.errors())
).keys()
)
return ORJSONResponse(content=f"Field required at: {errors}", status_code=422)

### Server-related utilities

@asynccontextmanager
async def lifespan(self, app: Self):
async with asyncpg.create_pool(dsn=self.config["postgres_uri"]) as app.pool:
yield

def get_db(self) -> Generator[asyncpg.Pool, None, None]:
yield self.pool

def openapi(self) -> dict[str, Any]:
if not self.openapi_schema:
self.openapi_schema = get_openapi(
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
summary=self.summary,
description=self.description,
terms_of_service=self.terms_of_service,
contact=self.contact,
license_info=self.license_info,
routes=self.routes,
webhooks=self.webhooks.routes,
tags=self.openapi_tags,
servers=self.servers,
separate_input_output_schemas=self.separate_input_output_schemas,
)
for path in self.openapi_schema["paths"].values():
for method in path.values():
responses = method.get("responses")
if str(status.HTTP_422_UNPROCESSABLE_ENTITY) in responses:
del responses[str(status.HTTP_422_UNPROCESSABLE_ENTITY)]
return self.openapi_schema
1 change: 0 additions & 1 deletion server/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore
app.state.limiter = router.limiter


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down

0 comments on commit 59a05f6

Please sign in to comment.