From d03435df7d86e5893bcb66d1ba79ded8a43ad8d0 Mon Sep 17 00:00:00 2001 From: Timothy Clem Date: Sat, 23 Mar 2024 17:17:30 -0700 Subject: [PATCH] Allow passing around req/resp http::Extensions Co-Authored-By: Nick Presta --- README.md | 4 +- crates/twirp-build/src/lib.rs | 6 +-- crates/twirp/src/context.rs | 42 ++++++++++++++++++ crates/twirp/src/details.rs | 4 +- crates/twirp/src/error.rs | 2 +- crates/twirp/src/lib.rs | 3 ++ crates/twirp/src/server.rs | 74 ++++++++++++++++++++++++++----- crates/twirp/src/test.rs | 45 +++++++++++++++---- example/src/bin/example-client.rs | 14 +++++- example/src/main.rs | 59 +++++++++++++++++++++--- 10 files changed, 219 insertions(+), 34 deletions(-) create mode 100644 crates/twirp/src/context.rs diff --git a/README.md b/README.md index b18dfcc..6231edf 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ Add a `build.rs` file to your project to compile the protos and generate Rust co ```rust fn main() { let proto_source_files = ["./service.proto"]; - + // Tell Cargo to rerun this build script if any of the proto files change for entry in &proto_source_files { println!("cargo:rerun-if-changed={}", entry); @@ -82,7 +82,7 @@ struct HaberdasherApiServer; #[async_trait] impl haberdash::HaberdasherApi for HaberdasherApiServer { - async fn make_hat(&self, req: MakeHatRequest) -> Result { + async fn make_hat(&self, ctx: twirp::Context, req: MakeHatRequest) -> Result { todo!() } } diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index c08a8a9..9b75a90 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -29,7 +29,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator { for m in &service.methods { writeln!( buf, - " async fn {}(&self, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;", + " async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;", m.name, m.input_type, m.output_type, ) .unwrap(); @@ -52,8 +52,8 @@ where let rust_method_name = &m.name; writeln!( buf, - r#" .route("/{uri}", |api: std::sync::Arc, req: {req_type}| async move {{ - api.{rust_method_name}(req).await + r#" .route("/{uri}", |api: std::sync::Arc, ctx: twirp::Context, req: {req_type}| async move {{ + api.{rust_method_name}(ctx, req).await }})"#, ) .unwrap(); diff --git a/crates/twirp/src/context.rs b/crates/twirp/src/context.rs new file mode 100644 index 0000000..8612ea9 --- /dev/null +++ b/crates/twirp/src/context.rs @@ -0,0 +1,42 @@ +use std::sync::{Arc, Mutex}; + +use http::Extensions; + +/// Context allows passing information between twirp rpc handlers and http middleware by providing +/// access to extensions on the `http:Request` and `http:Response`. +/// +/// An example use case is to extract a request id from an http header and use that id in subsequent +/// handler code. +#[derive(Default)] +pub struct Context { + extensions: Extensions, + resp_extensions: Arc>, +} + +impl Context { + pub fn new(extensions: Extensions, resp_extensions: Arc>) -> Self { + Self { + extensions, + resp_extensions, + } + } + + /// Get a request extension. + pub fn get(&self) -> Option<&T> + where + T: Clone + Send + Sync + 'static, + { + self.extensions.get::() + } + + /// Insert a response extension. + pub fn insert(&self, val: T) -> Option + where + T: Clone + Send + Sync + 'static, + { + self.resp_extensions + .lock() + .expect("mutex poisoned") + .insert(val) + } +} diff --git a/crates/twirp/src/details.rs b/crates/twirp/src/details.rs index c82a9ae..2e0b19f 100644 --- a/crates/twirp/src/details.rs +++ b/crates/twirp/src/details.rs @@ -5,7 +5,7 @@ use std::future::Future; use axum::extract::{Request, State}; use axum::Router; -use crate::{server, TwirpErrorResponse}; +use crate::{server, Context, TwirpErrorResponse}; /// Builder object used by generated code to build a Twirp service. /// @@ -33,7 +33,7 @@ where /// `|api: Arc, req: MakeHatRequest| async move { api.make_hat(req) }`. pub fn route(self, url: &str, f: F) -> Self where - F: Fn(S, Req) -> Fut + Clone + Sync + Send + 'static, + F: Fn(S, Context, Req) -> Fut + Clone + Sync + Send + 'static, Fut: Future> + Send, Req: prost::Message + Default + serde::de::DeserializeOwned, Res: prost::Message + serde::Serialize, diff --git a/crates/twirp/src/error.rs b/crates/twirp/src/error.rs index 4c1a98d..592ea10 100644 --- a/crates/twirp/src/error.rs +++ b/crates/twirp/src/error.rs @@ -8,7 +8,7 @@ use http::header::{self, HeaderMap, HeaderValue}; use hyper::{Response, StatusCode}; use serde::{Deserialize, Serialize, Serializer}; -// Alias for a generic error +/// Alias for a generic error pub type GenericError = Box; macro_rules! twirp_error_codes { diff --git a/crates/twirp/src/lib.rs b/crates/twirp/src/lib.rs index 876c0d1..5b66b2b 100644 --- a/crates/twirp/src/lib.rs +++ b/crates/twirp/src/lib.rs @@ -1,4 +1,5 @@ pub mod client; +pub mod context; pub mod error; pub mod headers; pub mod server; @@ -10,7 +11,9 @@ pub mod test; pub mod details; pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result}; +pub use context::Context; pub use error::*; // many constructors like `invalid_argument()` +pub use http::Extensions; // Re-export this crate's dependencies that users are likely to code against. These can be used to // import the exact versions of these libraries `twirp` is built with -- useful if your project is diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index 67ab69c..2b756d0 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -4,10 +4,12 @@ //! `twirp-build`. See for details and an example. use std::fmt::Debug; +use std::sync::{Arc, Mutex}; use axum::body::Body; use axum::response::IntoResponse; use futures::Future; +use http::Extensions; use http_body_util::BodyExt; use hyper::{header, Request, Response}; use serde::de::DeserializeOwned; @@ -15,7 +17,7 @@ use serde::Serialize; use tokio::time::{Duration, Instant}; use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF}; -use crate::{error, serialize_proto_message, GenericError, TwirpErrorResponse}; +use crate::{error, serialize_proto_message, Context, GenericError, TwirpErrorResponse}; // TODO: Properly implement JsonPb (de)serialization as it is slightly different // than standard JSON. @@ -46,7 +48,7 @@ pub(crate) async fn handle_request( f: F, ) -> Response where - F: FnOnce(S, Req) -> Fut + Clone + Sync + Send + 'static, + F: FnOnce(S, Context, Req) -> Fut + Clone + Sync + Send + 'static, Fut: Future> + Send, Req: prost::Message + Default + serde::de::DeserializeOwned, Resp: prost::Message + serde::Serialize, @@ -57,25 +59,29 @@ where .copied() .unwrap_or_else(|| Timings::new(Instant::now())); - let (req, resp_fmt) = match parse_request(req, &mut timings).await { + let (req, exts, resp_fmt) = match parse_request(req, &mut timings).await { Ok(pair) => pair, Err(err) => { - // This is the only place we use tracing (would be nice to remove) - // tracing::error!(?err, "failed to parse request"); - // TODO: We don't want to lose the underlying error here, but it might not be safe to - // include in the response like this always. + // TODO: Capture original error in the response extensions. E.g.: + // resp_exts + // .lock() + // .expect("mutex poisoned") + // .insert(RequestError(err)); let mut twirp_err = error::malformed("bad request"); twirp_err.insert_meta("error".to_string(), err.to_string()); return twirp_err.into_response(); } }; - let res = f(service, req).await; + let resp_exts = Arc::new(Mutex::new(Extensions::new())); + let ctx = Context::new(exts, resp_exts.clone()); + let res = f(service, ctx, req).await; timings.set_response_handled(); let mut resp = match write_response(res, resp_fmt) { Ok(resp) => resp, Err(err) => { + // TODO: Capture original error in the response extensions. let mut twirp_err = error::unknown("error serializing response"); twirp_err.insert_meta("error".to_string(), err.to_string()); return twirp_err.into_response(); @@ -83,6 +89,8 @@ where }; timings.set_response_written(); + resp.extensions_mut() + .extend(resp_exts.lock().expect("mutex poisoned").clone()); resp.extensions_mut().insert(timings); resp } @@ -90,19 +98,20 @@ where async fn parse_request( req: Request, timings: &mut Timings, -) -> Result<(T, BodyFormat), GenericError> +) -> Result<(T, Extensions, BodyFormat), GenericError> where T: prost::Message + Default + DeserializeOwned, { let format = BodyFormat::from_content_type(&req); - let bytes = req.into_body().collect().await?.to_bytes(); + let (parts, body) = req.into_parts(); + let bytes = body.collect().await?.to_bytes(); timings.set_received(); let request = match format { BodyFormat::Pb => T::decode(&bytes[..])?, BodyFormat::JsonPb => serde_json::from_slice(&bytes)?, }; timings.set_parsed(); - Ok((request, format)) + Ok((request, parts.extensions, format)) } fn write_response( @@ -233,6 +242,7 @@ mod tests { use super::*; use crate::test::*; + use axum::middleware::{self, Next}; use tower::Service; fn timings() -> Timings { @@ -299,4 +309,46 @@ mod tests { let data = read_err_body(resp.into_body()).await; assert_eq!(data, error::internal("boom!")); } + + #[tokio::test] + async fn test_middleware() { + let mut router = test_api_router().layer(middleware::from_fn(request_id_middleware)); + + // no request-id header + let resp = router.call(gen_ping_request("hi")).await.unwrap(); + assert!(resp.status().is_success(), "{:?}", resp); + let data: PingResponse = read_json_body(resp.into_body()).await; + assert_eq!(&data.name, "hi"); + + // now pass a header with x-request-id + let req = Request::post("/twirp/test.TestAPI/Ping") + .header("x-request-id", "abcd") + .body(Body::from( + serde_json::to_string(&PingRequest { + name: "hello".to_string(), + }) + .expect("will always be valid json"), + )) + .expect("always a valid twirp request"); + let resp = router.call(req).await.unwrap(); + assert!(resp.status().is_success(), "{:?}", resp); + let data: PingResponse = read_json_body(resp.into_body()).await; + assert_eq!(&data.name, "hello-abcd"); + } + + async fn request_id_middleware( + mut request: http::Request, + next: Next, + ) -> http::Response { + let rid = request + .headers() + .get("x-request-id") + .and_then(|v| v.to_str().ok()) + .map(|x| RequestId(x.to_string())); + if let Some(rid) = rid { + request.extensions_mut().insert(rid); + } + + next.run(request).await + } } diff --git a/crates/twirp/src/test.rs b/crates/twirp/src/test.rs index f07a338..06cf6f7 100644 --- a/crates/twirp/src/test.rs +++ b/crates/twirp/src/test.rs @@ -13,7 +13,7 @@ use tokio::time::Instant; use crate::details::TwirpRouterBuilder; use crate::server::Timings; -use crate::{error, Client, Result, TwirpErrorResponse}; +use crate::{error, Client, Context, Result, TwirpErrorResponse}; pub async fn run_test_server(port: u16) -> JoinHandle> { let router = test_api_router(); @@ -34,11 +34,15 @@ pub fn test_api_router() -> Router { let test_router = TwirpRouterBuilder::new(api) .route( "/Ping", - |api: Arc, req: PingRequest| async move { api.ping(req).await }, + |api: Arc, ctx: Context, req: PingRequest| async move { + api.ping(ctx, req).await + }, ) .route( "/Boom", - |api: Arc, req: PingRequest| async move { api.boom(req).await }, + |api: Arc, ctx: Context, req: PingRequest| async move { + api.boom(ctx, req).await + }, ) .build(); @@ -81,15 +85,32 @@ pub struct TestApiServer; #[async_trait] impl TestApi for TestApiServer { - async fn ping(&self, req: PingRequest) -> Result { - Ok(PingResponse { name: req.name }) + async fn ping( + &self, + ctx: Context, + req: PingRequest, + ) -> Result { + if let Some(RequestId(rid)) = ctx.get::() { + Ok(PingResponse { + name: format!("{}-{}", req.name, rid), + }) + } else { + Ok(PingResponse { name: req.name }) + } } - async fn boom(&self, _: PingRequest) -> Result { + async fn boom( + &self, + _ctx: Context, + _: PingRequest, + ) -> Result { Err(error::internal("boom!")) } } +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Default)] +pub struct RequestId(pub String); + // Small test twirp services (this would usually be generated with twirp-build) #[async_trait] pub trait TestApiClient { @@ -111,8 +132,16 @@ impl TestApiClient for Client { #[async_trait] pub trait TestApi { - async fn ping(&self, req: PingRequest) -> Result; - async fn boom(&self, req: PingRequest) -> Result; + async fn ping( + &self, + ctx: Context, + req: PingRequest, + ) -> Result; + async fn boom( + &self, + ctx: Context, + req: PingRequest, + ) -> Result; } #[derive(serde::Serialize, serde::Deserialize)] diff --git a/example/src/bin/example-client.rs b/example/src/bin/example-client.rs index 03ebbf6..e580874 100644 --- a/example/src/bin/example-client.rs +++ b/example/src/bin/example-client.rs @@ -28,6 +28,7 @@ pub async fn main() -> Result<(), GenericError> { twirp::reqwest::Client::default(), ) .with(RequestHeaders { hmac_key: None }) + .with(PrintResponseHeaders {}) .build()?; let resp = client .with(hostname("localhost")) @@ -59,7 +60,7 @@ struct RequestHeaders { #[async_trait] impl Middleware for RequestHeaders { async fn handle(&self, mut req: Request, next: Next<'_>) -> twirp::client::Result { - req.headers_mut().append("Request_id", "XYZ".try_into()?); + req.headers_mut().append("x-request-id", "XYZ".try_into()?); if let Some(_hmac_key) = &self.hmac_key { req.headers_mut() .append("Request-HMAC", "example:todo".try_into()?); @@ -69,6 +70,17 @@ impl Middleware for RequestHeaders { } } +struct PrintResponseHeaders; + +#[async_trait] +impl Middleware for PrintResponseHeaders { + async fn handle(&self, req: Request, next: Next<'_>) -> twirp::client::Result { + let res = next.run(req).await?; + eprintln!("Response headers: {res:?}"); + Ok(res) + } +} + #[derive(Debug)] struct MockHaberdasherApiClient; diff --git a/example/src/main.rs b/example/src/main.rs index f134a7a..d1ae7a4 100644 --- a/example/src/main.rs +++ b/example/src/main.rs @@ -3,8 +3,11 @@ use std::sync::Arc; use std::time::UNIX_EPOCH; use twirp::async_trait::async_trait; +use twirp::axum::body::Body; +use twirp::axum::http; +use twirp::axum::middleware::{self, Next}; use twirp::axum::routing::get; -use twirp::{invalid_argument, Router, TwirpErrorResponse}; +use twirp::{invalid_argument, Context, Router, TwirpErrorResponse}; pub mod service { pub mod haberdash { @@ -22,7 +25,11 @@ async fn ping() -> &'static str { #[tokio::main] pub async fn main() { let api_impl = Arc::new(HaberdasherApiServer {}); - let twirp_routes = Router::new().nest(haberdash::SERVICE_FQN, haberdash::router(api_impl)); + let middleware = twirp::tower::builder::ServiceBuilder::new() + .layer(middleware::from_fn(request_id_middleware)); + let twirp_routes = Router::new() + .nest(haberdash::SERVICE_FQN, haberdash::router(api_impl)) + .layer(middleware); let app = Router::new() .nest("/twirp", twirp_routes) .route("/_ping", get(ping)) @@ -42,12 +49,21 @@ struct HaberdasherApiServer; #[async_trait] impl haberdash::HaberdasherApi for HaberdasherApiServer { - async fn make_hat(&self, req: MakeHatRequest) -> Result { + async fn make_hat( + &self, + ctx: Context, + req: MakeHatRequest, + ) -> Result { if req.inches == 0 { return Err(invalid_argument("inches")); } - println!("got {:?}", req); + if let Some(id) = ctx.get::() { + println!("{id:?}"); + }; + + println!("got {req:?}"); + ctx.insert::(ResponseInfo(42)); let ts = std::time::SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or_default(); @@ -63,6 +79,35 @@ impl haberdash::HaberdasherApi for HaberdasherApiServer { } } +// Demonstrate sending back custom extensions from the handlers. +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Default)] +struct ResponseInfo(u16); + +/// Demonstrate pulling the request id out of an http header and sharing it with the rpc handlers. +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Default)] +struct RequestId(String); + +async fn request_id_middleware( + mut request: http::Request, + next: Next, +) -> http::Response { + let rid = request + .headers() + .get("x-request-id") + .and_then(|v| v.to_str().ok()) + .map(|x| RequestId(x.to_string())); + if let Some(rid) = rid { + request.extensions_mut().insert(rid); + } + + let mut res = next.run(request).await; + + let info = res.extensions().get::().unwrap().0; + res.headers_mut().insert("x-response-info", info.into()); + + res +} + #[cfg(test)] mod test { use service::haberdash::v1::HaberdasherApiClient; @@ -77,7 +122,8 @@ mod test { #[tokio::test] async fn success() { let api = HaberdasherApiServer {}; - let res = api.make_hat(MakeHatRequest { inches: 1 }).await; + let ctx = twirp::Context::default(); + let res = api.make_hat(ctx, MakeHatRequest { inches: 1 }).await; assert!(res.is_ok()); let res = res.unwrap(); assert_eq!(res.size, 1); @@ -86,7 +132,8 @@ mod test { #[tokio::test] async fn invalid_request() { let api = HaberdasherApiServer {}; - let res = api.make_hat(MakeHatRequest { inches: 0 }).await; + let ctx = twirp::Context::default(); + let res = api.make_hat(ctx, MakeHatRequest { inches: 0 }).await; assert!(res.is_err()); let err = res.unwrap_err(); assert_eq!(err.code, TwirpErrorCode::InvalidArgument);