openzeppelin_relayer/repositories/transaction_counter/
transaction_counter_redis.rs1use super::TransactionCounterTrait;
9use crate::models::RepositoryError;
10use crate::repositories::redis_base::RedisRepository;
11use crate::utils::RedisConnections;
12use async_trait::async_trait;
13use redis::AsyncCommands;
14use std::fmt;
15use std::sync::Arc;
16use tracing::debug;
17
18const COUNTER_PREFIX: &str = "transaction_counter";
19
20#[derive(Clone)]
21pub struct RedisTransactionCounter {
22 pub connections: Arc<RedisConnections>,
23 pub key_prefix: String,
24}
25
26impl RedisRepository for RedisTransactionCounter {}
27
28impl fmt::Debug for RedisTransactionCounter {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 f.debug_struct("RedisTransactionCounter")
31 .field("key_prefix", &self.key_prefix)
32 .finish()
33 }
34}
35
36impl RedisTransactionCounter {
37 pub fn new(
38 connections: Arc<RedisConnections>,
39 key_prefix: String,
40 ) -> Result<Self, RepositoryError> {
41 if key_prefix.is_empty() {
42 return Err(RepositoryError::InvalidData(
43 "Redis key prefix cannot be empty".to_string(),
44 ));
45 }
46
47 Ok(Self {
48 connections,
49 key_prefix,
50 })
51 }
52
53 fn counter_key(&self, relayer_id: &str, address: &str) -> String {
55 format!(
56 "{}:{}:{}:{}",
57 self.key_prefix, COUNTER_PREFIX, relayer_id, address
58 )
59 }
60}
61
62#[async_trait]
63impl TransactionCounterTrait for RedisTransactionCounter {
64 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
65 if relayer_id.is_empty() {
66 return Err(RepositoryError::InvalidData(
67 "Relayer ID cannot be empty".to_string(),
68 ));
69 }
70
71 if address.is_empty() {
72 return Err(RepositoryError::InvalidData(
73 "Address cannot be empty".to_string(),
74 ));
75 }
76
77 let key = self.counter_key(relayer_id, address);
78 debug!(relayer_id = %relayer_id, address = %address, "getting counter for relayer and address");
79
80 let mut conn = self
81 .get_connection(self.connections.reader(), "get")
82 .await?;
83
84 let value: Option<u64> = conn
85 .get(&key)
86 .await
87 .map_err(|e| self.map_redis_error(e, "get_counter"))?;
88
89 debug!(value = ?value, "retrieved counter value");
90 Ok(value)
91 }
92
93 async fn get_and_increment(
94 &self,
95 relayer_id: &str,
96 address: &str,
97 ) -> Result<u64, RepositoryError> {
98 if relayer_id.is_empty() {
99 return Err(RepositoryError::InvalidData(
100 "Relayer ID cannot be empty".to_string(),
101 ));
102 }
103
104 if address.is_empty() {
105 return Err(RepositoryError::InvalidData(
106 "Address cannot be empty".to_string(),
107 ));
108 }
109
110 let key = self.counter_key(relayer_id, address);
111 debug!(relayer_id = %relayer_id, address = %address, "getting and incrementing counter for relayer and address");
112
113 let mut conn = self
114 .get_connection(self.connections.primary(), "get_and_increment")
115 .await?;
116
117 let new_value: u64 = conn
119 .incr(&key, 1)
120 .await
121 .map_err(|e| self.map_redis_error(e, "get_and_increment"))?;
122
123 let current = new_value.saturating_sub(1);
124
125 debug!(from = %current, to = %(current + 1), "counter incremented");
126 Ok(current)
127 }
128
129 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
130 if relayer_id.is_empty() {
131 return Err(RepositoryError::InvalidData(
132 "Relayer ID cannot be empty".to_string(),
133 ));
134 }
135
136 if address.is_empty() {
137 return Err(RepositoryError::InvalidData(
138 "Address cannot be empty".to_string(),
139 ));
140 }
141
142 let key = self.counter_key(relayer_id, address);
143 debug!(relayer_id = %relayer_id, address = %address, "decrementing counter for relayer and address");
144
145 let mut conn = self
146 .get_connection(self.connections.primary(), "decrement")
147 .await?;
148
149 let exists: bool = conn
151 .exists(&key)
152 .await
153 .map_err(|e| self.map_redis_error(e, "check_counter_exists"))?;
154
155 if !exists {
156 return Err(RepositoryError::NotFound(format!(
157 "Counter not found for relayer {relayer_id} and address {address}"
158 )));
159 }
160
161 let new_value: i64 = conn
163 .decr(&key, 1)
164 .await
165 .map_err(|e| self.map_redis_error(e, "decrement_counter"))?;
166
167 let new_value = if new_value < 0 {
168 let _: () = conn
170 .set(&key, 0)
171 .await
172 .map_err(|e| self.map_redis_error(e, "correct_negative_counter"))?;
173 0u64
174 } else {
175 new_value as u64
176 };
177
178 debug!(new_value = %new_value, "counter decremented");
179 Ok(new_value)
180 }
181
182 async fn set(
183 &self,
184 relayer_id: &str,
185 address: &str,
186 value: u64,
187 ) -> Result<(), RepositoryError> {
188 if relayer_id.is_empty() {
189 return Err(RepositoryError::InvalidData(
190 "Relayer ID cannot be empty".to_string(),
191 ));
192 }
193
194 if address.is_empty() {
195 return Err(RepositoryError::InvalidData(
196 "Address cannot be empty".to_string(),
197 ));
198 }
199
200 let key = self.counter_key(relayer_id, address);
201 debug!(relayer_id = %relayer_id, address = %address, value = %value, "setting counter for relayer and address");
202
203 let mut conn = self
204 .get_connection(self.connections.primary(), "set")
205 .await?;
206
207 let _: () = conn
208 .set(&key, value)
209 .await
210 .map_err(|e| self.map_redis_error(e, "set_counter"))?;
211
212 debug!(value = %value, "counter set");
213 Ok(())
214 }
215
216 async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
217 let mut conn = self
218 .get_connection(self.connections.primary(), "drop_all_entries")
219 .await?;
220
221 let pattern = format!("{}:{}:*", self.key_prefix, COUNTER_PREFIX);
222 debug!(pattern = %pattern, "dropping all transaction counter entries");
223
224 let mut cursor: u64 = 0;
227 let mut all_keys: Vec<String> = Vec::new();
228
229 loop {
230 let (next_cursor, keys): (u64, Vec<String>) = redis::cmd("SCAN")
231 .cursor_arg(cursor)
232 .arg("MATCH")
233 .arg(&pattern)
234 .arg("COUNT")
235 .arg(100)
236 .query_async(&mut conn)
237 .await
238 .map_err(|e| self.map_redis_error(e, "drop_all_entries_scan"))?;
239
240 all_keys.extend(keys);
241
242 cursor = next_cursor;
243 if cursor == 0 {
244 break;
245 }
246 }
247
248 if !all_keys.is_empty() {
250 let mut pipe = redis::pipe();
251 pipe.atomic();
252 for key in &all_keys {
253 pipe.del(key);
254 }
255 pipe.exec_async(&mut conn)
256 .await
257 .map_err(|e| self.map_redis_error(e, "drop_all_entries_delete"))?;
258 }
259
260 debug!(total_deleted = %all_keys.len(), "dropped all transaction counter entries");
261 Ok(())
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use std::sync::Arc;
269 use tokio;
270 use uuid::Uuid;
271
272 async fn setup_test_repo() -> RedisTransactionCounter {
273 setup_test_repo_with_prefix("test_counter").await
274 }
275
276 async fn setup_test_repo_with_prefix(prefix: &str) -> RedisTransactionCounter {
277 let redis_url =
278 std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379".to_string());
279 let cfg = deadpool_redis::Config::from_url(&redis_url);
280 let pool = Arc::new(
281 cfg.builder()
282 .expect("Failed to create pool builder")
283 .max_size(16)
284 .runtime(deadpool_redis::Runtime::Tokio1)
285 .build()
286 .expect("Failed to build Redis pool"),
287 );
288 let connections = Arc::new(RedisConnections::new_single_pool(pool));
289
290 RedisTransactionCounter::new(connections, prefix.to_string())
291 .expect("Failed to create Redis transaction counter")
292 }
293
294 #[tokio::test]
295 #[ignore = "Requires active Redis instance"]
296 async fn test_get_nonexistent_counter() {
297 let repo = setup_test_repo().await;
298 let random_id = Uuid::new_v4().to_string();
299 let result = repo.get(&random_id, "0x1234").await.unwrap();
300 assert_eq!(result, None);
301 }
302
303 #[tokio::test]
304 #[ignore = "Requires active Redis instance"]
305 async fn test_set_and_get_counter() {
306 let repo = setup_test_repo().await;
307 let relayer_id = uuid::Uuid::new_v4().to_string();
308 let address = uuid::Uuid::new_v4().to_string();
309
310 repo.set(&relayer_id, &address, 100).await.unwrap();
311 let result = repo.get(&relayer_id, &address).await.unwrap();
312 assert_eq!(result, Some(100));
313 }
314
315 #[tokio::test]
316 #[ignore = "Requires active Redis instance"]
317 async fn test_get_and_increment() {
318 let repo = setup_test_repo().await;
319 let relayer_id = uuid::Uuid::new_v4().to_string();
320 let address = uuid::Uuid::new_v4().to_string();
321
322 let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
324 assert_eq!(result, 0);
325
326 let current = repo.get(&relayer_id, &address).await.unwrap();
327 assert_eq!(current, Some(1));
328
329 let result = repo.get_and_increment(&relayer_id, &address).await.unwrap();
331 assert_eq!(result, 1);
332
333 let current = repo.get(&relayer_id, &address).await.unwrap();
334 assert_eq!(current, Some(2));
335 }
336
337 #[tokio::test]
338 #[ignore = "Requires active Redis instance"]
339 async fn test_decrement() {
340 let repo = setup_test_repo().await;
341 let relayer_id = uuid::Uuid::new_v4().to_string();
342 let address = uuid::Uuid::new_v4().to_string();
343
344 repo.set(&relayer_id, &address, 5).await.unwrap();
346
347 let result = repo.decrement(&relayer_id, &address).await.unwrap();
349 assert_eq!(result, 4);
350
351 let current = repo.get(&relayer_id, &address).await.unwrap();
352 assert_eq!(current, Some(4));
353 }
354
355 #[tokio::test]
356 #[ignore = "Requires active Redis instance"]
357 async fn test_decrement_not_found() {
358 let repo = setup_test_repo().await;
359 let result = repo.decrement("nonexistent", "0x1234").await;
360 assert!(matches!(result, Err(RepositoryError::NotFound(_))));
361 }
362
363 #[tokio::test]
364 #[ignore = "Requires active Redis instance"]
365 async fn test_empty_validation() {
366 let repo = setup_test_repo().await;
367
368 let result = repo.get("", "0x1234").await;
370 assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
371
372 let result = repo.get("relayer", "").await;
374 assert!(matches!(result, Err(RepositoryError::InvalidData(_))));
375 }
376
377 #[tokio::test]
378 #[ignore = "Requires active Redis instance"]
379 async fn test_multiple_relayers() {
380 let repo = setup_test_repo().await;
381 let relayer_1 = uuid::Uuid::new_v4().to_string();
382 let relayer_2 = uuid::Uuid::new_v4().to_string();
383 let address_1 = uuid::Uuid::new_v4().to_string();
384 let address_2 = uuid::Uuid::new_v4().to_string();
385
386 repo.set(&relayer_1, &address_1, 100).await.unwrap();
388 repo.set(&relayer_1, &address_2, 200).await.unwrap();
389 repo.set(&relayer_2, &address_1, 300).await.unwrap();
390
391 assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
393 assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
394 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
395
396 assert_eq!(
398 repo.get_and_increment(&relayer_1, &address_1)
399 .await
400 .unwrap(),
401 100
402 );
403 assert_eq!(
404 repo.get_and_increment(&relayer_1, &address_1)
405 .await
406 .unwrap(),
407 101
408 );
409 assert_eq!(
410 repo.get_and_increment(&relayer_1, &address_2)
411 .await
412 .unwrap(),
413 200
414 );
415 assert_eq!(
416 repo.get_and_increment(&relayer_1, &address_2)
417 .await
418 .unwrap(),
419 201
420 );
421 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
422 }
423
424 #[tokio::test]
425 #[ignore = "Requires active Redis instance"]
426 async fn test_drop_all_entries() {
427 let prefix = format!("test_drop_{}", uuid::Uuid::new_v4());
428 let repo = setup_test_repo_with_prefix(&prefix).await;
429 let relayer_1 = uuid::Uuid::new_v4().to_string();
430 let relayer_2 = uuid::Uuid::new_v4().to_string();
431 let address_1 = uuid::Uuid::new_v4().to_string();
432 let address_2 = uuid::Uuid::new_v4().to_string();
433
434 repo.set(&relayer_1, &address_1, 100).await.unwrap();
436 repo.set(&relayer_1, &address_2, 200).await.unwrap();
437 repo.set(&relayer_2, &address_1, 300).await.unwrap();
438
439 assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), Some(100));
441 assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), Some(200));
442 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), Some(300));
443
444 repo.drop_all_entries().await.unwrap();
446
447 assert_eq!(repo.get(&relayer_1, &address_1).await.unwrap(), None);
449 assert_eq!(repo.get(&relayer_1, &address_2).await.unwrap(), None);
450 assert_eq!(repo.get(&relayer_2, &address_1).await.unwrap(), None);
451 }
452
453 #[tokio::test]
454 #[ignore = "Requires active Redis instance"]
455 async fn test_concurrent_get_and_increment() {
456 let repo = setup_test_repo().await;
457 let relayer_id = uuid::Uuid::new_v4().to_string();
458 let address = uuid::Uuid::new_v4().to_string();
459
460 repo.set(&relayer_id, &address, 100).await.unwrap();
462
463 let handles: Vec<_> = (0..10)
465 .map(|_| {
466 let repo = repo.clone();
467 let relayer_id = relayer_id.clone();
468 let address = address.clone();
469 tokio::spawn(
470 async move { repo.get_and_increment(&relayer_id, &address).await.unwrap() },
471 )
472 })
473 .collect();
474
475 let mut results = Vec::new();
477 for handle in handles {
478 results.push(handle.await.unwrap());
479 }
480
481 results.sort();
483
484 let expected: Vec<u64> = (100..110).collect();
486 assert_eq!(results, expected);
487
488 let final_value = repo.get(&relayer_id, &address).await.unwrap();
490 assert_eq!(final_value, Some(110));
491 }
492}