openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_redis.rs

1//! Redis implementation of the transaction counter.
2//!
3//! This module provides a Redis-based implementation of the `TransactionCounterTrait`,
4//! allowing transaction counters to be stored and retrieved from a Redis database.
5//! The implementation includes comprehensive error handling, logging, and atomic operations
6//! to ensure consistency when incrementing and decrementing counters.
7
8use super::TransactionCounterTrait;
9use crate::models::RepositoryError;
10use crate::repositories::redis_base::RedisRepository;
11use crate::utils::RedisConnections;
12use async_trait::async_trait;
13use redis::AsyncCommands;
14use std::fmt;
15use std::sync::Arc;
16use tracing::debug;
17
18const COUNTER_PREFIX: &str = "transaction_counter";
19
20#[derive(Clone)]
21pub struct RedisTransactionCounter {
22    pub connections: Arc<RedisConnections>,
23    pub key_prefix: String,
24}
25
26impl RedisRepository for RedisTransactionCounter {}
27
28impl fmt::Debug for RedisTransactionCounter {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        f.debug_struct("RedisTransactionCounter")
31            .field("key_prefix", &self.key_prefix)
32            .finish()
33    }
34}
35
36impl RedisTransactionCounter {
37    pub fn new(
38        connections: Arc<RedisConnections>,
39        key_prefix: String,
40    ) -> Result<Self, RepositoryError> {
41        if key_prefix.is_empty() {
42            return Err(RepositoryError::InvalidData(
43                "Redis key prefix cannot be empty".to_string(),
44            ));
45        }
46
47        Ok(Self {
48            connections,
49            key_prefix,
50        })
51    }
52
53    /// Generate key for transaction counter: {prefix}:transaction_counter:{relayer_id}:{address}
54    fn counter_key(&self, relayer_id: &str, address: &str) -> String {
55        format!(
56            "{}:{}:{}:{}",
57            self.key_prefix, COUNTER_PREFIX, relayer_id, address
58        )
59    }
60}
61
62#[async_trait]
63impl TransactionCounterTrait for RedisTransactionCounter {
64    async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
65        if relayer_id.is_empty() {
66            return Err(RepositoryError::InvalidData(
67                "Relayer ID cannot be empty".to_string(),
68            ));
69        }
70
71        if address.is_empty() {
72            return Err(RepositoryError::InvalidData(
73                "Address cannot be empty".to_string(),
74            ));
75        }
76
77        let key = self.counter_key(relayer_id, address);
78        debug!(relayer_id = %relayer_id, address = %address, "getting counter for relayer and address");
79
80        let mut conn = self
81            .get_connection(self.connections.reader(), "get")
82            .await?;
83
84        let value: Option<u64> = conn
85            .get(&key)
86            .await
87            .map_err(|e| self.map_redis_error(e, "get_counter"))?;
88
89        debug!(value = ?value, "retrieved counter value");
90        Ok(value)
91    }
92
93    async fn get_and_increment(
94        &self,
95        relayer_id: &str,
96        address: &str,
97    ) -> Result<u64, RepositoryError> {
98        if relayer_id.is_empty() {
99            return Err(RepositoryError::InvalidData(
100                "Relayer ID cannot be empty".to_string(),
101            ));
102        }
103
104        if address.is_empty() {
105            return Err(RepositoryError::InvalidData(
106                "Address cannot be empty".to_string(),
107            ));
108        }
109
110        let key = self.counter_key(relayer_id, address);
111        debug!(relayer_id = %relayer_id, address = %address, "getting and incrementing counter for relayer and address");
112
113        let mut conn = self
114            .get_connection(self.connections.primary(), "get_and_increment")
115            .await?;
116
117        // Use Redis INCR for atomic increment
118        let new_value: u64 = conn
119            .incr(&key, 1)
120            .await
121            .map_err(|e| self.map_redis_error(e, "get_and_increment"))?;
122
123        let current = new_value.saturating_sub(1);
124
125        debug!(from = %current, to = %(current + 1), "counter incremented");
126        Ok(current)
127    }
128
129    async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
130        if relayer_id.is_empty() {
131            return Err(RepositoryError::InvalidData(
132                "Relayer ID cannot be empty".to_string(),
133            ));
134        }
135
136        if address.is_empty() {
137            return Err(RepositoryError::InvalidData(
138                "Address cannot be empty".to_string(),
139            ));
140        }
141
142        let key = self.counter_key(relayer_id, address);
143        debug!(relayer_id = %relayer_id, address = %address, "decrementing counter for relayer and address");
144
145        let mut conn = self
146            .get_connection(self.connections.primary(), "decrement")
147            .await?;
148
149        // Check if counter exists first
150        let exists: bool = conn
151            .exists(&key)
152            .await
153            .map_err(|e| self.map_redis_error(e, "check_counter_exists"))?;
154
155        if !exists {
156            return Err(RepositoryError::NotFound(format!(
157                "Counter not found for relayer {relayer_id} and address {address}"
158            )));
159        }
160
161        // Use Redis DECR and correct if it goes below 0
162        let new_value: i64 = conn
163            .decr(&key, 1)
164            .await
165            .map_err(|e| self.map_redis_error(e, "decrement_counter"))?;
166
167        let new_value = if new_value < 0 {
168            // Correct negative values back to 0
169            let _: () = conn
170                .set(&key, 0)
171                .await
172                .map_err(|e| self.map_redis_error(e, "correct_negative_counter"))?;
173            0u64
174        } else {
175            new_value as u64
176        };
177
178        debug!(new_value = %new_value, "counter decremented");
179        Ok(new_value)
180    }
181
182    async fn set(
183        &self,
184        relayer_id: &str,
185        address: &str,
186        value: u64,
187    ) -> Result<(), RepositoryError> {
188        if relayer_id.is_empty() {
189            return Err(RepositoryError::InvalidData(
190                "Relayer ID cannot be empty".to_string(),
191            ));
192        }
193
194        if address.is_empty() {
195            return Err(RepositoryError::InvalidData(
196                "Address cannot be empty".to_string(),
197            ));
198        }
199
200        let key = self.counter_key(relayer_id, address);
201        debug!(relayer_id = %relayer_id, address = %address, value = %value, "setting counter for relayer and address");
202
203        let mut conn = self
204            .get_connection(self.connections.primary(), "set")
205            .await?;
206
207        let _: () = conn
208            .set(&key, value)
209            .await
210            .map_err(|e| self.map_redis_error(e, "set_counter"))?;
211
212        debug!(value = %value, "counter set");
213        Ok(())
214    }
215
216    async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
217        let mut conn = self
218            .get_connection(self.connections.primary(), "drop_all_entries")
219            .await?;
220
221        let pattern = format!("{}:{}:*", self.key_prefix, COUNTER_PREFIX);
222        debug!(pattern = %pattern, "dropping all transaction counter entries");
223
224        // Phase 1: Collect all matching keys without mutating the keyspace.
225        // Deleting during SCAN can cause hash table rehashing, which may skip keys.
226        let mut cursor: u64 = 0;
227        let mut all_keys: Vec<String> = Vec::new();
228
229        loop {
230            let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
231                .cursor_arg(cursor)
232                .arg("MATCH")
233                .arg(&pattern)
234                .arg("COUNT")
235                .arg(100)
236                .query_async(&mut conn)
237                .await
238                .map_err(|e| self.map_redis_error(e, "drop_all_entries_scan"))?;
239
240            all_keys.extend(keys);
241
242            cursor = next_cursor;
243            if cursor == 0 {
244                break;
245            }
246        }
247
248        // Phase 2: Batch delete all collected keys.
249        if !all_keys.is_empty() {
250            let mut pipe = redis::pipe();
251            pipe.atomic();
252            for key in &all_keys {
253                pipe.del(key);
254            }
255            pipe.exec_async(&mut conn)
256                .await
257                .map_err(|e| self.map_redis_error(e, "drop_all_entries_delete"))?;
258        }
259
260        debug!(total_deleted = %all_keys.len(), "dropped all transaction counter entries");
261        Ok(())
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use std::sync::Arc;
269    use tokio;
270    use uuid::Uuid;
271
272    async fn setup_test_repo() -> RedisTransactionCounter {
273        setup_test_repo_with_prefix("test_counter").await
274    }
275
276    async fn setup_test_repo_with_prefix(prefix: &str) -> RedisTransactionCounter {
277        let redis_url =
278            std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
279        let cfg = deadpool_redis::Config::from_url(&redis_url);
280        let pool = Arc::new(
281            cfg.builder()
282                .expect("Failed to create pool builder")
283                .max_size(16)
284                .runtime(deadpool_redis::Runtime::Tokio1)
285                .build()
286                .expect("Failed to build Redis pool"),
287        );
288        let connections = Arc::new(RedisConnections::new_single_pool(pool));
289
290        RedisTransactionCounter::new(connections, prefix.to_string())
291            .expect("Failed to create Redis transaction counter")
292    }
293
294    #[tokio::test]
295    #[ignore = "Requires active Redis instance"]
296    async fn test_get_nonexistent_counter() {
297        let repo = setup_test_repo().await;
298        let random_id = Uuid::new_v4().to_string();
299        let result = repo.get(&random_id, "0x1234").await.unwrap();
300        assert_eq!(result, None);
301    }
302
303    #[tokio::test]
304    #[ignore = "Requires active Redis instance"]
305    async fn test_set_and_get_counter() {
306        let repo = setup_test_repo().await;
307        let relayer_id = uuid::Uuid::new_v4().to_string();
308        let address = uuid::Uuid::new_v4().to_string();
309
310        repo.set(&relayer_id, &address, 100).await.unwrap();
311        let result = repo.get(&relayer_id, &address).await.unwrap();
312        assert_eq!(result, Some(100));
313    }
314
315    #[tokio::test]
316    #[ignore = "Requires active Redis instance"]
317    async fn test_get_and_increment() {
318        let repo = setup_test_repo().await;
319        let relayer_id = uuid::Uuid::new_v4().to_string();
320        let address = uuid::Uuid::new_v4().to_string();
321
322        // First increment should return 0 and set to 1
323        let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
324        assert_eq!(result, 0);
325
326        let current = repo.get(&relayer_id, &address).await.unwrap();
327        assert_eq!(current, Some(1));
328
329        // Second increment should return 1 and set to 2
330        let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
331        assert_eq!(result, 1);
332
333        let current = repo.get(&relayer_id, &address).await.unwrap();
334        assert_eq!(current, Some(2));
335    }
336
337    #[tokio::test]
338    #[ignore = "Requires active Redis instance"]
339    async fn test_decrement() {
340        let repo = setup_test_repo().await;
341        let relayer_id = uuid::Uuid::new_v4().to_string();
342        let address = uuid::Uuid::new_v4().to_string();
343
344        // Set initial value
345        repo.set(&relayer_id, &address, 5).await.unwrap();
346
347        // Decrement should return 4
348        let result = repo.decrement(&relayer_id, &address).await.unwrap();
349        assert_eq!(result, 4);
350
351        let current = repo.get(&relayer_id, &address).await.unwrap();
352        assert_eq!(current, Some(4));
353    }
354
355    #[tokio::test]
356    #[ignore = "Requires active Redis instance"]
357    async fn test_decrement_not_found() {
358        let repo = setup_test_repo().await;
359        let result = repo.decrement("nonexistent", "0x1234").await;
360        assert!(matches!(result, Err(RepositoryError::NotFound(_))));
361    }
362
363    #[tokio::test]
364    #[ignore = "Requires active Redis instance"]
365    async fn test_empty_validation() {
366        let repo = setup_test_repo().await;
367
368        // Test empty relayer_id
369        let result = repo.get("", "0x1234").await;
370        assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
371
372        // Test empty address
373        let result = repo.get("relayer", "").await;
374        assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
375    }
376
377    #[tokio::test]
378    #[ignore = "Requires active Redis instance"]
379    async fn test_multiple_relayers() {
380        let repo = setup_test_repo().await;
381        let relayer_1 = uuid::Uuid::new_v4().to_string();
382        let relayer_2 = uuid::Uuid::new_v4().to_string();
383        let address_1 = uuid::Uuid::new_v4().to_string();
384        let address_2 = uuid::Uuid::new_v4().to_string();
385
386        // Set different values for different relayer/address combinations
387        repo.set(&relayer_1, &address_1, 100).await.unwrap();
388        repo.set(&relayer_1, &address_2, 200).await.unwrap();
389        repo.set(&relayer_2, &address_1, 300).await.unwrap();
390
391        // Verify independent counters
392        assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
393        assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
394        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
395
396        // Verify independent increments
397        assert_eq!(
398            repo.get_and_increment(&relayer_1, &address_1)
399                .await
400                .unwrap(),
401            100
402        );
403        assert_eq!(
404            repo.get_and_increment(&relayer_1, &address_1)
405                .await
406                .unwrap(),
407            101
408        );
409        assert_eq!(
410            repo.get_and_increment(&relayer_1, &address_2)
411                .await
412                .unwrap(),
413            200
414        );
415        assert_eq!(
416            repo.get_and_increment(&relayer_1, &address_2)
417                .await
418                .unwrap(),
419            201
420        );
421        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
422    }
423
424    #[tokio::test]
425    #[ignore = "Requires active Redis instance"]
426    async fn test_drop_all_entries() {
427        let prefix = format!("test_drop_{}", uuid::Uuid::new_v4());
428        let repo = setup_test_repo_with_prefix(&prefix).await;
429        let relayer_1 = uuid::Uuid::new_v4().to_string();
430        let relayer_2 = uuid::Uuid::new_v4().to_string();
431        let address_1 = uuid::Uuid::new_v4().to_string();
432        let address_2 = uuid::Uuid::new_v4().to_string();
433
434        // Set up multiple counters
435        repo.set(&relayer_1, &address_1, 100).await.unwrap();
436        repo.set(&relayer_1, &address_2, 200).await.unwrap();
437        repo.set(&relayer_2, &address_1, 300).await.unwrap();
438
439        // Verify they exist
440        assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
441        assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
442        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
443
444        // Drop all
445        repo.drop_all_entries().await.unwrap();
446
447        // Verify all are gone
448        assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), None);
449        assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), None);
450        assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), None);
451    }
452
453    #[tokio::test]
454    #[ignore = "Requires active Redis instance"]
455    async fn test_concurrent_get_and_increment() {
456        let repo = setup_test_repo().await;
457        let relayer_id = uuid::Uuid::new_v4().to_string();
458        let address = uuid::Uuid::new_v4().to_string();
459
460        // Set initial value
461        repo.set(&relayer_id, &address, 100).await.unwrap();
462
463        // Create multiple concurrent tasks that increment the counter
464        let handles: Vec<_> = (0..10)
465            .map(|_| {
466                let repo = repo.clone();
467                let relayer_id = relayer_id.clone();
468                let address = address.clone();
469                tokio::spawn(
470                    async move { repo.get_and_increment(&relayer_id, &address).await.unwrap() },
471                )
472            })
473            .collect();
474
475        // Wait for all tasks to complete and collect results
476        let mut results = Vec::new();
477        for handle in handles {
478            results.push(handle.await.unwrap());
479        }
480
481        // Sort results to check they are sequential
482        results.sort();
483
484        // Verify we get exactly the values 100-109 (no duplicates, no gaps)
485        let expected: Vec<u64> = (100..110).collect();
486        assert_eq!(results, expected);
487
488        // Verify final value is 110
489        let final_value = repo.get(&relayer_id, &address).await.unwrap();
490        assert_eq!(final_value, Some(110));
491    }
492}