]> piware.de Git - learn-rust.git/commitdiff
tokio-tutorial-mini-redis: Shared global state
authorMartin Pitt <martin@piware.de>
Fri, 16 Sep 2022 12:02:30 +0000 (14:02 +0200)
committerMartin Pitt <martin@piware.de>
Fri, 16 Sep 2022 12:02:30 +0000 (14:02 +0200)
tokio-tutorial-mini-redis/Cargo.toml
tokio-tutorial-mini-redis/src/main.rs

index eae82dba457479e7fe9cb21dbf56eece912b457f..4a82df3fbcaa6e4285e78f9b8aae8eea3ad85968 100644 (file)
@@ -6,5 +6,6 @@ edition = "2021"
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
+bytes = "1"
 tokio = { version = "1", features = ["full"] }
 mini-redis = "0.4"
index d255ab36c270eb741ebc9f22c18d0a02aff6448a..2488f033ead6e2808c08fe6f4499e3b0818f9d15 100644 (file)
@@ -1,37 +1,39 @@
 use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
 
+use bytes::Bytes;
 use mini_redis::{Connection, Frame};
 use mini_redis::Command::{self, Get, Set};
 use tokio::net::{TcpListener, TcpStream};
 
+type Db = Arc<Mutex<HashMap<String, Bytes>>>;
+
 #[tokio::main]
 async fn main() {
     let listener = TcpListener::bind("127.0.0.1:6379").await.unwrap();
+    let db: Db = Arc::new(Mutex::new(HashMap::new()));
 
     loop {
         // The second item contains the IP and port of the new connection
         let (socket, _) = listener.accept().await.unwrap();
-        tokio::spawn(async move { process(socket).await });
+        let db_i = db.clone();
+        tokio::spawn(async move { process(socket, db_i).await });
     }
 }
 
-async fn process(socket: TcpStream) {
-    let mut db = HashMap::new();
+async fn process(socket: TcpStream, db: Db) {
     let mut connection = Connection::new(socket);
 
     while let Some(frame) = connection.read_frame().await.unwrap() {
         let response = match Command::from_frame(frame).unwrap() {
             Set(cmd) => {
                 // The value is stored as `Vec<u8>`
-                db.insert(cmd.key().to_string(), cmd.value().to_vec());
+                db.lock().unwrap().insert(cmd.key().to_string(), cmd.value().clone());
                 Frame::Simple("OK".to_string())
             }
             Get(cmd) => {
-                if let Some(value) = db.get(cmd.key()) {
-                    // `Frame::Bulk` expects data to be of type `Bytes`. This
-                    // type will be covered later in the tutorial. For now,
-                    // `&Vec<u8>` is converted to `Bytes` using `into()`.
-                    Frame::Bulk(value.clone().into())
+                if let Some(value) = db.lock().unwrap().get(cmd.key()) {
+                    Frame::Bulk(value.clone())
                 } else {
                     Frame::Null
                 }