]> piware.de Git - learn-rust.git/blob - async-http/src/main.rs
async-http: Unit test for handle_connection()
[learn-rust.git] / async-http / src / main.rs
1 use std::fs;
2 use std::time::Duration;
3
4 use async_std::prelude::*;
5 use async_std::io::{ Read, Write };
6 use async_std::net::{ TcpListener };
7 use async_std::task;
8 use futures::stream::StreamExt;
9
10 #[async_std::main]
11 async fn main() {
12     env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
13
14     // Listen for incoming TCP connections on localhost port 7878
15     let listener = TcpListener::bind("127.0.0.1:7878").await.unwrap();
16
17     listener.incoming().for_each_concurrent(/* limit */ None, |tcpstream| async move {
18         let tcpstream = tcpstream.unwrap();
19         task::spawn(handle_connection(tcpstream));
20     }).await;
21 }
22
23 async fn handle_connection(mut stream: impl Read + Write + Unpin) {
24     // Read the first 1024 bytes of data from the stream
25     let mut buffer = [0; 1024];
26     assert!(stream.read(&mut buffer).await.unwrap() > 0);
27
28     // Respond with greetings or a 404,
29     // depending on the data in the request
30     let (status_line, filename) = if buffer.starts_with(b"GET / HTTP/1.1\r\n") {
31         ("HTTP/1.1 200 OK", "index.html")
32     } else if buffer.starts_with(b"GET /sleep HTTP/1.1\r\n") {
33         task::sleep(Duration::from_secs(5)).await;
34         // sync version, to demonstrate concurrent async vs. parallel threads
35         // std::thread::sleep(Duration::from_secs(5));
36         ("HTTP/1.1 201 Sleep", "index.html")
37     } else {
38         ("HTTP/1.1 404 NOT FOUND", "404.html")
39     };
40     let contents = fs::read_to_string(filename).unwrap();
41     log::info!("GET {} {}", filename, status_line);
42
43     // Write response back to the stream,
44     // and flush the stream to ensure the response is sent back to the client
45     let response = format!("{status_line}\r\n\r\n{contents}");
46     stream.write_all(response.as_bytes()).await.unwrap();
47     stream.flush().await.unwrap();
48 }
49
50 #[cfg(test)]
51 mod tests {
52     use super::*;
53
54     use std::cmp;
55     use std::pin::Pin;
56
57     use futures::io::Error;
58     use futures::task::{Context, Poll};
59
60     struct MockTcpStream {
61         read_data: Vec<u8>,
62         write_data: Vec<u8>,
63     }
64
65     impl Read for MockTcpStream {
66         fn poll_read(
67             self: Pin<&mut Self>,
68             _: &mut Context,
69             buf: &mut [u8],
70         ) -> Poll<Result<usize, Error>> {
71             let size: usize = cmp::min(self.read_data.len(), buf.len());
72             buf[..size].copy_from_slice(&self.read_data[..size]);
73             Poll::Ready(Ok(size))
74         }
75     }
76
77     impl Write for MockTcpStream {
78         fn poll_write(
79             mut self: Pin<&mut Self>,
80             _: &mut Context,
81             buf: &[u8],
82         ) -> Poll<Result<usize, Error>> {
83             self.write_data = Vec::from(buf);
84
85             Poll::Ready(Ok(buf.len()))
86         }
87
88         fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Error>> {
89             Poll::Ready(Ok(()))
90         }
91
92         fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), Error>> {
93             Poll::Ready(Ok(()))
94         }
95     }
96
97     impl Unpin for MockTcpStream {}
98
99     #[async_std::test]
100     async fn test_handle_connection() {
101         let input_bytes = b"GET / HTTP/1.1\r\n";
102         let mut contents = vec![0u8; 1024];
103         contents[..input_bytes.len()].clone_from_slice(input_bytes);
104         let mut stream = MockTcpStream {
105             read_data: contents,
106             write_data: Vec::new(),
107         };
108
109         handle_connection(&mut stream).await;
110
111         let expected_contents = fs::read_to_string("index.html").unwrap();
112         let expected_response = format!("HTTP/1.1 200 OK\r\n\r\n{}", expected_contents);
113         assert!(stream.write_data.starts_with(expected_response.as_bytes()));
114     }
115 }