openzeppelin_relayer/utils/
polling.rs

1use std::future::Future;
2use std::time::Duration;
3
4use color_eyre::Result;
5use tracing::{debug, warn};
6
7/// Polls until a condition is met or timeout is reached.
8///
9/// This helper provides a reusable abstraction for waiting on async conditions
10/// with configurable timeout and polling interval.
11///
12/// # Arguments
13/// * `check` - Closure that returns `Ok(true)` when condition is met, `Ok(false)` to continue polling
14/// * `max_wait` - Maximum time to wait before giving up
15/// * `poll_interval` - Time to sleep between polls
16/// * `operation_name` - Name of the operation for logging
17///
18/// # Returns
19/// * `Ok(true)` - Condition was met within timeout
20/// * `Ok(false)` - Timeout reached without condition being met (errors are logged and polling continues)
21pub async fn poll_until<F, Fut>(
22    check: F,
23    max_wait: Duration,
24    poll_interval: Duration,
25    operation_name: &str,
26) -> Result<bool>
27where
28    F: Fn() -> Fut,
29    Fut: Future<Output = Result<bool>>,
30{
31    let start = std::time::Instant::now();
32
33    loop {
34        match check().await {
35            Ok(true) => {
36                debug!("{} completed", operation_name);
37                return Ok(true);
38            }
39            Ok(false) => {}
40            Err(e) => {
41                warn!(error = %e, "Error checking {} status while waiting", operation_name);
42            }
43        }
44
45        if start.elapsed() > max_wait {
46            warn!(
47                "Timed out waiting for {} to complete, proceeding anyway",
48                operation_name
49            );
50            return Ok(false);
51        }
52
53        tokio::time::sleep(poll_interval).await;
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use std::sync::atomic::{AtomicU32, Ordering};
61    use std::sync::Arc;
62
63    #[tokio::test]
64    async fn test_poll_until_condition_met_immediately() {
65        let result = poll_until(
66            || async { Ok(true) },
67            Duration::from_millis(100),
68            Duration::from_millis(10),
69            "immediate_test",
70        )
71        .await;
72
73        assert!(result.is_ok());
74        assert!(result.unwrap());
75    }
76
77    #[tokio::test]
78    async fn test_poll_until_condition_met_after_multiple_polls() {
79        let poll_count = Arc::new(AtomicU32::new(0));
80        let poll_count_clone = Arc::clone(&poll_count);
81
82        let result = poll_until(
83            move || {
84                let count = poll_count_clone.fetch_add(1, Ordering::SeqCst);
85                async move {
86                    // Return true on the 3rd poll (count == 2)
87                    Ok(count >= 2)
88                }
89            },
90            Duration::from_secs(1),
91            Duration::from_millis(10),
92            "delayed_condition_test",
93        )
94        .await;
95
96        assert!(result.is_ok());
97        assert!(result.unwrap());
98        assert!(poll_count.load(Ordering::SeqCst) >= 3);
99    }
100
101    #[tokio::test]
102    async fn test_poll_until_timeout_reached() {
103        let result = poll_until(
104            || async { Ok(false) },
105            Duration::from_millis(50),
106            Duration::from_millis(10),
107            "timeout_test",
108        )
109        .await;
110
111        assert!(result.is_ok());
112        assert!(!result.unwrap());
113    }
114
115    #[tokio::test]
116    async fn test_poll_until_continues_polling_after_errors() {
117        let poll_count = Arc::new(AtomicU32::new(0));
118        let poll_count_clone = Arc::clone(&poll_count);
119
120        let result = poll_until(
121            move || {
122                let count = poll_count_clone.fetch_add(1, Ordering::SeqCst);
123                async move {
124                    if count < 2 {
125                        // Return error on first two polls
126                        Err(color_eyre::eyre::eyre!("temporary error"))
127                    } else {
128                        // Return success on 3rd poll
129                        Ok(true)
130                    }
131                }
132            },
133            Duration::from_secs(1),
134            Duration::from_millis(10),
135            "error_recovery_test",
136        )
137        .await;
138
139        assert!(result.is_ok());
140        assert!(result.unwrap());
141        assert!(poll_count.load(Ordering::SeqCst) >= 3);
142    }
143
144    #[tokio::test]
145    async fn test_poll_until_timeout_after_persistent_errors() {
146        let poll_count = Arc::new(AtomicU32::new(0));
147        let poll_count_clone = Arc::clone(&poll_count);
148
149        let result = poll_until(
150            move || {
151                poll_count_clone.fetch_add(1, Ordering::SeqCst);
152                async { Err(color_eyre::eyre::eyre!("persistent error")) }
153            },
154            Duration::from_millis(50),
155            Duration::from_millis(10),
156            "persistent_error_test",
157        )
158        .await;
159
160        // Should timeout (return Ok(false)) since errors don't stop polling
161        assert!(result.is_ok());
162        assert!(!result.unwrap());
163        // Should have polled multiple times
164        assert!(poll_count.load(Ordering::SeqCst) >= 2);
165    }
166
167    #[tokio::test]
168    async fn test_poll_until_respects_poll_interval() {
169        let start = std::time::Instant::now();
170        let poll_count = Arc::new(AtomicU32::new(0));
171        let poll_count_clone = Arc::clone(&poll_count);
172
173        let result = poll_until(
174            move || {
175                let count = poll_count_clone.fetch_add(1, Ordering::SeqCst);
176                async move { Ok(count >= 3) }
177            },
178            Duration::from_secs(1),
179            Duration::from_millis(50),
180            "interval_test",
181        )
182        .await;
183
184        let elapsed = start.elapsed();
185
186        assert!(result.is_ok());
187        assert!(result.unwrap());
188        // With 50ms interval and 4 polls (0, 1, 2, 3), we expect at least 150ms
189        // (3 sleeps between 4 polls)
190        assert!(elapsed >= Duration::from_millis(100));
191    }
192}