Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove explicit status from return code, instead relying on TypedError's status() method. #1

Open
wants to merge 38 commits into
base: typed-catchers
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
56e7fa6
Initial brush
the10thWiz Jun 20, 2024
c0ad038
Working example
the10thWiz Jun 26, 2024
99e2109
Add error type to logs
the10thWiz Jun 29, 2024
09c56c7
Major improvements
the10thWiz Jun 29, 2024
b68900f
Fix whitespace
the10thWiz Jun 30, 2024
4cb3a3a
Revert local changes to scripts dir
the10thWiz Jun 30, 2024
ac3a7fa
Ensure examples pass CI
the10thWiz Jun 30, 2024
eaea6f6
Add Transient impl for serde::json::Error
the10thWiz Jun 30, 2024
1308c19
tmp
the10thWiz Jul 1, 2024
f8c8bb8
Rework catch attribute
the10thWiz Jul 2, 2024
dea224f
Update tests to use new #[catch] macro
the10thWiz Jul 2, 2024
7b8689c
Update transient and use new features in examples
the10thWiz Jul 13, 2024
fb796fc
Update guide
the10thWiz Jul 13, 2024
af68f5e
Major changes
the10thWiz Aug 17, 2024
a59cb04
Update core server code to use new error trait
the10thWiz Aug 17, 2024
6427db2
Updates to improve many aspects
the10thWiz Aug 24, 2024
6d06ac7
Update code to work properly with borrowed errors
the10thWiz Sep 3, 2024
04ae827
Fix formatting issues
the10thWiz Sep 3, 2024
99bba53
Update codegen with many of the new changes
the10thWiz Sep 3, 2024
f0f2342
Major fixes for matching and responder
the10thWiz Sep 3, 2024
84ba0b7
Update to pass tests
the10thWiz Sep 3, 2024
6ab2d13
Add FromError
the10thWiz Sep 4, 2024
b55b9c7
Implement TypedError for form errors
the10thWiz Sep 4, 2024
61a4b44
Add Fairing support, and update examples to match new APIs
the10thWiz Sep 4, 2024
61cd326
Fix safety issues & comments
the10thWiz Sep 7, 2024
3fbf1b4
Add derive macro for `TypedError`
the10thWiz Sep 8, 2024
c263a6c
Update Fairings types to fix issues
the10thWiz Sep 8, 2024
8266603
Add ui-fail tests
the10thWiz Sep 8, 2024
541efe3
Add safety comments for `Drop` impls
the10thWiz Sep 8, 2024
6a0c57a
Add and update ui tests
the10thWiz Sep 8, 2024
ee6a829
Add intermediate types and fix issues
the10thWiz Sep 8, 2024
6ea0c9d
Updates to tests
the10thWiz Sep 8, 2024
7390ac5
Clean up TODO comments
the10thWiz Sep 9, 2024
7871a0c
Fix TODOs
the10thWiz Sep 9, 2024
ad9b26c
Improve responder and route macros
the10thWiz Sep 13, 2024
cc50b84
Parial Update for docs
the10thWiz Sep 13, 2024
c787f64
Remove explicit Status parameter
the10thWiz Sep 20, 2024
546eea0
Remove explicit status parameter in most cases
the10thWiz Nov 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions contrib/db_pools/codegen/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ pub fn derive_database(input: TokenStream) -> TokenStream {

#[rocket::async_trait]
impl<'r> rocket::request::FromRequest<'r> for &'r #decorated_type {
type Error = ();
type Error = rocket::http::Status;

async fn from_request(
req: &'r rocket::request::Request<'_>
) -> rocket::request::Outcome<Self, Self::Error> {
match #db_ty::fetch(req.rocket()) {
Some(db) => rocket::outcome::Outcome::Success(db),
None => rocket::outcome::Outcome::Error((
rocket::http::Status::InternalServerError, ()))
None => rocket::outcome::Outcome::Error(
rocket::http::Status::InternalServerError
)
}
}
}
Expand Down
20 changes: 15 additions & 5 deletions contrib/db_pools/lib/src/database.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};

use rocket::TypedError;
use rocket::{error, Build, Ignite, Phase, Rocket, Sentinel, Orbit};
use rocket::fairing::{self, Fairing, Info, Kind};
use rocket::request::{FromRequest, Outcome, Request};
use rocket::figment::providers::Serialized;
use rocket::http::Status;

