use axum::{
routing::{get, get_service},
- extract::Path,
+ extract::{Path, TypedHeader, ws},
http::{StatusCode},
response,
Router};
-async fn hello(Path(name): Path<String>) -> impl response::IntoResponse {
- (StatusCode::OK, format!("Hello {}", name))
+async fn hello(Path(name): Path<String>, user_agent: Option<TypedHeader<axum::headers::UserAgent>>) -> impl response::IntoResponse {
+ if let Some(TypedHeader(user_agent)) = user_agent {
+ (StatusCode::OK, format!("Hello {} from {}", name, user_agent))
+ } else {
+ (StatusCode::OK, format!("Hello {}", name))
+ }
}
-#[tokio::main]
-async fn main() {
- tracing_subscriber::fmt::init();
- let app = Router::new()
+async fn ws_echo(mut socket: ws::WebSocket) {
+ while let Some(msg) = socket.recv().await {
+ if let Ok(msg) = msg {
+ tracing::debug!("websocket got message: {:?}", msg);
+
+ let reply = match msg {
+ ws::Message::Text(t) => ws::Message::Text(t),
+ ws::Message::Binary(b) => ws::Message::Binary(b),
+ // axum handles Ping/Pong by itself
+ ws::Message::Ping(_) => { continue },
+ ws::Message::Pong(_) => { continue },
+ ws::Message::Close(_) => { break }
+ };
+
+ if socket.send(reply).await
+ .is_err() {
+ tracing::info!("websocket client disconnected");
+ break;
+ }
+ }
+ else {
+ tracing::info!("websocket client disconnected");
+ break;
+ }
+ }
+}
+
+fn app() -> Router {
+ Router::new()
.route("/hello/:name", get(hello))
- .route("/static",
- get_service(tower_http::services::ServeFile::new("Cargo.toml").precompressed_gzip())
+ .nest("/dir",
+ get_service(tower_http::services::ServeDir::new("../static").precompressed_gzip())
.handle_error(|e: io::Error| async move {
(StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", e))
})
)
+ .route("/ws-echo", get(|ws: ws::WebSocketUpgrade| async {ws.on_upgrade(ws_echo)}))
.layer(
tower::ServiceBuilder::new()
- .layer(tower_http::trace::TraceLayer::new_for_http())
- );
+ .layer(
+ tower_http::trace::TraceLayer::new_for_http()
+ .make_span_with(tower_http::trace::DefaultMakeSpan::default().include_headers(true)),
+ )
+ )
+}
- let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000));
+#[tokio::main]
+async fn main() {
+ tracing_subscriber::fmt::init();
+
+ let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3030));
tracing::info!("listening on {}", addr);
axum::Server::bind(&addr)
- .serve(app.into_make_service())
+ .serve(app().into_make_service())
.await
.unwrap();
}
+
+#[cfg(test)]
+mod tests {
+ use axum::{
+ http::{Request, StatusCode},
+ response::Response,
+ body::Body
+ };
+ use tower::ServiceExt; // for `oneshot`
+
+ async fn assert_res_ok_body(res: Response, expected_body: &[u8]) {
+ assert_eq!(res.status(), StatusCode::OK);
+ assert_eq!(hyper::body::to_bytes(res.into_body()).await.unwrap(), expected_body);
+ }
+
+ #[tokio::test]
+ async fn test_hello() {
+ // no user-agent
+ let res = super::app()
+ .oneshot(Request::builder().uri("/hello/rust").body(Body::empty()).unwrap())
+ .await
+ .unwrap();
+ assert_res_ok_body(res, b"Hello rust").await;
+
+ // with user-agent
+ let res = super::app()
+ .oneshot(Request::builder()
+ .uri("/hello/rust")
+ .header("user-agent", "TestBrowser 0.1")
+ .body(Body::empty()).unwrap())
+ .await
+ .unwrap();
+ assert_res_ok_body(res, b"Hello rust from TestBrowser 0.1").await;
+ }
+}