]> piware.de Git - learn-rust.git/commitdiff
Closures and Cacher object
authorMartin Pitt <martin@piware.de>
Wed, 25 Aug 2021 06:25:30 +0000 (08:25 +0200)
committerMartin Pitt <martin@piware.de>
Wed, 25 Aug 2021 06:25:30 +0000 (08:25 +0200)
src/lib.rs
src/main.rs
tests/test_lib.rs

index b9ee526f5eeabddfa09e377c639fa41c88ef764c..fe3b71b286640b29919d465a9fd25f55a002b42b 100644 (file)
@@ -1,5 +1,6 @@
 use std::fs::File;
 use std::io::prelude::*;
+use std::collections::HashMap;
 
 pub fn read_file(path: &str) -> Result<String, std::io::Error> {
     let mut s = String::new();
@@ -48,3 +49,39 @@ pub fn longest<'a>(x: &'a str, y: &'a str) -> &'a str {
         y
     }
 }
+
+/// Wrap and cache an expensive calculation
+///
+/// This calls a closure just once for every distinct argument. Any subsequent
+/// call to `.value()` with the same argument uses the cached value.
+pub struct Cacher<T, A, V>
+where
+    T: Fn(A) -> V,
+    A: Eq + Copy + std::hash::Hash,
+    V: Copy,
+{
+    calc: T,
+    values: HashMap<A, V>,
+}
+
+impl<T, A, V> Cacher<T, A, V>
+where
+    T: Fn(A) -> V,
+    A: Eq + Copy + std::hash::Hash,
+    V: Copy,
+{
+    pub fn new(calc: T) -> Cacher<T, A, V> {
+        Cacher { calc, values: HashMap::new() }
+    }
+
+    pub fn value(&mut self, arg: A) -> V {
+        match self.values.get(&arg) {
+            Some(v) => *v,
+            None => {
+                let v = (self.calc)(arg);
+                self.values.insert(arg, v);
+                v
+            }
+        }
+    }
+}
index 718ecb5f22028762b437c4b87a60dac13f99c90c..0ef1b16f4f9b19c7e9e1c71377abd4273acc0d1b 100644 (file)
@@ -137,10 +137,31 @@ fn test_generics() {
     println!("longest string: {}", l);
 }
 
+fn test_closures() {
+    let mut expensive_int_result = Cacher::new(|x| {
+        println!("calculating expensive int result for {}", x);
+        2 * x
+    });
+
+    println!("1st int call for value 1: {}", expensive_int_result.value(1));
+    println!("2nd int call for value 1: {}", expensive_int_result.value(1));
+    println!("1st int call for value 2: {}", expensive_int_result.value(2));
+
+    let mut expensive_str_result = Cacher::new(|x: &str| {
+        println!("calculating expensive str result for {}", x);
+        x.len()
+    });
+
+    println!("1st int call for value abc: {}", expensive_str_result.value("abc"));
+    println!("2nd int call for value abc: {}", expensive_str_result.value("abc"));
+    println!("1st int call for value defg: {}", expensive_str_result.value("defg"));
+}
+
 fn main() {
     test_strings();
     test_vectors();
     test_hashmaps();
     test_files();
     test_generics();
+    test_closures();
 }
index 44d7f87cee54405631e9412ced6542eec01eaed6..186ffd7dcf3756d8262f9aefdd6bfb0ff192525b 100644 (file)
@@ -6,3 +6,46 @@ fn test_longest() {
     assert_eq!(longest("abc", "def"), "def");
     assert_eq!(longest("abc", "defg"), "defg");
 }
+
+// FIXME: How to make this not unsafe?
+static mut CALLED: u32 = 0;
+
+#[test]
+fn test_cacher_int_int() {
+    unsafe { CALLED = 0; }
+    let mut cacher = Cacher::new(|x| {
+        unsafe { CALLED += 1; }
+        2 * x
+    });
+    assert_eq!(cacher.value(1), 2);
+    unsafe { assert_eq!(CALLED, 1); }
+    // second time cached
+    assert_eq!(cacher.value(1), 2);
+    unsafe { assert_eq!(CALLED, 1); }
+    // re-evaluated for new value
+    assert_eq!(cacher.value(-2), -4);
+    unsafe { assert_eq!(CALLED, 2); }
+    // old arg still cached
+    assert_eq!(cacher.value(1), 2);
+    unsafe { assert_eq!(CALLED, 2); }
+}
+
+#[test]
+fn test_cacher_str_usize() {
+    unsafe { CALLED = 0; }
+    let mut cacher = Cacher::new(|x: &str| {
+        unsafe { CALLED += 1; }
+        x.len()
+    });
+    assert_eq!(cacher.value("abc"), 3);
+    unsafe { assert_eq!(CALLED, 1); }
+    // second time cached
+    assert_eq!(cacher.value("abc"), 3);
+    unsafe { assert_eq!(CALLED, 1); }
+    // re-evaluated for new value
+    assert_eq!(cacher.value("defg"), 4);
+    unsafe { assert_eq!(CALLED, 2); }
+    // old arg still cached
+    assert_eq!(cacher.value("abc"), 3);
+    unsafe { assert_eq!(CALLED, 2); }
+}