X-Git-Url: https://piware.de/gitweb/?p=learn-rust.git;a=blobdiff_plain;f=axum-server%2Fsrc%2Fmain.rs;h=31c0a4f619576870762aebe5ab0ca3a7619b169c;hp=aa2a26337a2551d44d1a77d3cc598214d97fffa0;hb=ac5838f9b78c6894751cfacaf2f238e09a4a7f51;hpb=d154738629ef14b2141a01aac34a7b39c263eed1 diff --git a/axum-server/src/main.rs b/axum-server/src/main.rs index aa2a263..31c0a4f 100644 --- a/axum-server/src/main.rs +++ b/axum-server/src/main.rs @@ -1,12 +1,45 @@ +use std::io; + use axum::{ - routing::{get}, - extract::Path, - http, + routing::{get, get_service}, + extract::{Path, TypedHeader, ws}, + http::{StatusCode}, response, Router}; -async fn hello(Path(name): Path) -> impl response::IntoResponse { - (http::StatusCode::OK, format!("Hello {}", name)) +async fn hello(Path(name): Path, user_agent: Option>) -> impl response::IntoResponse { + if let Some(TypedHeader(user_agent)) = user_agent { + (StatusCode::OK, format!("Hello {} from {}", name, user_agent)) + } else { + (StatusCode::OK, format!("Hello {}", name)) + } +} + +async fn ws_echo(mut socket: ws::WebSocket) { + while let Some(msg) = socket.recv().await { + if let Ok(msg) = msg { + tracing::debug!("websocket got message: {:?}", msg); + + let reply = match msg { + ws::Message::Text(t) => ws::Message::Text(t), + ws::Message::Binary(b) => ws::Message::Binary(b), + // axum handles Ping/Pong by itself + ws::Message::Ping(_) => { continue }, + ws::Message::Pong(_) => { continue }, + ws::Message::Close(_) => { break } + }; + + if socket.send(reply).await + .is_err() { + tracing::info!("websocket client disconnected"); + break; + } + } + else { + tracing::info!("websocket client disconnected"); + break; + } + } } #[tokio::main] @@ -14,9 +47,19 @@ async fn main() { tracing_subscriber::fmt::init(); let app = Router::new() .route("/hello/:name", get(hello)) + .route("/static", + get_service(tower_http::services::ServeFile::new("Cargo.toml").precompressed_gzip()) + .handle_error(|e: io::Error| async move { + (StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", e)) + }) + ) + .route("/ws-echo", get(|ws: ws::WebSocketUpgrade| async {ws.on_upgrade(ws_echo)})) .layer( tower::ServiceBuilder::new() - .layer(tower_http::trace::TraceLayer::new_for_http()) + .layer( + tower_http::trace::TraceLayer::new_for_http() + .make_span_with(tower_http::trace::DefaultMakeSpan::default().include_headers(true)), + ) ); let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000));