use crate::Pool;

Expand Down Expand Up @@ -278,17 +278,27 @@ impl<D: Database> Fairing for Initializer<D> {
}
}

#[derive(Debug, TypedError)]
pub enum ConnectionError<E> {
#[error(status = 503)]
ServiceUnavailable(E),
#[error(status = 500)]
InternalServerError,
}

#[rocket::async_trait]
impl<'r, D: Database> FromRequest<'r> for Connection<D> {
type Error = Option<<D::Pool as Pool>::Error>;
impl<'r, D: Database> FromRequest<'r> for Connection<D>
where <D::Pool as Pool>::Error: Send + Sync,
{
type Error = ConnectionError<<D::Pool as Pool>::Error>;

async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match D::fetch(req.rocket()) {
Some(db) => match db.get().await {
Ok(conn) => Outcome::Success(Connection(conn)),
Err(e) => Outcome::Error((Status::ServiceUnavailable, Some(e))),
Err(e) => Outcome::Error(ConnectionError::ServiceUnavailable(e)),
},
None => Outcome::Error((Status::InternalServerError, None)),
None => Outcome::Error(ConnectionError::InternalServerError),
}
}
}
Expand Down
40 changes: 34 additions & 6 deletions contrib/db_pools/lib/src/diesel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
//! # #[macro_use] extern crate rocket;
//! # #[cfg(feature = "diesel_mysql")] {
//! use rocket_db_pools::{Database, Connection};
//! use rocket_db_pools::diesel::{QueryResult, MysqlPool, prelude::*};
//! use rocket_db_pools::diesel::{QueryResult, DieselError, MysqlPool, prelude::*};
//!
//! #[derive(Database)]
//! #[database("diesel_mysql")]
Expand Down Expand Up @@ -58,6 +58,11 @@
//!
//! Ok(format!("{post_ids:?}"))
//! }
//!
//! #[catch(500, error = "<e>")]
//! fn catch_diesel_error(e: &DieselError) -> String {
//! format!("{e:?}")
//! }
//! # }
//! ```

Expand Down Expand Up @@ -92,17 +97,40 @@ pub use diesel_async::AsyncMysqlConnection;
#[cfg(feature = "diesel_postgres")]
pub use diesel_async::AsyncPgConnection;

/// Alias of a `Result` with an error type of [`Debug`] for a `diesel::Error`.
use rocket::TypedError;

/// Wrapper type for diesel errors
#[derive(Debug, TypedError)]
pub struct DieselError(diesel::result::Error);

