]> piware.de Git - learn-rust.git/blobdiff - axum-server/src/main.rs
concepts: rustfmt
[learn-rust.git] / axum-server / src / main.rs
index 1225f67614184b5b34cc0bb3a9703bb19d225433..fa44ac1695c12888da8f52dfa8f3ee38d7f16f88 100644 (file)
@@ -2,35 +2,144 @@ use std::io;
 
 use axum::{
     routing::{get, get_service},
-    extract::Path,
+    extract::{Path, TypedHeader, ws},
     http::{StatusCode},
     response,
     Router};
 
-async fn hello(Path(name): Path<String>) -> impl response::IntoResponse {
-    (StatusCode::OK, format!("Hello {}", name))
+async fn hello(Path(name): Path<String>, user_agent: Option<TypedHeader<axum::headers::UserAgent>>) -> impl response::IntoResponse {
+    if let Some(TypedHeader(user_agent)) = user_agent {
+        (StatusCode::OK, format!("Hello {} from {}", name, user_agent))
+    } else {
+        (StatusCode::OK, format!("Hello {}", name))
+    }
 }
 
-#[tokio::main]
-async fn main() {
-    tracing_subscriber::fmt::init();
-    let app = Router::new()
+async fn ws_echo(mut socket: ws::WebSocket) {
+    while let Some(msg) = socket.recv().await {
+        if let Ok(msg) = msg {
+            tracing::debug!("websocket got message: {:?}", msg);
+
+            let reply = match msg  {
+                ws::Message::Text(t) => ws::Message::Text(t),
+                ws::Message::Binary(b) => ws::Message::Binary(b),
+                // axum handles Ping/Pong by itself
+                ws::Message::Ping(_) => { continue },
+                ws::Message::Pong(_) => { continue },
+                ws::Message::Close(_) => { break }
+            };
+
+            if socket.send(reply).await
+                .is_err() {
+                    tracing::info!("websocket client disconnected");
+                    break;
+                }
+        }
+        else {
+            tracing::info!("websocket client disconnected");
+            break;
+        }
+    }
+}
+
+fn app() -> Router {
+    Router::new()
         .route("/hello/:name", get(hello))
-        .route("/static",
-               get_service(tower_http::services::ServeFile::new("Cargo.toml").precompressed_gzip())
+        .nest("/dir",
+               get_service(tower_http::services::ServeDir::new("../static").precompressed_gzip())
                    .handle_error(|e: io::Error| async move {
                        (StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", e))
                    })
         )
+        .route("/ws-echo", get(|ws: ws::WebSocketUpgrade| async {ws.on_upgrade(ws_echo)}))
         .layer(
             tower::ServiceBuilder::new()
-                .layer(tower_http::trace::TraceLayer::new_for_http())
-        );
+                .layer(
+                    tower_http::trace::TraceLayer::new_for_http()
+                         .make_span_with(tower_http::trace::DefaultMakeSpan::default().include_headers(true)),
+                )
+        )
+}
 
-    let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000));
+#[tokio::main]
+async fn main() {
+    tracing_subscriber::fmt::init();
+
+    let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3030));
     tracing::info!("listening on {}", addr);
     axum::Server::bind(&addr)
-        .serve(app.into_make_service())
+        .serve(app().into_make_service())
         .await
         .unwrap();
 }
+
+#[cfg(test)]
+mod tests {
+    use axum::{
+        http::{Request, StatusCode},
+        response::Response,
+        body::Body
+    };
+    use tower::ServiceExt; // for `oneshot`
+
+    async fn assert_res_ok_body(res: Response, expected_body: &[u8]) {
+        assert_eq!(res.status(), StatusCode::OK);
+        assert_eq!(hyper::body::to_bytes(res.into_body()).await.unwrap(), expected_body);
+    }
+
+    #[tokio::test]
+    async fn test_hello() {
+        // no user-agent
+        let res = super::app()
+            .oneshot(Request::builder().uri("/hello/rust").body(Body::empty()).unwrap())
+            .await
+            .unwrap();
+        assert_res_ok_body(res, b"Hello rust").await;
+
+        // with user-agent
+        let res = super::app()
+            .oneshot(Request::builder()
+                        .uri("/hello/rust")
+                        .header("user-agent", "TestBrowser 0.1")
+                        .body(Body::empty()).unwrap())
+            .await
+            .unwrap();
+        assert_res_ok_body(res, b"Hello rust from TestBrowser 0.1").await;
+    }
+
+    #[tokio::test]
+    async fn test_static_dir() {
+        let res = super::app()
+            .oneshot(Request::builder().uri("/dir/plain.txt").body(Body::empty()).unwrap())
+            .await
+            .unwrap();
+        assert_res_ok_body(res, b"Hello world! This is uncompressed text.\n").await;
+
+        // transparent .gz lookup, without gzip transfer encoding
+        let res = super::app()
+            .oneshot(Request::builder()
+                        .uri("/dir/dir1/optzip.txt")
+                        .header("accept-encoding", "deflate")
+                        .body(Body::empty()).unwrap())
+            .await
+            .unwrap();
+        assert_eq!(res.status(), StatusCode::OK);
+        // that returns the uncompressed file
+        assert_res_ok_body(res, b"This file is available uncompressed or compressed\n\
+                                  AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\n").await;
+
+        // transparent .gz lookup, with gzip transfer encoding
+        let res = super::app()
+            .oneshot(Request::builder()
+                        .uri("/dir/dir1/optzip.txt")
+                        .header("accept-encoding", "deflate, gzip")
+                        .body(Body::empty()).unwrap())
+            .await
+            .unwrap();
+        assert_eq!(res.status(), StatusCode::OK);
+        let res_bytes: &[u8] = &hyper::body::to_bytes(res.into_body()).await.unwrap();
+        // that returns the compressed file
+        assert_eq!(res_bytes.len(), 63); // file size of ../static/dir1/optzip.txt.gz
+        assert_eq!(res_bytes[0], 31);
+    }
+}