Skip to content

Commit

Permalink
Allow passing around req/resp http::Extensions
Browse files Browse the repository at this point in the history
Co-Authored-By: Nick Presta <[email protected]>
  • Loading branch information
tclem and nickpresta committed Mar 27, 2024
1 parent 3198261 commit d03435d
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 34 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Add a `build.rs` file to your project to compile the protos and generate Rust co
```rust
fn main() {
let proto_source_files = ["./service.proto"];

// Tell Cargo to rerun this build script if any of the proto files change
for entry in &proto_source_files {
println!("cargo:rerun-if-changed={}", entry);
Expand Down Expand Up @@ -82,7 +82,7 @@ struct HaberdasherApiServer;

#[async_trait]
impl haberdash::HaberdasherApi for HaberdasherApiServer {
async fn make_hat(&self, req: MakeHatRequest) -> Result<MakeHatResponse, TwirpErrorResponse> {
async fn make_hat(&self, ctx: twirp::Context, req: MakeHatRequest) -> Result<MakeHatResponse, TwirpErrorResponse> {
todo!()
}
}
Expand Down
6 changes: 3 additions & 3 deletions crates/twirp-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
for m in &service.methods {
writeln!(
buf,
" async fn {}(&self, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;",
" async fn {}(&self, ctx: twirp::Context, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;",
m.name, m.input_type, m.output_type,
)
.unwrap();
Expand All @@ -52,8 +52,8 @@ where
let rust_method_name = &m.name;
writeln!(
buf,
r#" .route("/{uri}", |api: std::sync::Arc<T>, req: {req_type}| async move {{
api.{rust_method_name}(req).await
r#" .route("/{uri}", |api: std::sync::Arc<T>, ctx: twirp::Context, req: {req_type}| async move {{
api.{rust_method_name}(ctx, req).await
}})"#,
)
.unwrap();
Expand Down
42 changes: 42 additions & 0 deletions crates/twirp/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use std::sync::{Arc, Mutex};

use http::Extensions;

/// Context allows passing information between twirp rpc handlers and http middleware by providing
/// access to extensions on the `http:Request` and `http:Response`.
///
/// An example use case is to extract a request id from an http header and use that id in subsequent
/// handler code.
#[derive(Default)]
pub struct Context {
extensions: Extensions,
resp_extensions: Arc<Mutex<Extensions>>,
}

impl Context {
pub fn new(extensions: Extensions, resp_extensions: Arc<Mutex<Extensions>>) -> Self {
Self {
extensions,
resp_extensions,
}
}

/// Get a request extension.
pub fn get<T>(&self) -> Option<&T>
where
T: Clone + Send + Sync + 'static,
{
self.extensions.get::<T>()
}

/// Insert a response extension.
pub fn insert<T>(&self, val: T) -> Option<T>
where
T: Clone + Send + Sync + 'static,
{
self.resp_extensions
.lock()
.expect("mutex poisoned")
.insert(val)
}
}
4 changes: 2 additions & 2 deletions crates/twirp/src/details.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::future::Future;
use axum::extract::{Request, State};
use axum::Router;

use crate::{server, TwirpErrorResponse};
use crate::{server, Context, TwirpErrorResponse};

/// Builder object used by generated code to build a Twirp service.
///
Expand Down Expand Up @@ -33,7 +33,7 @@ where
/// `|api: Arc<HaberdasherApiServer>, req: MakeHatRequest| async move { api.make_hat(req) }`.
pub fn route<F, Fut, Req, Res>(self, url: &str, f: F) -> Self
where
F: Fn(S, Req) -> Fut + Clone + Sync + Send + 'static,
F: Fn(S, Context, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Res, TwirpErrorResponse>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Res: prost::Message + serde::Serialize,
Expand Down
2 changes: 1 addition & 1 deletion crates/twirp/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use http::header::{self, HeaderMap, HeaderValue};
use hyper::{Response, StatusCode};
use serde::{Deserialize, Serialize, Serializer};

// Alias for a generic error
/// Alias for a generic error
pub type GenericError = Box<dyn std::error::Error + Send + Sync>;

macro_rules! twirp_error_codes {
Expand Down
3 changes: 3 additions & 0 deletions crates/twirp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod client;
pub mod context;
pub mod error;
pub mod headers;
pub mod server;
Expand All @@ -10,7 +11,9 @@ pub mod test;
pub mod details;

pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result};
pub use context::Context;
pub use error::*; // many constructors like `invalid_argument()`
pub use http::Extensions;

// Re-export this crate's dependencies that users are likely to code against. These can be used to
// import the exact versions of these libraries `twirp` is built with -- useful if your project is
Expand Down
74 changes: 63 additions & 11 deletions crates/twirp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
//! `twirp-build`. See <https://github.com/github/twirp-rs#usage> for details and an example.
use std::fmt::Debug;
use std::sync::{Arc, Mutex};

use axum::body::Body;
use axum::response::IntoResponse;
use futures::Future;
use http::Extensions;
use http_body_util::BodyExt;
use hyper::{header, Request, Response};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::time::{Duration, Instant};

use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
use crate::{error, serialize_proto_message, GenericError, TwirpErrorResponse};
use crate::{error, serialize_proto_message, Context, GenericError, TwirpErrorResponse};

// TODO: Properly implement JsonPb (de)serialization as it is slightly different
// than standard JSON.
Expand Down Expand Up @@ -46,7 +48,7 @@ pub(crate) async fn handle_request<S, F, Fut, Req, Resp>(
f: F,
) -> Response<Body>
where
F: FnOnce(S, Req) -> Fut + Clone + Sync + Send + 'static,
F: FnOnce(S, Context, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Resp, TwirpErrorResponse>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Resp: prost::Message + serde::Serialize,
Expand All @@ -57,52 +59,59 @@ where
.copied()
.unwrap_or_else(|| Timings::new(Instant::now()));

let (req, resp_fmt) = match parse_request(req, &mut timings).await {
let (req, exts, resp_fmt) = match parse_request(req, &mut timings).await {
Ok(pair) => pair,
Err(err) => {
// This is the only place we use tracing (would be nice to remove)
// tracing::error!(?err, "failed to parse request");
// TODO: We don't want to lose the underlying error here, but it might not be safe to
// include in the response like this always.
// TODO: Capture original error in the response extensions. E.g.:
// resp_exts
// .lock()
// .expect("mutex poisoned")
// .insert(RequestError(err));
let mut twirp_err = error::malformed("bad request");
twirp_err.insert_meta("error".to_string(), err.to_string());
return twirp_err.into_response();
}
};

let res = f(service, req).await;
let resp_exts = Arc::new(Mutex::new(Extensions::new()));
let ctx = Context::new(exts, resp_exts.clone());
let res = f(service, ctx, req).await;
timings.set_response_handled();

let mut resp = match write_response(res, resp_fmt) {
Ok(resp) => resp,
Err(err) => {
// TODO: Capture original error in the response extensions.
let mut twirp_err = error::unknown("error serializing response");
twirp_err.insert_meta("error".to_string(), err.to_string());
return twirp_err.into_response();
}
};
timings.set_response_written();

resp.extensions_mut()
.extend(resp_exts.lock().expect("mutex poisoned").clone());
resp.extensions_mut().insert(timings);
resp
}

async fn parse_request<T>(
req: Request<Body>,
timings: &mut Timings,
) -> Result<(T, BodyFormat), GenericError>
) -> Result<(T, Extensions, BodyFormat), GenericError>
where
T: prost::Message + Default + DeserializeOwned,
{
let format = BodyFormat::from_content_type(&req);
let bytes = req.into_body().collect().await?.to_bytes();
let (parts, body) = req.into_parts();
let bytes = body.collect().await?.to_bytes();
timings.set_received();
let request = match format {
BodyFormat::Pb => T::decode(&bytes[..])?,
BodyFormat::JsonPb => serde_json::from_slice(&bytes)?,
};
timings.set_parsed();
Ok((request, format))
Ok((request, parts.extensions, format))
}

fn write_response<T>(
Expand Down Expand Up @@ -233,6 +242,7 @@ mod tests {
use super::*;
use crate::test::*;

use axum::middleware::{self, Next};
use tower::Service;

fn timings() -> Timings {
Expand Down Expand Up @@ -299,4 +309,46 @@ mod tests {
let data = read_err_body(resp.into_body()).await;
assert_eq!(data, error::internal("boom!"));
}

#[tokio::test]
async fn test_middleware() {
let mut router = test_api_router().layer(middleware::from_fn(request_id_middleware));

// no request-id header
let resp = router.call(gen_ping_request("hi")).await.unwrap();
assert!(resp.status().is_success(), "{:?}", resp);
let data: PingResponse = read_json_body(resp.into_body()).await;
assert_eq!(&data.name, "hi");

// now pass a header with x-request-id
let req = Request::post("/twirp/test.TestAPI/Ping")
.header("x-request-id", "abcd")
.body(Body::from(
serde_json::to_string(&PingRequest {
name: "hello".to_string(),
})
.expect("will always be valid json"),
))
.expect("always a valid twirp request");
let resp = router.call(req).await.unwrap();
assert!(resp.status().is_success(), "{:?}", resp);
let data: PingResponse = read_json_body(resp.into_body()).await;
assert_eq!(&data.name, "hello-abcd");
}

async fn request_id_middleware(
mut request: http::Request<Body>,
next: Next,
) -> http::Response<Body> {
let rid = request
.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(|x| RequestId(x.to_string()));
if let Some(rid) = rid {
request.extensions_mut().insert(rid);
}

next.run(request).await
}
}
45 changes: 37 additions & 8 deletions crates/twirp/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use tokio::time::Instant;

use crate::details::TwirpRouterBuilder;
use crate::server::Timings;
use crate::{error, Client, Result, TwirpErrorResponse};
use crate::{error, Client, Context, Result, TwirpErrorResponse};

pub async fn run_test_server(port: u16) -> JoinHandle<Result<(), std::io::Error>> {
let router = test_api_router();
Expand All @@ -34,11 +34,15 @@ pub fn test_api_router() -> Router {
let test_router = TwirpRouterBuilder::new(api)
.route(
"/Ping",
|api: Arc<TestApiServer>, req: PingRequest| async move { api.ping(req).await },
|api: Arc<TestApiServer>, ctx: Context, req: PingRequest| async move {
api.ping(ctx, req).await
},
)
.route(
"/Boom",
|api: Arc<TestApiServer>, req: PingRequest| async move { api.boom(req).await },
|api: Arc<TestApiServer>, ctx: Context, req: PingRequest| async move {
api.boom(ctx, req).await
},
)
.build();

Expand Down Expand Up @@ -81,15 +85,32 @@ pub struct TestApiServer;

#[async_trait]
impl TestApi for TestApiServer {
async fn ping(&self, req: PingRequest) -> Result<PingResponse, TwirpErrorResponse> {
Ok(PingResponse { name: req.name })
async fn ping(
&self,
ctx: Context,
req: PingRequest,
) -> Result<PingResponse, TwirpErrorResponse> {
if let Some(RequestId(rid)) = ctx.get::<RequestId>() {
Ok(PingResponse {
name: format!("{}-{}", req.name, rid),
})
} else {
Ok(PingResponse { name: req.name })
}
}

async fn boom(&self, _: PingRequest) -> Result<PingResponse, TwirpErrorResponse> {
async fn boom(
&self,
_ctx: Context,
_: PingRequest,
) -> Result<PingResponse, TwirpErrorResponse> {
Err(error::internal("boom!"))
}
}

#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Default)]
pub struct RequestId(pub String);

// Small test twirp services (this would usually be generated with twirp-build)
#[async_trait]
pub trait TestApiClient {
Expand All @@ -111,8 +132,16 @@ impl TestApiClient for Client {

#[async_trait]
pub trait TestApi {
async fn ping(&self, req: PingRequest) -> Result<PingResponse, TwirpErrorResponse>;
async fn boom(&self, req: PingRequest) -> Result<PingResponse, TwirpErrorResponse>;
async fn ping(
&self,
ctx: Context,
req: PingRequest,
) -> Result<PingResponse, TwirpErrorResponse>;
async fn boom(
&self,
ctx: Context,
req: PingRequest,
) -> Result<PingResponse, TwirpErrorResponse>;
}

#[derive(serde::Serialize, serde::Deserialize)]
Expand Down
14 changes: 13 additions & 1 deletion example/src/bin/example-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub async fn main() -> Result<(), GenericError> {
twirp::reqwest::Client::default(),
)
.with(RequestHeaders { hmac_key: None })
.with(PrintResponseHeaders {})
.build()?;
let resp = client
.with(hostname("localhost"))
Expand Down Expand Up @@ -59,7 +60,7 @@ struct RequestHeaders {
#[async_trait]
impl Middleware for RequestHeaders {
async fn handle(&self, mut req: Request, next: Next<'_>) -> twirp::client::Result<Response> {
req.headers_mut().append("Request_id", "XYZ".try_into()?);
req.headers_mut().append("x-request-id", "XYZ".try_into()?);
if let Some(_hmac_key) = &self.hmac_key {
req.headers_mut()
.append("Request-HMAC", "example:todo".try_into()?);
Expand All @@ -69,6 +70,17 @@ impl Middleware for RequestHeaders {
}
}

struct PrintResponseHeaders;

#[async_trait]
impl Middleware for PrintResponseHeaders {
async fn handle(&self, req: Request, next: Next<'_>) -> twirp::client::Result<Response> {
let res = next.run(req).await?;
eprintln!("Response headers: {res:?}");
Ok(res)
}
}

#[derive(Debug)]
struct MockHaberdasherApiClient;

Expand Down
Loading

0 comments on commit d03435d

Please sign in to comment.