]> piware.de Git - learn-rust.git/blob - axum-server/src/main.rs
31c0a4f619576870762aebe5ab0ca3a7619b169c
[learn-rust.git] / axum-server / src / main.rs
1 use std::io;
2
3 use axum::{
4     routing::{get, get_service},
5     extract::{Path, TypedHeader, ws},
6     http::{StatusCode},
7     response,
8     Router};
9
10 async fn hello(Path(name): Path<String>, user_agent: Option<TypedHeader<axum::headers::UserAgent>>) -> impl response::IntoResponse {
11     if let Some(TypedHeader(user_agent)) = user_agent {
12         (StatusCode::OK, format!("Hello {} from {}", name, user_agent))
13     } else {
14         (StatusCode::OK, format!("Hello {}", name))
15     }
16 }
17
18 async fn ws_echo(mut socket: ws::WebSocket) {
19     while let Some(msg) = socket.recv().await {
20         if let Ok(msg) = msg {
21             tracing::debug!("websocket got message: {:?}", msg);
22
23             let reply = match msg  {
24                 ws::Message::Text(t) => ws::Message::Text(t),
25                 ws::Message::Binary(b) => ws::Message::Binary(b),
26                 // axum handles Ping/Pong by itself
27                 ws::Message::Ping(_) => { continue },
28                 ws::Message::Pong(_) => { continue },
29                 ws::Message::Close(_) => { break }
30             };
31
32             if socket.send(reply).await
33                 .is_err() {
34                     tracing::info!("websocket client disconnected");
35                     break;
36                 }
37         }
38         else {
39             tracing::info!("websocket client disconnected");
40             break;
41         }
42     }
43 }
44
45 #[tokio::main]
46 async fn main() {
47     tracing_subscriber::fmt::init();
48     let app = Router::new()
49         .route("/hello/:name", get(hello))
50         .route("/static",
51                get_service(tower_http::services::ServeFile::new("Cargo.toml").precompressed_gzip())
52                    .handle_error(|e: io::Error| async move {
53                        (StatusCode::INTERNAL_SERVER_ERROR, format!("Unhandled internal error: {}", e))
54                    })
55         )
56         .route("/ws-echo", get(|ws: ws::WebSocketUpgrade| async {ws.on_upgrade(ws_echo)}))
57         .layer(
58             tower::ServiceBuilder::new()
59                 .layer(
60                     tower_http::trace::TraceLayer::new_for_http()
61                          .make_span_with(tower_http::trace::DefaultMakeSpan::default().include_headers(true)),
62                 )
63         );
64
65     let addr = std::net::SocketAddr::from(([127, 0, 0, 1], 3000));
66     tracing::info!("listening on {}", addr);
67     axum::Server::bind(&addr)
68         .serve(app.into_make_service())
69         .await
70         .unwrap();
71 }