1use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::net::UnixStream;
13use tokio::sync::Semaphore;
14
15use super::config::get_config;
16use super::protocol::{PoolRequest, PoolResponse};
17use super::PluginError;
18
19pub struct PoolConnection {
21 stream: UnixStream,
22 id: usize,
24}
25
26impl PoolConnection {
27 pub async fn new(socket_path: &str, id: usize) -> Result<Self, PluginError> {
28 let max_attempts = get_config().pool_connect_retries;
29 let mut attempts = 0;
30 let mut delay_ms = 10u64;
31
32 tracing::debug!(connection_id = id, socket_path = %socket_path, "Connecting to pool server");
33
34 loop {
35 match UnixStream::connect(socket_path).await {
36 Ok(stream) => {
37 if attempts > 0 {
38 tracing::debug!(
39 connection_id = id,
40 attempts = attempts,
41 "Connected to pool server after retries"
42 );
43 }
44 return Ok(Self { stream, id });
45 }
46 Err(e) => {
47 attempts += 1;
48
49 if attempts >= max_attempts {
50 return Err(PluginError::SocketError(format!(
51 "Failed to connect to pool after {max_attempts} attempts: {e}. \
52 Consider increasing PLUGIN_POOL_CONNECT_RETRIES or PLUGIN_POOL_MAX_CONNECTIONS."
53 )));
54 }
55
56 if attempts <= 3 || attempts % 5 == 0 {
57 tracing::debug!(
58 connection_id = id,
59 attempt = attempts,
60 max_attempts = max_attempts,
61 delay_ms = delay_ms,
62 "Retrying connection to pool server"
63 );
64 }
65
66 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
67 delay_ms = std::cmp::min(delay_ms * 2, 1000);
68 }
69 }
70 }
71 }
72
73 pub async fn send_request(
74 &mut self,
75 request: &PoolRequest,
76 ) -> Result<PoolResponse, PluginError> {
77 let request_task_id = Self::extract_task_id(request);
79
80 let json = serde_json::to_string(request)
81 .map_err(|e| PluginError::PluginError(format!("Failed to serialize request: {e}")))?;
82
83 if let Err(e) = self.stream.write_all(format!("{json}\n").as_bytes()).await {
84 return Err(PluginError::SocketError(format!(
85 "Failed to send request: {e}"
86 )));
87 }
88
89 if let Err(e) = self.stream.flush().await {
90 return Err(PluginError::SocketError(format!(
91 "Failed to flush request: {e}"
92 )));
93 }
94
95 let mut reader = BufReader::new(&mut self.stream);
96 let mut line = String::new();
97
98 if let Err(e) = reader.read_line(&mut line).await {
99 return Err(PluginError::SocketError(format!(
100 "Failed to read response: {e}"
101 )));
102 }
103
104 tracing::debug!(response_len = line.len(), "Received response from pool");
105
106 let response: PoolResponse = serde_json::from_str(&line)
107 .map_err(|e| PluginError::PluginError(format!("Failed to parse response: {e}")))?;
108
109 if response.task_id != request_task_id {
111 tracing::error!(
112 request_task_id = %request_task_id,
113 response_task_id = %response.task_id,
114 connection_id = self.id,
115 "Response task_id mismatch"
116 );
117 return Err(PluginError::PluginError(
118 "Internal plugin error: response task_id mismatch".to_string(),
119 ));
120 }
121
122 Ok(response)
123 }
124
125 fn extract_task_id(request: &PoolRequest) -> String {
127 match request {
128 PoolRequest::Execute(req) => req.task_id.clone(),
129 PoolRequest::Precompile { task_id, .. } => task_id.clone(),
130 PoolRequest::Cache { task_id, .. } => task_id.clone(),
131 PoolRequest::Invalidate { task_id, .. } => task_id.clone(),
132 PoolRequest::Stats { task_id } => task_id.clone(),
133 PoolRequest::Health { task_id } => task_id.clone(),
134 PoolRequest::Shutdown { task_id } => task_id.clone(),
135 }
136 }
137
138 pub async fn send_request_with_timeout(
139 &mut self,
140 request: &PoolRequest,
141 timeout_secs: u64,
142 ) -> Result<PoolResponse, PluginError> {
143 tokio::time::timeout(
144 Duration::from_secs(timeout_secs),
145 self.send_request(request),
146 )
147 .await
148 .map_err(|_| {
149 PluginError::SocketError(format!("Request timed out after {timeout_secs} seconds"))
150 })?
151 }
152
153 pub fn id(&self) -> usize {
155 self.id
156 }
157}
158
159pub struct ConnectionPool {
164 socket_path: String,
165 #[allow(dead_code)]
167 max_connections: usize,
168 next_id: Arc<AtomicUsize>,
170 pub semaphore: Arc<Semaphore>,
172}
173
174impl ConnectionPool {
175 pub fn new(socket_path: String, max_connections: usize) -> Self {
176 Self {
177 socket_path,
178 max_connections,
179 next_id: Arc::new(AtomicUsize::new(0)),
180 semaphore: Arc::new(Semaphore::new(max_connections)),
181 }
182 }
183
184 pub async fn acquire_with_permit(
188 &self,
189 permit: Option<tokio::sync::OwnedSemaphorePermit>,
190 ) -> Result<PooledConnection<'_>, PluginError> {
191 let permit = match permit {
192 Some(p) => p,
193 None => {
194 let available_permits = self.semaphore.available_permits();
195 if available_permits == 0 {
196 tracing::warn!(
197 max_connections = self.max_connections,
198 "All connection permits exhausted - waiting for connection"
199 );
200 }
201 self.semaphore.clone().acquire_owned().await.map_err(|_| {
202 PluginError::PluginError("Connection semaphore closed".to_string())
203 })?
204 }
205 };
206
207 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
208 tracing::debug!(connection_id = id, "Creating connection");
209
210 let conn = PoolConnection::new(&self.socket_path, id).await?;
211
212 Ok(PooledConnection {
213 conn: Some(conn),
214 pool: self,
215 _permit: permit,
216 })
217 }
218
219 pub async fn acquire(&self) -> Result<PooledConnection<'_>, PluginError> {
221 self.acquire_with_permit(None).await
222 }
223
224 pub fn release(&self, conn: PoolConnection) {
226 let conn_id = conn.id();
227 tracing::debug!(connection_id = conn_id, "Connection closed");
228 drop(conn);
229 }
230
231 pub fn next_connection_id(&self) -> usize {
234 self.next_id.fetch_add(1, Ordering::Relaxed)
235 }
236}
237
238pub struct PooledConnection<'a> {
240 conn: Option<PoolConnection>,
241 pool: &'a ConnectionPool,
242 _permit: tokio::sync::OwnedSemaphorePermit,
244}
245
246impl<'a> PooledConnection<'a> {
247 pub async fn send_request_with_timeout(
248 &mut self,
249 request: &PoolRequest,
250 timeout_secs: u64,
251 ) -> Result<PoolResponse, PluginError> {
252 if let Some(ref mut conn) = self.conn {
253 conn.send_request_with_timeout(request, timeout_secs).await
254 } else {
255 Err(PluginError::PluginError(
256 "Connection already released".to_string(),
257 ))
258 }
259 }
260}
261
262impl<'a> Drop for PooledConnection<'a> {
263 fn drop(&mut self) {
264 if let Some(conn) = self.conn.take() {
265 self.pool.release(conn);
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::services::plugins::protocol::ExecuteRequest;
274
275 #[test]
280 fn test_connection_pool_creation() {
281 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 10);
282 assert_eq!(pool.semaphore.available_permits(), 10);
284 }
285
286 #[test]
287 fn test_connection_pool_creation_single_connection() {
288 let pool = ConnectionPool::new("/tmp/single.sock".to_string(), 1);
289 assert_eq!(pool.semaphore.available_permits(), 1);
290 }
291
292 #[test]
293 fn test_connection_pool_creation_large_pool() {
294 let pool = ConnectionPool::new("/tmp/large.sock".to_string(), 1000);
295 assert_eq!(pool.semaphore.available_permits(), 1000);
296 }
297
298 #[test]
299 fn test_connection_pool_stores_socket_path() {
300 let path = "/var/run/custom.sock";
301 let pool = ConnectionPool::new(path.to_string(), 5);
302 assert_eq!(pool.socket_path, path);
303 }
304
305 #[test]
306 fn test_connection_pool_stores_max_connections() {
307 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 42);
308 assert_eq!(pool.max_connections, 42);
309 }
310
311 #[tokio::test]
316 async fn test_connection_pool_semaphore_limits() {
317 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 2);
318
319 let permit1 = pool.semaphore.clone().try_acquire_owned();
320 assert!(permit1.is_ok());
321
322 let permit2 = pool.semaphore.clone().try_acquire_owned();
323 assert!(permit2.is_ok());
324
325 let permit3 = pool.semaphore.clone().try_acquire_owned();
327 assert!(permit3.is_err());
328 }
329
330 #[tokio::test]
331 async fn test_semaphore_permit_release_restores_capacity() {
332 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 2);
333
334 let permit1 = pool.semaphore.clone().try_acquire_owned().unwrap();
336 let permit2 = pool.semaphore.clone().try_acquire_owned().unwrap();
337
338 assert_eq!(pool.semaphore.available_permits(), 0);
340
341 drop(permit1);
343
344 assert_eq!(pool.semaphore.available_permits(), 1);
346
347 let permit3 = pool.semaphore.clone().try_acquire_owned();
349 assert!(permit3.is_ok());
350
351 drop(permit2);
353 drop(permit3.unwrap());
354
355 assert_eq!(pool.semaphore.available_permits(), 2);
357 }
358
359 #[tokio::test]
360 async fn test_semaphore_async_acquire() {
361 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 1);
362
363 let permit = pool.semaphore.clone().acquire_owned().await;
365 assert!(permit.is_ok());
366 let _permit = permit.unwrap();
367
368 assert_eq!(pool.semaphore.available_permits(), 0);
370 }
371
372 #[test]
377 fn test_connection_id_increment() {
378 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 10);
379 assert_eq!(pool.next_connection_id(), 0);
380 assert_eq!(pool.next_connection_id(), 1);
381 assert_eq!(pool.next_connection_id(), 2);
382 }
383
384 #[test]
385 fn test_connection_id_starts_at_zero() {
386 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 10);
387 assert_eq!(pool.next_connection_id(), 0);
388 }
389
390 #[test]
391 fn test_connection_id_monotonically_increasing() {
392 let pool = ConnectionPool::new("/tmp/test.sock".to_string(), 10);
393
394 let mut last_id = pool.next_connection_id();
395 for _ in 0..100 {
396 let current_id = pool.next_connection_id();
397 assert!(
398 current_id > last_id,
399 "IDs should be monotonically increasing"
400 );
401 last_id = current_id;
402 }
403 }
404
405 #[test]
406 fn test_connection_id_thread_safe() {
407 use std::thread;
408
409 let pool = Arc::new(ConnectionPool::new("/tmp/test.sock".to_string(), 100));
410 let mut handles = vec![];
411
412 for _ in 0..10 {
414 let pool_clone = pool.clone();
415 handles.push(thread::spawn(move || {
416 let mut ids = vec![];
417 for _ in 0..100 {
418 ids.push(pool_clone.next_connection_id());
419 }
420 ids
421 }));
422 }
423
424 let mut all_ids: Vec<usize> = handles
426 .into_iter()
427 .flat_map(|h| h.join().unwrap())
428 .collect();
429
430 all_ids.sort();
432 let unique_count = all_ids.windows(2).filter(|w| w[0] != w[1]).count() + 1;
433 assert_eq!(unique_count, all_ids.len(), "All IDs should be unique");
434 }
435
436 #[test]
441 fn test_extract_task_id_from_execute_request() {
442 let request = PoolRequest::Execute(Box::new(ExecuteRequest {
443 task_id: "execute-task-123".to_string(),
444 plugin_id: "test-plugin".to_string(),
445 compiled_code: None,
446 plugin_path: None,
447 params: serde_json::json!({}),
448 headers: None,
449 socket_path: "/tmp/test.sock".to_string(),
450 http_request_id: None,
451 timeout: Some(30000),
452 route: None,
453 config: None,
454 method: None,
455 query: None,
456 }));
457
458 let task_id = PoolConnection::extract_task_id(&request);
459 assert_eq!(task_id, "execute-task-123");
460 }
461
462 #[test]
463 fn test_extract_task_id_from_precompile_request() {
464 let request = PoolRequest::Precompile {
465 task_id: "precompile-task-456".to_string(),
466 plugin_id: "test-plugin".to_string(),
467 plugin_path: Some("/path/to/plugin.ts".to_string()),
468 source_code: None,
469 };
470
471 let task_id = PoolConnection::extract_task_id(&request);
472 assert_eq!(task_id, "precompile-task-456");
473 }
474
475 #[test]
476 fn test_extract_task_id_from_cache_request() {
477 let request = PoolRequest::Cache {
478 task_id: "cache-task-789".to_string(),
479 plugin_id: "test-plugin".to_string(),
480 compiled_code: "compiled code".to_string(),
481 };
482
483 let task_id = PoolConnection::extract_task_id(&request);
484 assert_eq!(task_id, "cache-task-789");
485 }
486
487 #[test]
488 fn test_extract_task_id_from_invalidate_request() {
489 let request = PoolRequest::Invalidate {
490 task_id: "invalidate-task-abc".to_string(),
491 plugin_id: "test-plugin".to_string(),
492 };
493
494 let task_id = PoolConnection::extract_task_id(&request);
495 assert_eq!(task_id, "invalidate-task-abc");
496 }
497
498 #[test]
499 fn test_extract_task_id_from_stats_request() {
500 let request = PoolRequest::Stats {
501 task_id: "stats-task-def".to_string(),
502 };
503
504 let task_id = PoolConnection::extract_task_id(&request);
505 assert_eq!(task_id, "stats-task-def");
506 }
507
508 #[test]
509 fn test_extract_task_id_from_health_request() {
510 let request = PoolRequest::Health {
511 task_id: "health-task-ghi".to_string(),
512 };
513
514 let task_id = PoolConnection::extract_task_id(&request);
515 assert_eq!(task_id, "health-task-ghi");
516 }
517
518 #[test]
519 fn test_extract_task_id_from_shutdown_request() {
520 let request = PoolRequest::Shutdown {
521 task_id: "shutdown-task-jkl".to_string(),
522 };
523
524 let task_id = PoolConnection::extract_task_id(&request);
525 assert_eq!(task_id, "shutdown-task-jkl");
526 }
527
528 #[test]
529 fn test_extract_task_id_preserves_special_characters() {
530 let request = PoolRequest::Stats {
531 task_id: "task-with-special_chars.and/slashes:colons".to_string(),
532 };
533
534 let task_id = PoolConnection::extract_task_id(&request);
535 assert_eq!(task_id, "task-with-special_chars.and/slashes:colons");
536 }
537
538 #[test]
539 fn test_extract_task_id_handles_empty_string() {
540 let request = PoolRequest::Health {
541 task_id: "".to_string(),
542 };
543
544 let task_id = PoolConnection::extract_task_id(&request);
545 assert_eq!(task_id, "");
546 }
547
548 #[test]
549 fn test_extract_task_id_handles_uuid_format() {
550 let uuid = "550e8400-e29b-41d4-a716-446655440000";
551 let request = PoolRequest::Stats {
552 task_id: uuid.to_string(),
553 };
554
555 let task_id = PoolConnection::extract_task_id(&request);
556 assert_eq!(task_id, uuid);
557 }
558
559 #[tokio::test]
564 async fn test_acquire_without_server_fails() {
565 let pool = ConnectionPool::new("/tmp/nonexistent_socket_12345.sock".to_string(), 10);
566
567 let result = pool.acquire().await;
568 assert!(result.is_err());
569
570 match result {
571 Err(PluginError::SocketError(msg)) => {
572 assert!(msg.contains("Failed to connect"));
573 }
574 _ => panic!("Expected SocketError"),
575 }
576 }
577
578 #[tokio::test]
579 async fn test_acquire_with_pre_acquired_permit() {
580 let pool = ConnectionPool::new("/tmp/nonexistent_socket_67890.sock".to_string(), 10);
581
582 let permit = pool.semaphore.clone().acquire_owned().await.unwrap();
584 assert_eq!(pool.semaphore.available_permits(), 9);
585
586 let result = pool.acquire_with_permit(Some(permit)).await;
588
589 assert!(result.is_err());
591 }
592
593 #[test]
598 fn test_pooled_connection_cannot_be_used_after_release() {
599 }
603
604 #[tokio::test]
609 async fn test_acquire_error_message_contains_helpful_info() {
610 let pool = ConnectionPool::new("/tmp/no_server_here_xyz.sock".to_string(), 10);
611
612 let result = pool.acquire().await;
613 assert!(result.is_err());
614
615 if let Err(PluginError::SocketError(msg)) = result {
616 assert!(
618 msg.contains("PLUGIN_POOL_CONNECT_RETRIES")
619 || msg.contains("PLUGIN_POOL_MAX_CONNECTIONS")
620 || msg.contains("Failed to connect"),
621 "Error message should contain helpful info: {msg}"
622 );
623 }
624 }
625
626 #[test]
631 fn test_multiple_pools_independent() {
632 let pool1 = ConnectionPool::new("/tmp/pool1.sock".to_string(), 5);
633 let pool2 = ConnectionPool::new("/tmp/pool2.sock".to_string(), 10);
634
635 assert_eq!(pool1.semaphore.available_permits(), 5);
637 assert_eq!(pool2.semaphore.available_permits(), 10);
638
639 assert_eq!(pool1.next_connection_id(), 0);
641 assert_eq!(pool2.next_connection_id(), 0);
642 assert_eq!(pool1.next_connection_id(), 1);
643 assert_eq!(pool2.next_connection_id(), 1);
644 }
645
646 #[tokio::test]
651 async fn test_concurrent_semaphore_acquire() {
652 let pool = Arc::new(ConnectionPool::new("/tmp/concurrent.sock".to_string(), 3));
653
654 let mut handles = vec![];
655
656 for i in 0..3 {
658 let pool_clone = pool.clone();
659 handles.push(tokio::spawn(async move {
660 let permit = pool_clone.semaphore.clone().acquire_owned().await;
661 assert!(permit.is_ok(), "Task {i} should acquire permit");
662 tokio::time::sleep(Duration::from_millis(10)).await;
664 }));
665 }
666
667 for handle in handles {
669 handle.await.unwrap();
670 }
671
672 assert_eq!(pool.semaphore.available_permits(), 3);
674 }
675
676 #[tokio::test]
677 async fn test_semaphore_fairness() {
678 use std::sync::atomic::AtomicU32;
679
680 let pool = Arc::new(ConnectionPool::new("/tmp/fairness.sock".to_string(), 1));
681 let counter = Arc::new(AtomicU32::new(0));
682
683 let permit = pool.semaphore.clone().acquire_owned().await.unwrap();
685
686 let mut handles = vec![];
687
688 for _ in 0..3 {
690 let pool_clone = pool.clone();
691 let counter_clone = counter.clone();
692 handles.push(tokio::spawn(async move {
693 let _permit = pool_clone.semaphore.clone().acquire_owned().await.unwrap();
694 counter_clone.fetch_add(1, Ordering::SeqCst);
695 }));
696 }
697
698 tokio::time::sleep(Duration::from_millis(50)).await;
700
701 assert_eq!(counter.load(Ordering::SeqCst), 0);
703
704 drop(permit);
706
707 for handle in handles {
709 handle.await.unwrap();
710 }
711
712 assert_eq!(counter.load(Ordering::SeqCst), 3);
714 }
715
716 #[test]
721 fn test_zero_max_connections_creates_closed_semaphore() {
722 let pool = ConnectionPool::new("/tmp/zero.sock".to_string(), 0);
723 assert_eq!(pool.semaphore.available_permits(), 0);
724
725 let permit = pool.semaphore.clone().try_acquire_owned();
727 assert!(permit.is_err());
728 }
729
730 #[test]
731 fn test_socket_path_with_spaces() {
732 let path = "/tmp/path with spaces/test.sock";
733 let pool = ConnectionPool::new(path.to_string(), 5);
734 assert_eq!(pool.socket_path, path);
735 }
736
737 #[test]
738 fn test_socket_path_with_unicode() {
739 let path = "/tmp/тест/套接字.sock";
740 let pool = ConnectionPool::new(path.to_string(), 5);
741 assert_eq!(pool.socket_path, path);
742 }
743
744 #[test]
745 fn test_very_long_socket_path() {
746 let path = format!("/tmp/{}/test.sock", "a".repeat(200));
747 let pool = ConnectionPool::new(path.clone(), 5);
748 assert_eq!(pool.socket_path, path);
749 }
750}