]> piware.de Git - learn-rust.git/blob - tokio-tutorial-mini-redis/src/bin/server.rs
tokio-tutorial-mini-redis: Add proper error handling
[learn-rust.git] / tokio-tutorial-mini-redis / src / bin / server.rs
1 use std::collections::HashMap;
2 use std::error::Error;
3 use std::sync::{Arc, Mutex};
4
5 use bytes::Bytes;
6 use mini_redis::{Connection, Frame};
7 use mini_redis::Command::{self, Get, Set};
8 use tokio::net::{TcpListener, TcpStream};
9
10 type Db = Arc<Mutex<HashMap<String, Bytes>>>;
11
12 const LISTEN: &str = "127.0.0.1:6379";
13
14 #[tokio::main]
15 async fn main() {
16     env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init();
17
18     let listener = TcpListener::bind(LISTEN).await.unwrap();
19     log::info!("Listening on {}", LISTEN);
20     let db: Db = Arc::new(Mutex::new(HashMap::new()));
21
22     loop {
23         match listener.accept().await {
24             Ok((socket, addr)) => {
25                 log::debug!("got connection from {:?}", addr);
26                 let db_i = db.clone();
27                 tokio::spawn(async move {
28                     if let Err(e) = process(socket, db_i).await {
29                         log::warn!("failed: {:?}", e);
30                     }
31                 });
32             },
33             Err(e) => log::warn!("Failed to accept connection: {}", e),
34         };
35     }
36 }
37
38 async fn process(socket: TcpStream, db: Db) -> Result<(), Box<dyn Error + Send + Sync>> {
39     let mut connection = Connection::new(socket);
40
41     while let Some(frame) = connection.read_frame().await? {
42         let response = match Command::from_frame(frame)? {
43             Set(cmd) => {
44                 // The value is stored as `Vec<u8>`
45                 db.lock().unwrap().insert(cmd.key().to_string(), cmd.value().clone());
46                 log::debug!("Set {} → {:?}", &cmd.key(), &cmd.value());
47                 Frame::Simple("OK".to_string())
48             }
49             Get(cmd) => {
50                 if let Some(value) = db.lock().unwrap().get(cmd.key()) {
51                     log::debug!("Get {} → {:?}", &cmd.key(), &value);
52                     Frame::Bulk(value.clone())
53                 } else {
54                     log::debug!("Get {} unknown key", &cmd.key());
55                     Frame::Null
56                 }
57             }
58             cmd => panic!("unimplemented {:?}", cmd),
59         };
60
61         // Write the response to the client
62         connection.write_frame(&response).await?;
63     }
64     Ok(())
65 }