impl std::ops::Deref for DieselError {
type Target = diesel::result::Error;
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl AsRef<diesel::result::Error> for DieselError {
fn as_ref(&self) -> &diesel::result::Error {
&self.0
}
}

impl From<diesel::result::Error> for DieselError {
fn from(e: diesel::result::Error) -> Self {
Self(e)
}
}

/// Alias of a `Result` with an error type of `diesel::Error`.
///
/// `QueryResult` is a [`Responder`](rocket::response::Responder) when `T` (the
/// `Ok` value) is a `Responder`. By using this alias as a route handler's
/// return type, the `?` operator can be applied to fallible `diesel` functions
/// in the route handler while still providing a valid `Responder` return type.
///
/// See the [module level docs](self#example) for a usage example.
///
/// [`Debug`]: rocket::response::Debug
pub type QueryResult<T, E = rocket::response::Debug<diesel::result::Error>> = Result<T, E>;
/// See module level docs for usage, and catching the error type.
pub type QueryResult<T, E = DieselError> = Result<T, E>;

/// Type alias for an `async` pool of MySQL connections for `async` [diesel].
///
Expand Down
6 changes: 3 additions & 3 deletions contrib/dyn_templates/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ impl Sentinel for Metadata<'_> {
/// (`500`) is returned.
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Metadata<'r> {
type Error = ();
type Error = Status;

async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
request.rocket().state::<ContextManager>()
.map(|cm| request::Outcome::Success(Metadata(cm)))
.unwrap_or_else(|| {
Expand All @@ -163,7 +163,7 @@ impl<'r> FromRequest<'r> for Metadata<'r> {
To use templates, you must attach `Template::fairing()`."
);

request::Outcome::Error((Status::InternalServerError, ()))
request::Outcome::Error(Status::InternalServerError)
})
}
}
26 changes: 14 additions & 12 deletions contrib/dyn_templates/src/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,19 +265,21 @@ impl Template {
/// extension and a fixed-size body containing the rendered template. If
/// rendering fails, an `Err` of `Status::InternalServerError` is returned.
impl<'r> Responder<'r, 'static> for Template {
fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> {
let ctxt = req.rocket()
.state::<ContextManager>()
.ok_or_else(|| {
error!(
"uninitialized template context: missing `Template::fairing()`.\n\
To use templates, you must attach `Template::fairing()`."
);

Status::InternalServerError
})?;
type Error = Status;
fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static, Self::Error> {
if let Some(ctxt) = req.rocket().state::<ContextManager>() {
match self.finalize(&ctxt.context()) {
Ok(v) => v.respond_to(req).map_err(|e| match e {}),
Err(s) => Err(s),
}
} else {
error!(
"uninitialized template context: missing `Template::fairing()`.\n\
To use templates, you must attach `Template::fairing()`."
);

self.finalize(&ctxt.context())?.respond_to(req)
Err(Status::InternalServerError)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions contrib/sync_db_pools/codegen/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result<TokenStrea

#[#rocket::async_trait]
impl<'r> #rocket::request::FromRequest<'r> for #guard_type {
type Error = ();
type Error = #rocket::http::Status;

async fn from_request(
__r: &'r #rocket::request::Request<'_>
) -> #rocket::request::Outcome<Self, ()> {
) -> #rocket::request::Outcome<Self, Self::Error> {
<#conn>::from_request(__r).await.map(Self)
}
}
Expand Down
9 changes: 4 additions & 5 deletions contrib/sync_db_pools/lib/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::marker::PhantomData;
use rocket::{Phase, Rocket, Ignite, Sentinel};
use rocket::fairing::{AdHoc, Fairing};
use rocket::request::{Request, Outcome, FromRequest};
use rocket::outcome::IntoOutcome;
use rocket::http::Status;
use rocket::trace::Trace;

Expand Down Expand Up @@ -212,17 +211,17 @@ impl<K, C: Poolable> Drop for ConnectionPool<K, C> {

#[rocket::async_trait]
impl<'r, K: 'static, C: Poolable> FromRequest<'r> for Connection<K, C> {
type Error = ();
type Error = Status;

#[inline]
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ()> {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.rocket().state::<ConnectionPool<K, C>>() {
Some(c) => c.get().await.or_error((Status::ServiceUnavailable, ())),
Some(c) => c.get().await.ok_or(Status::ServiceUnavailable).into(),
None => {
let conn = std::any::type_name::<K>();
error!("`{conn}::fairing()` is not attached\n\
the fairing must be attached to use `{conn} in routes.");
Outcome::Error((Status::InternalServerError, ()))
Outcome::Error(Status::InternalServerError)
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions contrib/ws/src/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ impl<'r> FromRequest<'r> for WebSocket {
}

impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
type Error = std::convert::Infallible;
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o, Self::Error> {
Response::build()
.raw_header("Sec-Websocket-Version", "13")
.raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
Expand All @@ -250,7 +251,8 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> {
impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S>
where S: futures::Stream<Item = Result<Message>> + Send + 'o
{
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> {
type Error = std::convert::Infallible;
fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o, Self::Error> {
Response::build()
.raw_header("Sec-Websocket-Version", "13")
.raw_header("Sec-WebSocket-Accept", self.ws.key.clone())
Expand Down
98 changes: 71 additions & 27 deletions core/codegen/src/attribute/catch/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,64 @@
mod parse;

use devise::ext::SpanDiagnosticExt;
use devise::{Spanned, Result};
use devise::{Result, Spanned};
use proc_macro2::{TokenStream, Span};

use crate::http_codegen::Optional;
use crate::syn_ext::ReturnTypeExt;
use crate::syn_ext::{IdentExt, ReturnTypeExt};
use crate::exports::*;

use self::parse::ErrorGuard;

use super::param::Guard;

fn error_type(guard: &ErrorGuard) -> TokenStream {
let ty = &guard.ty;
quote! {
#_catcher::type_id_of::<#ty>()
}
}

fn error_guard_decl(guard: &ErrorGuard) -> TokenStream {
let (ident, ty) = (guard.ident.rocketized(), &guard.ty);
quote_spanned! { ty.span() =>
let #ident: &#ty = match #_catcher::downcast(__error_init) {
Some(v) => v,
None => return #_Result::Err(#__status),
};
}
}

fn request_guard_decl(guard: &Guard) -> TokenStream {
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
quote_spanned! { ty.span() =>
let #ident: #ty = match <#ty as #FromError>::from_error(
#__status,
#__req,
__error_init
).await {
#_Result::Ok(__v) => __v,
#_Result::Err(__e) => {
::rocket::trace::info!(
name: "forward",
target: concat!("rocket::codegen::catch::", module_path!()),
parameter = stringify!(#ident),
type_name = stringify!(#ty),
status = __e.code,
"error guard forwarding; trying next catcher"
);

return #_Err(#__status);
},
};
}
}

pub fn _catch(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream
) -> Result<TokenStream> {
// Parse and validate all of the user's input.
let catch = parse::Attribute::parse(args.into(), input)?;
let catch = parse::Attribute::parse(args.into(), input.into())?;

// Gather everything we'll need to generate the catcher.
let user_catcher_fn = &catch.function;
Expand All @@ -22,35 +67,28 @@ pub fn _catch(
let status_code = Optional(catch.status.map(|s| s.code));
let deprecated = catch.function.attrs.iter().find(|a| a.path().is_ident("deprecated"));

// Determine the number of parameters that will be passed in.
if catch.function.sig.inputs.len() > 2 {
return Err(catch.function.sig.paren_token.span.join()
.error("invalid number of arguments: must be zero, one, or two")
.help("catchers optionally take `&Request` or `Status, &Request`"));
}

// This ensures that "Responder not implemented" points to the return type.
let return_type_span = catch.function.sig.output.ty()
.map(|ty| ty.span())
.unwrap_or_else(Span::call_site);

// Set the `req` and `status` spans to that of their respective function
// arguments for a more correct `wrong type` error span. `rev` to be cute.
let codegen_args = &[__req, __status];
let inputs = catch.function.sig.inputs.iter().rev()
.zip(codegen_args.iter())
.map(|(fn_arg, codegen_arg)| match fn_arg {
syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()),
syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span())
}).rev();
let error_guard = catch.error_guard.as_ref().map(error_guard_decl);
let error_type = Optional(catch.error_guard.as_ref().map(error_type));
let request_guards = catch.request_guards.iter().map(request_guard_decl);
let parameter_names = catch.arguments.map.values()
.map(|(ident, _)| ident.rocketized());

// We append `.await` to the function call if this is `async`.
let dot_await = catch.function.sig.asyncness
.map(|a| quote_spanned!(a.span() => .await));

let catcher_response = quote_spanned!(return_type_span => {
let ___responder = #user_catcher_fn_name(#(#inputs),*) #dot_await;
#_response::Responder::respond_to(___responder, #__req)?
let ___responder = #user_catcher_fn_name(#(#parameter_names),*) #dot_await;
match #_response::Responder::respond_to(___responder, #__req) {
#_Ok(v) => v,
// If the responder fails, we drop any typed error, and convert to 500
#_Err(_) => return #_Err(#Status::InternalServerError),
}
});

// Generate the catcher, keeping the user's input around.
Expand All @@ -68,20 +106,26 @@ pub fn _catch(
fn into_info(self) -> #_catcher::StaticInfo {
fn monomorphized_function<'__r>(
#__status: #Status,
#__req: &'__r #Request<'_>
#__req: &'__r #Request<'_>,
__error_init: #_Option<&'__r (dyn #TypedError<'__r> + '__r)>,
) -> #_catcher::BoxFuture<'__r> {
#_Box::pin(async move {
#error_guard
#(#request_guards)*
let __response = #catcher_response;
#Response::build()
.status(#__status)
.merge(__response)
.ok()
#_Result::Ok(
#Response::build()
.status(#__status)
.merge(__response)
.finalize()
)
})
}

#_catcher::StaticInfo {
name: ::core::stringify!(#user_catcher_fn_name),
code: #status_code,
error_type: #error_type,
handler: monomorphized_function,
location: (::core::file!(), ::core::line!(), ::core::column!()),
}
Expand Down
Loading
Loading