openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_in_memory.rs

1//! This module provides an in-memory implementation of a transaction counter.
2//!
3//! The `InMemoryTransactionCounter` struct is used to track and manage transaction nonces
4//! for different relayers and addresses. It supports operations to get, increment, decrement,
5//! and set nonce values. This implementation uses a `DashMap` for concurrent access and
6//! modification of the nonce values.
7use async_trait::async_trait;
8use dashmap::DashMap;
9
10use crate::repositories::{RepositoryError, TransactionCounterTrait};
11
12#[derive(Debug, Default, Clone)]
13pub struct InMemoryTransactionCounter {
14    store: DashMap<(String, String), u64>, // (relayer_id, address) -> nonce/sequence
15}
16
17impl InMemoryTransactionCounter {
18    pub fn new() -> Self {
19        Self {
20            store: DashMap::new(),
21        }
22    }
23}
24
25#[async_trait]
26impl TransactionCounterTrait for InMemoryTransactionCounter {
27    async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
28        Ok(self
29            .store
30            .get(&(relayer_id.to_string(), address.to_string()))
31            .map(|n| *n))
32    }
33
34    async fn get_and_increment(
35        &self,
36        relayer_id: &str,
37        address: &str,
38    ) -> Result<u64, RepositoryError> {
39        let mut entry = self
40            .store
41            .entry((relayer_id.to_string(), address.to_string()))
42            .or_insert(0);
43        let current = *entry;
44        *entry += 1;
45        Ok(current)
46    }
47
48    async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
49        let mut entry = self
50            .store
51            .get_mut(&(relayer_id.to_string(), address.to_string()))
52            .ok_or_else(|| RepositoryError::NotFound(format!("Counter not found for {address}")))?;
53        if *entry > 0 {
54            *entry -= 1;
55        }
56        Ok(*entry)
57    }
58
59    async fn set(
60        &self,
61        relayer_id: &str,
62        address: &str,
63        value: u64,
64    ) -> Result<(), RepositoryError> {
65        self.store
66            .insert((relayer_id.to_string(), address.to_string()), value);
67        Ok(())
68    }
69
70    async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
71        self.store.clear();
72        Ok(())
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[tokio::test]
81    async fn test_decrement_not_found() {
82        let store = InMemoryTransactionCounter::new();
83        let result = store.decrement("nonexistent", "0x1234").await;
84        assert!(matches!(result, Err(RepositoryError::NotFound(_))));
85    }
86
87    #[tokio::test]
88    async fn test_nonce_store() {
89        let store = InMemoryTransactionCounter::new();
90        let relayer_id = "relayer_1";
91        let address = "0x1234";
92
93        // Initially should be None
94        assert_eq!(store.get(relayer_id, address).await.unwrap(), None);
95
96        // Set a value explicitly
97        store.set(relayer_id, address, 100).await.unwrap();
98        assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(100));
99
100        // Increment
101        assert_eq!(
102            store.get_and_increment(relayer_id, address).await.unwrap(),
103            100
104        );
105        assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(101));
106
107        // Decrement
108        assert_eq!(store.decrement(relayer_id, address).await.unwrap(), 100);
109        assert_eq!(store.get(relayer_id, address).await.unwrap(), Some(100));
110    }
111
112    #[tokio::test]
113    async fn test_drop_all_entries() {
114        let store = InMemoryTransactionCounter::new();
115
116        store.set("relayer_1", "0x1234", 100).await.unwrap();
117        store.set("relayer_1", "0x5678", 200).await.unwrap();
118        store.set("relayer_2", "0x1234", 300).await.unwrap();
119
120        assert_eq!(store.get("relayer_1", "0x1234").await.unwrap(), Some(100));
121
122        store.drop_all_entries().await.unwrap();
123
124        assert_eq!(store.get("relayer_1", "0x1234").await.unwrap(), None);
125        assert_eq!(store.get("relayer_1", "0x5678").await.unwrap(), None);
126        assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), None);
127    }
128
129    #[tokio::test]
130    async fn test_multiple_relayers() {
131        let store = InMemoryTransactionCounter::new();
132
133        // Setup different relayer/address combinations
134        store.set("relayer_1", "0x1234", 100).await.unwrap();
135        store.set("relayer_1", "0x5678", 200).await.unwrap();
136        store.set("relayer_2", "0x1234", 300).await.unwrap();
137
138        // Verify independent counters
139        assert_eq!(store.get("relayer_1", "0x1234").await.unwrap(), Some(100));
140        assert_eq!(store.get("relayer_1", "0x5678").await.unwrap(), Some(200));
141        assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), Some(300));
142
143        // Verify independent increments
144        assert_eq!(
145            store
146                .get_and_increment("relayer_1", "0x1234")
147                .await
148                .unwrap(),
149            100
150        );
151        assert_eq!(
152            store
153                .get_and_increment("relayer_1", "0x1234")
154                .await
155                .unwrap(),
156            101
157        );
158        assert_eq!(
159            store
160                .get_and_increment("relayer_1", "0x5678")
161                .await
162                .unwrap(),
163            200
164        );
165        assert_eq!(
166            store
167                .get_and_increment("relayer_1", "0x5678")
168                .await
169                .unwrap(),
170            201
171        );
172        assert_eq!(store.get("relayer_2", "0x1234").await.unwrap(), Some(300));
173    }
174}