]> piware.de Git - learn-rust.git/blobdiff - axum-server/src/main.rs
axum-server: Add websocket route
[learn-rust.git] / axum-server / src / main.rs
index aa2a26337a2551d44d1a77d3cc598214d97fffa0..31c0a4f619576870762aebe5ab0ca3a7619b169c 100644 (file)
@@ -1,12 +1,45 @@
+use std::io;
+
 use axum::{
-    routing::{get},
-    extract::Path,
-    http,
+    routing::{get, get_service},
+    extract::{Path, TypedHeader, ws},
+    http::{StatusCode},
     response,
     Router};
 
-async fn hello(Path(name): Path<String>) -> impl response::IntoResponse {
-    (http::StatusCode::OK, format!("Hello {}", name))
+async fn hello(Path(name): Path<String>, user_agent: Option<TypedHeader<axum::headers::UserAgent>>) -> impl response::IntoResponse {
+    if let Some(TypedHeader(user_agent)) = user_agent {
+        (StatusCode::OK, format!("Hello {} from {}", name, user_agent))
+    } else {
+        (StatusCode::OK, format!("Hello {}", name))
+    }
+}
+
+async fn ws_echo(mut socket: ws::WebSocket) {
+    while let Some(msg) = socket.recv().await {
+        if let Ok(msg) = msg {
+            tracing::debug!("websocket got message: {:?}", msg);
+
+            let reply = match msg  {
+                ws::Message::Text(t) => ws::Message::Text(t),
+                ws::Message::Binary(b) => ws::Message::Binary(b),
+                // axum handles Ping/Pong by itself
+                ws::Message::Ping(_) => { continue },
+                ws::Message::Pong(_) => { continue },
+                ws::Message::Close(_) => { break }
+            };
+
+            if socket.send(reply).await
+                .is_err() {
+                    tracing::info!("websocket client disconnected");
+                    break;
+                }
+        }
+        else {
+            tracing::info!("websocket client disconnected");
+            break;
+        }
+    }
 }
 
 #[tokio::main]
@@ -14,9 +47,19 @@ async fn main() {
     tracing_subscriber::fmt::init();
     let app = Router::new()
         .route("/hello/:name", get(hello))
+        .route("/static",
+               get_service(tower_http::services::ServeFile::new("Cargo.toml").precompressed_gzip())
+                   .handle_error(|e: io::Error| async move {
+                       (StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", e))
+                   })
+        )
+        .route("/ws-echo", get(|ws: ws::WebSocketUpgrade| async {ws.on_upgrade(ws_echo)}))
         .layer(
             tower::ServiceBuilder::new()
-                .layer(tower_http::trace::TraceLayer::new_for_http())
+                .layer(
+                    tower_http::trace::TraceLayer::new_for_http()
+                         .make_span_with(tower_http::trace::DefaultMakeSpan::default().include_headers(true)),
+                )
         );
 
     let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000));