]> piware.de Git - learn-rust.git/commitdiff
axum-server: Add websocket route
authorMartin Pitt <martin@piware.de>
Sat, 12 Nov 2022 09:54:47 +0000 (10:54 +0100)
committerMartin Pitt <martin@piware.de>
Sat, 12 Nov 2022 09:54:47 +0000 (10:54 +0100)
axum-server/Cargo.toml
axum-server/src/main.rs

index 978ee0601206492e21b25c60af632a64f68ffd04..0444b3957dc6e24aeb62b2cb344a5a2bd58c771f 100644 (file)
@@ -6,7 +6,7 @@ edition = "2021"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
-axum = { version = "0.5", features = ["headers"] }
+axum = { version = "0.5", features = ["ws", "headers"] }
 tokio = { version = "1", features = ["full"] }
 tower = "0.4"
 tower-http = { version = "0.3", features = ["trace", "fs"] }
index 3212d3f250962e2f422419951e1ce0fd0d355b13..31c0a4f619576870762aebe5ab0ca3a7619b169c 100644 (file)
@@ -2,7 +2,7 @@ use std::io;
 
 use axum::{
     routing::{get, get_service},
-    extract::{Path, TypedHeader},
+    extract::{Path, TypedHeader, ws},
     http::{StatusCode},
     response,
     Router};
@@ -15,6 +15,33 @@ async fn hello(Path(name): Path<String>, user_agent: Option<TypedHeader<axum::he
     }
 }
 
+async fn ws_echo(mut socket: ws::WebSocket) {
+    while let Some(msg) = socket.recv().await {
+        if let Ok(msg) = msg {
+            tracing::debug!("websocket got message: {:?}", msg);
+
+            let reply = match msg  {
+                ws::Message::Text(t) => ws::Message::Text(t),
+                ws::Message::Binary(b) => ws::Message::Binary(b),
+                // axum handles Ping/Pong by itself
+                ws::Message::Ping(_) => { continue },
+                ws::Message::Pong(_) => { continue },
+                ws::Message::Close(_) => { break }
+            };
+
+            if socket.send(reply).await
+                .is_err() {
+                    tracing::info!("websocket client disconnected");
+                    break;
+                }
+        }
+        else {
+            tracing::info!("websocket client disconnected");
+            break;
+        }
+    }
+}
+
 #[tokio::main]
 async fn main() {
     tracing_subscriber::fmt::init();
@@ -26,6 +53,7 @@ async fn main() {
                        (StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", e))
                    })
         )
+        .route("/ws-echo", get(|ws: ws::WebSocketUpgrade| async {ws.on_upgrade(ws_echo)}))
         .layer(
             tower::ServiceBuilder::new()
                 .layer(