openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_in_memory.rs1use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::repositories::{RepositoryError, TransactionCounterTrait};
11
12#[derive(Debug, Default, Clone)]
13pub struct InMemoryTransactionCounter {
14 store: DashMap<(String, String), u64>, }
16
17impl InMemoryTransactionCounter {
18 pub fn new() -> Self {
19 Self {
20 store: DashMap::new(),
21 }
22 }
23}
24
25#[async_trait]
26impl TransactionCounterTrait for InMemoryTransactionCounter {
27 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
28 Ok(self
29 .store
30 .get(&(relayer_id.to_string(), address.to_string()))
31 .map(|n| *n))
32 }
33
34 async fn get_and_increment(
35 &self,
36 relayer_id: &str,
37 address: &str,
38 ) -> Result<u64, RepositoryError> {
39 let mut entry = self
40 .store
41 .entry((relayer_id.to_string(), address.to_string()))
42 .or_insert(0);
43 let current = *entry;
44 *entry += 1;
45 Ok(current)
46 }
47
48 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
49 let mut entry = self
50 .store
51 .get_mut(&(relayer_id.to_string(), address.to_string()))
52 .ok_or_else(|| RepositoryError::NotFound(format!("Counter not found for {address}")))?;
53 if *entry > 0 {
54 *entry -= 1;
55 }
56 Ok(*entry)
57 }
58
59 async fn set(
60 &self,
61 relayer_id: &str,
62 address: &str,
63 value: u64,
64 ) -> Result<(), RepositoryError> {
65 self.store
66 .insert((relayer_id.to_string(), address.to_string()), value);
67 Ok(())
68 }
69
70 async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
71 self.store.clear();
72 Ok(())
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[tokio::test]
81 async fn test_decrement_not_found() {
82 let store = InMemoryTransactionCounter::new();
83 let result = store.decrement("nonexistent", "0x1234").await;
84 assert!(matches!(result, Err(RepositoryError::NotFound(_))));
85 }
86
87 #[tokio::test]
88 async fn test_nonce_store() {
89 let store = InMemoryTransactionCounter::new();
90 let relayer_id = "relayer_1";
91 let address = "0x1234";
92
93 assert_eq!(store.get(relayer_id, address).await.unwrap(), None);
95
96 store.set(relayer_id, address, 100).await.unwrap();
98 assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(100));
99
100 assert_eq!(
102 store.get_and_increment(relayer_id, address).await.unwrap(),
103 100
104 );
105 assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(101));
106
107 assert_eq!(store.decrement(relayer_id, address).await.unwrap(), 100);
109 assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(100));
110 }
111
112 #[tokio::test]
113 async fn test_drop_all_entries() {
114 let store = InMemoryTransactionCounter::new();
115
116 store.set("relayer_1", "0x1234", 100).await.unwrap();
117 store.set("relayer_1", "0x5678", 200).await.unwrap();
118 store.set("relayer_2", "0x1234", 300).await.unwrap();
119
120 assert_eq!(store.get("relayer_1", "0x1234").await.unwrap(), Some(100));
121
122 store.drop_all_entries().await.unwrap();
123
124 assert_eq!(store.get("relayer_1", "0x1234").await.unwrap(), None);
125 assert_eq!(store.get("relayer_1", "0x5678").await.unwrap(), None);
126 assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), None);
127 }
128
129 #[tokio::test]
130 async fn test_multiple_relayers() {
131 let store = InMemoryTransactionCounter::new();
132
133 store.set("relayer_1", "0x1234", 100).await.unwrap();
135 store.set("relayer_1", "0x5678", 200).await.unwrap();
136 store.set("relayer_2", "0x1234", 300).await.unwrap();
137
138 assert_eq!(store.get("relayer_1", "0x1234").await.unwrap(), Some(100));
140 assert_eq!(store.get("relayer_1", "0x5678").await.unwrap(), Some(200));
141 assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), Some(300));
142
143 assert_eq!(
145 store
146 .get_and_increment("relayer_1", "0x1234")
147 .await
148 .unwrap(),
149 100
150 );
151 assert_eq!(
152 store
153 .get_and_increment("relayer_1", "0x1234")
154 .await
155 .unwrap(),
156 101
157 );
158 assert_eq!(
159 store
160 .get_and_increment("relayer_1", "0x5678")
161 .await
162 .unwrap(),
163 200
164 );
165 assert_eq!(
166 store
167 .get_and_increment("relayer_1", "0x5678")
168 .await
169 .unwrap(),
170 201
171 );
172 assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), Some(300));
173 }
174}