fa44ac1695c12888da8f52dfa8f3ee38d7f16f88
[learn-rust.git] / axum-server / src / main.rs
1 use std::io;
2
3 use axum::{
4     routing::{get, get_service},
5     extract::{Path, TypedHeader, ws},
6     http::{StatusCode},
7     response,
8     Router};
9
10 async fn hello(Path(name): Path<String>, user_agent: Option<TypedHeader<axum::headers::UserAgent>>) -> impl response::IntoResponse {
11     if let Some(TypedHeader(user_agent)) = user_agent {
12         (StatusCode::OK, format!("Hello {} from {}", name, user_agent))
13     } else {
14         (StatusCode::OK, format!("Hello {}", name))
15     }
16 }
17
18 async fn ws_echo(mut socket: ws::WebSocket) {
19     while let Some(msg) = socket.recv().await {
20         if let Ok(msg) = msg {
21             tracing::debug!("websocket got message: {:?}", msg);
22
23             let reply = match msg  {
24                 ws::Message::Text(t) => ws::Message::Text(t),
25                 ws::Message::Binary(b) => ws::Message::Binary(b),
26                 // axum handles Ping/Pong by itself
27                 ws::Message::Ping(_) => { continue },
28                 ws::Message::Pong(_) => { continue },
29                 ws::Message::Close(_) => { break }
30             };
31
32             if socket.send(reply).await
33                 .is_err() {
34                     tracing::info!("websocket client disconnected");
35                     break;
36                 }
37         }
38         else {
39             tracing::info!("websocket client disconnected");
40             break;
41         }
42     }
43 }
44
45 fn app() -> Router {
46     Router::new()
47         .route("/hello/:name", get(hello))
48         .nest("/dir",
49                get_service(tower_http::services::ServeDir::new("../static").precompressed_gzip())
50                    .handle_error(|e: io::Error| async move {
51                        (StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", e))
52                    })
53         )
54         .route("/ws-echo", get(|ws: ws::WebSocketUpgrade| async {ws.on_upgrade(ws_echo)}))
55         .layer(
56             tower::ServiceBuilder::new()
57                 .layer(
58                     tower_http::trace::TraceLayer::new_for_http()
59                          .make_span_with(tower_http::trace::DefaultMakeSpan::default().include_headers(true)),
60                 )
61         )
62 }
63
64 #[tokio::main]
65 async fn main() {
66     tracing_subscriber::fmt::init();
67
68     let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3030));
69     tracing::info!("listening on {}", addr);
70     axum::Server::bind(&addr)
71         .serve(app().into_make_service())
72         .await
73         .unwrap();
74 }
75
76 #[cfg(test)]
77 mod tests {
78     use axum::{
79         http::{Request, StatusCode},
80         response::Response,
81         body::Body
82     };
83     use tower::ServiceExt; // for `oneshot`
84
85     async fn assert_res_ok_body(res: Response, expected_body: &[u8]) {
86         assert_eq!(res.status(), StatusCode::OK);
87         assert_eq!(hyper::body::to_bytes(res.into_body()).await.unwrap(), expected_body);
88     }
89
90     #[tokio::test]
91     async fn test_hello() {
92         // no user-agent
93         let res = super::app()
94             .oneshot(Request::builder().uri("/hello/rust").body(Body::empty()).unwrap())
95             .await
96             .unwrap();
97         assert_res_ok_body(res, b"Hello rust").await;
98
99         // with user-agent
100         let res = super::app()
101             .oneshot(Request::builder()
102                         .uri("/hello/rust")
103                         .header("user-agent", "TestBrowser 0.1")
104                         .body(Body::empty()).unwrap())
105             .await
106             .unwrap();
107         assert_res_ok_body(res, b"Hello rust from TestBrowser 0.1").await;
108     }
109
110     #[tokio::test]
111     async fn test_static_dir() {
112         let res = super::app()
113             .oneshot(Request::builder().uri("/dir/plain.txt").body(Body::empty()).unwrap())
114             .await
115             .unwrap();
116         assert_res_ok_body(res, b"Hello world! This is uncompressed text.\n").await;
117
118         // transparent .gz lookup, without gzip transfer encoding
119         let res = super::app()
120             .oneshot(Request::builder()
121                         .uri("/dir/dir1/optzip.txt")
122                         .header("accept-encoding", "deflate")
123                         .body(Body::empty()).unwrap())
124             .await
125             .unwrap();
126         assert_eq!(res.status(), StatusCode::OK);
127         // that returns the uncompressed file
128         assert_res_ok_body(res, b"This file is available uncompressed or compressed\n\
129                                   AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\n").await;
130
131         // transparent .gz lookup, with gzip transfer encoding
132         let res = super::app()
133             .oneshot(Request::builder()
134                         .uri("/dir/dir1/optzip.txt")
135                         .header("accept-encoding", "deflate, gzip")
136                         .body(Body::empty()).unwrap())
137             .await
138             .unwrap();
139         assert_eq!(res.status(), StatusCode::OK);
140         let res_bytes: &[u8] = &hyper::body::to_bytes(res.into_body()).await.unwrap();
141         // that returns the compressed file
142         assert_eq!(res_bytes.len(), 63); // file size of ../static/dir1/optzip.txt.gz
143         assert_eq!(res_bytes[0], 31);
144     }
145 }