]> piware.de Git - learn-rust.git/blobdiff - actix-server/src/main.rs
actix-server: Add echo websocket route
[learn-rust.git] / actix-server / src / main.rs
index 4a43ef55b36d8b1f2ccf31f84d11979a1efd35fd..58756de28179a96d7d046bff156d2c5231f49662 100644 (file)
@@ -1,9 +1,11 @@
 use std::path::Path;
 
-use actix_web::{get, route, web, App, HttpRequest, HttpServer, Responder, Result};
+use actix::{Actor, ActorContext, StreamHandler};
+use actix_web::{get, route, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder, Result};
 use actix_web::http::header;
 use actix_web::middleware::Logger;
 use actix_files::{Files, NamedFile};
+use actix_web_actors::ws;
 
 #[route("/hello/{name}", method="GET", method="HEAD")]
 async fn hello(params: web::Path<String>, req: HttpRequest) -> Result<String> {
@@ -37,6 +39,33 @@ async fn static_file(params: web::Path<String>, req: HttpRequest) -> Result<impl
     Ok(NamedFile::open_async(disk_path).await?.customize())
 }
 
+struct WsEcho;
+
+impl Actor for WsEcho {
+    type Context = ws::WebsocketContext<Self>;
+}
+
+impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for WsEcho {
+    fn handle(&mut self, msg: Result<ws::Message, ws::ProtocolError>, ctx: &mut Self::Context) {
+        log::info!("WsEcho got message {:?}", msg);
+        match msg {
+            Ok(ws::Message::Ping(msg)) => ctx.pong(&msg),
+            Ok(ws::Message::Text(text)) => ctx.text(text),
+            Ok(ws::Message::Binary(bin)) => ctx.binary(bin),
+            Ok(ws::Message::Close(reason)) => {
+                ctx.close(reason);
+                ctx.stop();
+            },
+            _ => ctx.stop(),
+        }
+    }
+}
+
+#[get("/ws-echo")]
+async fn ws_echo(req: HttpRequest, stream: web::Payload) -> Result<HttpResponse, Error> {
+    ws::start(WsEcho {}, &req, stream)
+}
+
 #[actix_web::main]
 async fn main() -> std::io::Result<()> {
     env_logger::init_from_env(env_logger::Env::default().default_filter_or("info"));
@@ -46,6 +75,7 @@ async fn main() -> std::io::Result<()> {
             .service(hello)
             .service(static_file)
             .service(Files::new("/dir", "../static"))
+            .service(ws_echo)
             .wrap(Logger::default())
     })
         .bind(("127.0.0.1", 3030))?
@@ -57,8 +87,12 @@ async fn main() -> std::io::Result<()> {
 mod tests {
     use actix_web::{App, body, test, web};
     use actix_web::http::{header, StatusCode};
+    use actix_web_actors::ws;
 
-    use super::{hello, static_file};
+    use futures_util::sink::SinkExt;
+    use futures_util::StreamExt;
+
+    use super::{hello, static_file, ws_echo};
 
     #[actix_web::test]
     async fn test_hello() {
@@ -132,4 +166,21 @@ mod tests {
         assert_eq!(res_bytes.len(), 63); // file size of ../static/dir1/optzip.txt.gz
         assert_eq!(res_bytes[0], 31);
     }
+
+    #[actix_web::test]
+    async fn test_ws_echo() {
+        // FIXME: duplicating the .service() call from main() here is super ugly, but it's hard to move that into a fn
+        let mut srv = actix_test::start(|| App::new().service(ws_echo));
+        let mut client = srv.ws_at("/ws-echo").await.unwrap();
+
+        // text echo
+        client.send(ws::Message::Text("hello".into())).await.unwrap();
+        let received = client.next().await.unwrap().unwrap();
+        assert_eq!(received, ws::Frame::Text("hello".into()));
+
+        // binary echo
+        client.send(ws::Message::Binary(web::Bytes::from_static(&[42, 99]))).await.unwrap();
+        let received = client.next().await.unwrap().unwrap();
+        assert_eq!(received, ws::Frame::Binary(web::Bytes::from_static(&[42, 99])));
+    }
 }