From c5233c520ab09608cabb8492b5da2292589f23b8 Mon Sep 17 00:00:00 2001 From: Martin Pitt Date: Fri, 16 Sep 2022 14:02:30 +0200 Subject: [PATCH] tokio-tutorial-mini-redis: Shared global state --- tokio-tutorial-mini-redis/Cargo.toml | 1 + tokio-tutorial-mini-redis/src/main.rs | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tokio-tutorial-mini-redis/Cargo.toml b/tokio-tutorial-mini-redis/Cargo.toml index eae82db..4a82df3 100644 --- a/tokio-tutorial-mini-redis/Cargo.toml +++ b/tokio-tutorial-mini-redis/Cargo.toml @@ -6,5 +6,6 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bytes = "1" tokio = { version = "1", features = ["full"] } mini-redis = "0.4" diff --git a/tokio-tutorial-mini-redis/src/main.rs b/tokio-tutorial-mini-redis/src/main.rs index d255ab3..2488f03 100644 --- a/tokio-tutorial-mini-redis/src/main.rs +++ b/tokio-tutorial-mini-redis/src/main.rs @@ -1,37 +1,39 @@ use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use bytes::Bytes; use mini_redis::{Connection, Frame}; use mini_redis::Command::{self, Get, Set}; use tokio::net::{TcpListener, TcpStream}; +type Db = Arc>>; + #[tokio::main] async fn main() { let listener = TcpListener::bind("127.0.0.1:6379").await.unwrap(); + let db: Db = Arc::new(Mutex::new(HashMap::new())); loop { // The second item contains the IP and port of the new connection let (socket, _) = listener.accept().await.unwrap(); - tokio::spawn(async move { process(socket).await }); + let db_i = db.clone(); + tokio::spawn(async move { process(socket, db_i).await }); } } -async fn process(socket: TcpStream) { - let mut db = HashMap::new(); +async fn process(socket: TcpStream, db: Db) { let mut connection = Connection::new(socket); while let Some(frame) = connection.read_frame().await.unwrap() { let response = match Command::from_frame(frame).unwrap() { Set(cmd) => { // The value is stored as `Vec` - db.insert(cmd.key().to_string(), cmd.value().to_vec()); + db.lock().unwrap().insert(cmd.key().to_string(), cmd.value().clone()); Frame::Simple("OK".to_string()) } Get(cmd) => { - if let Some(value) = db.get(cmd.key()) { - // `Frame::Bulk` expects data to be of type `Bytes`. This - // type will be covered later in the tutorial. For now, - // `&Vec` is converted to `Bytes` using `into()`. - Frame::Bulk(value.clone().into()) + if let Some(value) = db.lock().unwrap().get(cmd.key()) { + Frame::Bulk(value.clone()) } else { Frame::Null } -- 2.39.5