]> piware.de Git - learn-rust.git/blob - axum-server/src/main.rs
axum-server: Add unit test for /hello route
[learn-rust.git] / axum-server / src / main.rs
1 use std::io;
2
3 use axum::{
4     routing::{get, get_service},
5     extract::{Path, TypedHeader, ws},
6     http::{StatusCode},
7     response,
8     Router};
9
10 async fn hello(Path(name): Path<String>, user_agent: Option<TypedHeader<axum::headers::UserAgent>>) -> impl response::IntoResponse {
11     if let Some(TypedHeader(user_agent)) = user_agent {
12         (StatusCode::OK, format!("Hello {} from {}", name, user_agent))
13     } else {
14         (StatusCode::OK, format!("Hello {}", name))
15     }
16 }
17
18 async fn ws_echo(mut socket: ws::WebSocket) {
19     while let Some(msg) = socket.recv().await {
20         if let Ok(msg) = msg {
21             tracing::debug!("websocket got message: {:?}", msg);
22
23             let reply = match msg  {
24                 ws::Message::Text(t) => ws::Message::Text(t),
25                 ws::Message::Binary(b) => ws::Message::Binary(b),
26                 // axum handles Ping/Pong by itself
27                 ws::Message::Ping(_) => { continue },
28                 ws::Message::Pong(_) => { continue },
29                 ws::Message::Close(_) => { break }
30             };
31
32             if socket.send(reply).await
33                 .is_err() {
34                     tracing::info!("websocket client disconnected");
35                     break;
36                 }
37         }
38         else {
39             tracing::info!("websocket client disconnected");
40             break;
41         }
42     }
43 }
44
45 fn app() -> Router {
46     Router::new()
47         .route("/hello/:name", get(hello))
48         .nest("/dir",
49                get_service(tower_http::services::ServeDir::new("../static").precompressed_gzip())
50                    .handle_error(|e: io::Error| async move {
51                        (StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", e))
52                    })
53         )
54         .route("/ws-echo", get(|ws: ws::WebSocketUpgrade| async {ws.on_upgrade(ws_echo)}))
55         .layer(
56             tower::ServiceBuilder::new()
57                 .layer(
58                     tower_http::trace::TraceLayer::new_for_http()
59                          .make_span_with(tower_http::trace::DefaultMakeSpan::default().include_headers(true)),
60                 )
61         )
62 }
63
64 #[tokio::main]
65 async fn main() {
66     tracing_subscriber::fmt::init();
67
68     let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3030));
69     tracing::info!("listening on {}", addr);
70     axum::Server::bind(&addr)
71         .serve(app().into_make_service())
72         .await
73         .unwrap();
74 }
75
76 #[cfg(test)]
77 mod tests {
78     use axum::{
79         http::{Request, StatusCode},
80         response::Response,
81         body::Body
82     };
83     use tower::ServiceExt; // for `oneshot`
84
85     async fn assert_res_ok_body(res: Response, expected_body: &[u8]) {
86         assert_eq!(res.status(), StatusCode::OK);
87         assert_eq!(hyper::body::to_bytes(res.into_body()).await.unwrap(), expected_body);
88     }
89
90     #[tokio::test]
91     async fn test_hello() {
92         // no user-agent
93         let res = super::app()
94             .oneshot(Request::builder().uri("/hello/rust").body(Body::empty()).unwrap())
95             .await
96             .unwrap();
97         assert_res_ok_body(res, b"Hello rust").await;
98
99         // with user-agent
100         let res = super::app()
101             .oneshot(Request::builder()
102                         .uri("/hello/rust")
103                         .header("user-agent", "TestBrowser 0.1")
104                         .body(Body::empty()).unwrap())
105             .await
106             .unwrap();
107         assert_res_ok_body(res, b"Hello rust from TestBrowser 0.1").await;
108     }
109 }