diff --git a/Cargo.lock b/Cargo.lock index 6cb32a21d0..f34ed4aea4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -222,6 +222,7 @@ dependencies = [ "rand", "redis", "regex", + "rmp-serde", "schemars", "sea-orm", "serde", @@ -2012,6 +2013,28 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "938a142ab806f18b88a97b0dea523d39e0fd730a064b035726adcfc58a8a5188" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rsa" version = "0.9.6" diff --git a/Cargo.toml b/Cargo.toml index 0bde324447..2beb8e3fe5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ quote = "1.0.36" rand = "0.8.5" redis = "0.25.3" regex = "1.10.4" +rmp-serde = "1.2.0" schemars = "0.8.16" sea-orm = "0.12.15" serde = "1.0.197" diff --git a/packages/backend-rs/Cargo.toml b/packages/backend-rs/Cargo.toml index 93c57dc321..5a18dc92cf 100644 --- a/packages/backend-rs/Cargo.toml +++ b/packages/backend-rs/Cargo.toml @@ -32,6 +32,7 @@ parse-display = { workspace = true } rand = { workspace = true } redis = { workspace = true } regex = { workspace = true } +rmp-serde = { workspace = true } schemars = { workspace = true, features = ["chrono"] } sea-orm = { workspace = true, features = ["sqlx-postgres", "runtime-tokio-rustls"] } serde = { workspace = true, features = ["derive"] } diff --git a/packages/backend-rs/src/misc/mod.rs b/packages/backend-rs/src/misc/mod.rs index 56aae3b552..24cec14969 100644 --- a/packages/backend-rs/src/misc/mod.rs +++ b/packages/backend-rs/src/misc/mod.rs @@ -11,3 +11,4 @@ pub mod meta; pub mod nyaify; pub mod password; pub mod reaction; +pub mod redis_cache; diff --git a/packages/backend-rs/src/misc/redis_cache.rs b/packages/backend-rs/src/misc/redis_cache.rs new file mode 100644 index 0000000000..d4924bb646 --- /dev/null +++ b/packages/backend-rs/src/misc/redis_cache.rs @@ -0,0 +1,84 @@ +use crate::database::{redis_conn, redis_key}; +use redis::{Commands, RedisError}; +use serde::{Deserialize, Serialize}; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Redis error: {0}")] + RedisError(#[from] RedisError), + #[error("Data serialization error: {0}")] + SerializeError(#[from] rmp_serde::encode::Error), + #[error("Data deserialization error: {0}")] + DeserializeError(#[from] rmp_serde::decode::Error), +} + +pub fn set_cache Deserialize<'a> + Serialize>( + key: &str, + value: &V, + expire_seconds: u64, +) -> Result<(), Error> { + redis_conn()?.set_ex( + redis_key(key), + rmp_serde::encode::to_vec(&value)?, + expire_seconds, + )?; + Ok(()) +} + +pub fn get_cache Deserialize<'a> + Serialize>(key: &str) -> Result, Error> { + let serialized_value: Option> = redis_conn()?.get(redis_key(key))?; + Ok(match serialized_value { + Some(v) => Some(rmp_serde::from_slice::(v.as_ref())?), + None => None, + }) +} + +#[cfg(test)] +mod unit_test { + use super::{get_cache, set_cache}; + use pretty_assertions::assert_eq; + + #[test] + fn set_get_expire() { + #[derive(serde::Deserialize, serde::Serialize, PartialEq, Debug)] + struct Data { + id: u32, + kind: String, + } + + let key_1 = "CARGO_TEST_CACHE_KEY_1"; + let value_1: Vec = vec![1, 2, 3, 4, 5]; + + let key_2 = "CARGO_TEST_CACHE_KEY_2"; + let value_2 = "Hello fedizens".to_string(); + + let key_3 = "CARGO_TEST_CACHE_KEY_3"; + let value_3 = Data { + id: 1000000007, + kind: "prime number".to_string(), + }; + + set_cache(key_1, &value_1, 1).unwrap(); + set_cache(key_2, &value_2, 1).unwrap(); + set_cache(key_3, &value_3, 1).unwrap(); + + let cached_value_1: Vec = get_cache(key_1).unwrap().unwrap(); + let cached_value_2: String = get_cache(key_2).unwrap().unwrap(); + let cached_value_3: Data = get_cache(key_3).unwrap().unwrap(); + + assert_eq!(value_1, cached_value_1); + assert_eq!(value_2, cached_value_2); + assert_eq!(value_3, cached_value_3); + + // wait for the cache to expire + std::thread::sleep(std::time::Duration::from_millis(1100)); + + let expired_value_1: Option> = get_cache(key_1).unwrap(); + let expired_value_2: Option> = get_cache(key_2).unwrap(); + let expired_value_3: Option> = get_cache(key_3).unwrap(); + + assert!(expired_value_1.is_none()); + assert!(expired_value_2.is_none()); + assert!(expired_value_3.is_none()); + } +}