openzeppelin_relayer/repositories/transaction_counter/
mod.rs1pub mod transaction_counter_in_memory;
22pub mod transaction_counter_redis;
23
24use crate::utils::RedisConnections;
25pub use transaction_counter_in_memory::InMemoryTransactionCounter;
26pub use transaction_counter_redis::RedisTransactionCounter;
27
28use async_trait::async_trait;
29use serde::Serialize;
30use std::sync::Arc;
31use thiserror::Error;
32
33#[cfg(test)]
34use mockall::automock;
35
36use crate::models::RepositoryError;
37
38#[derive(Error, Debug, Serialize)]
39pub enum TransactionCounterError {
40 #[error("No sequence found for relayer {relayer_id} and address {address}")]
41 SequenceNotFound { relayer_id: String, address: String },
42 #[error("Counter not found for {0}")]
43 NotFound(String),
44}
45
46#[allow(dead_code)]
47#[async_trait]
48#[cfg_attr(test, automock)]
49pub trait TransactionCounterTrait {
50 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError>;
51
52 async fn get_and_increment(
53 &self,
54 relayer_id: &str,
55 address: &str,
56 ) -> Result<u64, RepositoryError>;
57
58 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError>;
59
60 async fn set(&self, relayer_id: &str, address: &str, value: u64)
61 -> Result<(), RepositoryError>;
62
63 async fn drop_all_entries(&self) -> Result<(), RepositoryError>;
66}
67
68#[derive(Debug, Clone)]
70pub enum TransactionCounterRepositoryStorage {
71 InMemory(InMemoryTransactionCounter),
72 Redis(RedisTransactionCounter),
73}
74
75impl TransactionCounterRepositoryStorage {
76 pub fn new_in_memory() -> Self {
77 Self::InMemory(InMemoryTransactionCounter::new())
78 }
79 pub fn new_redis(
80 connections: Arc<RedisConnections>,
81 key_prefix: String,
82 ) -> Result<Self, RepositoryError> {
83 Ok(Self::Redis(RedisTransactionCounter::new(
84 connections,
85 key_prefix,
86 )?))
87 }
88}
89
90#[async_trait]
91impl TransactionCounterTrait for TransactionCounterRepositoryStorage {
92 async fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, RepositoryError> {
93 match self {
94 TransactionCounterRepositoryStorage::InMemory(counter) => {
95 counter.get(relayer_id, address).await
96 }
97 TransactionCounterRepositoryStorage::Redis(counter) => {
98 counter.get(relayer_id, address).await
99 }
100 }
101 }
102
103 async fn get_and_increment(
104 &self,
105 relayer_id: &str,
106 address: &str,
107 ) -> Result<u64, RepositoryError> {
108 match self {
109 TransactionCounterRepositoryStorage::InMemory(counter) => {
110 counter.get_and_increment(relayer_id, address).await
111 }
112 TransactionCounterRepositoryStorage::Redis(counter) => {
113 counter.get_and_increment(relayer_id, address).await
114 }
115 }
116 }
117
118 async fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, RepositoryError> {
119 match self {
120 TransactionCounterRepositoryStorage::InMemory(counter) => {
121 counter.decrement(relayer_id, address).await
122 }
123 TransactionCounterRepositoryStorage::Redis(counter) => {
124 counter.decrement(relayer_id, address).await
125 }
126 }
127 }
128
129 async fn set(
130 &self,
131 relayer_id: &str,
132 address: &str,
133 value: u64,
134 ) -> Result<(), RepositoryError> {
135 match self {
136 TransactionCounterRepositoryStorage::InMemory(counter) => {
137 counter.set(relayer_id, address, value).await
138 }
139 TransactionCounterRepositoryStorage::Redis(counter) => {
140 counter.set(relayer_id, address, value).await
141 }
142 }
143 }
144
145 async fn drop_all_entries(&self) -> Result<(), RepositoryError> {
146 match self {
147 TransactionCounterRepositoryStorage::InMemory(counter) => {
148 counter.drop_all_entries().await
149 }
150 TransactionCounterRepositoryStorage::Redis(counter) => counter.drop_all_entries().await,
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157
158 use super::*;
159
160 #[tokio::test]
161 async fn test_in_memory_repository_creation() {
162 let repo = TransactionCounterRepositoryStorage::new_in_memory();
163
164 matches!(repo, TransactionCounterRepositoryStorage::InMemory(_));
165 }
166
167 #[tokio::test]
168 async fn test_enum_wrapper_delegation() {
169 let repo = TransactionCounterRepositoryStorage::new_in_memory();
170
171 let result = repo.get("test_relayer", "0x1234").await.unwrap();
173 assert_eq!(result, None);
174
175 repo.set("test_relayer", "0x1234", 100).await.unwrap();
176 let result = repo.get("test_relayer", "0x1234").await.unwrap();
177 assert_eq!(result, Some(100));
178
179 let current = repo
180 .get_and_increment("test_relayer", "0x1234")
181 .await
182 .unwrap();
183 assert_eq!(current, 100);
184
185 let result = repo.get("test_relayer", "0x1234").await.unwrap();
186 assert_eq!(result, Some(101));
187
188 let new_value = repo.decrement("test_relayer", "0x1234").await.unwrap();
189 assert_eq!(new_value, 100);
190 }
191
192 #[tokio::test]
193 async fn test_enum_wrapper_drop_all_entries() {
194 let repo = TransactionCounterRepositoryStorage::new_in_memory();
195
196 repo.set("relayer_1", "0x1234", 100).await.unwrap();
197 repo.set("relayer_2", "0x5678", 200).await.unwrap();
198
199 repo.drop_all_entries().await.unwrap();
200
201 assert_eq!(repo.get("relayer_1", "0x1234").await.unwrap(), None);
202 assert_eq!(repo.get("relayer_2", "0x5678").await.unwrap(), None);
203 }
204}