1use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37 primitives::Blob,
38 types::{MessageType, SigningAlgorithmSpec},
39 Client,
40};
41use once_cell::sync::Lazy;
42use serde::Serialize;
43use std::{collections::HashMap, sync::Arc};
44use tokio::sync::RwLock;
45
46use crate::{
47 models::{Address, AwsKmsSignerConfig},
48 services::{
49 client_cache::AsyncClientCache, signer::evm::utils::recover_evm_signature_from_der,
50 },
51 utils::{
52 self, derive_ethereum_address_from_der, derive_solana_address_from_der,
53 derive_stellar_address_from_der,
54 },
55};
56use tracing::debug;
57
58#[cfg(test)]
59use mockall::{automock, mock};
60
61#[derive(Clone, Debug, thiserror::Error, Serialize)]
62pub enum AwsKmsError {
63 #[error("AWS KMS response parse error: {0}")]
64 ParseError(String),
65 #[error("AWS KMS config error: {0}")]
66 ConfigError(String),
67 #[error("AWS KMS get error: {0}")]
68 GetError(String),
69 #[error("AWS KMS signing error: {0}")]
70 SignError(String),
71 #[error("AWS KMS permissions error: {0}")]
72 PermissionError(String),
73 #[error("AWS KMS public key error: {0}")]
74 RecoveryError(#[from] utils::Secp256k1Error),
75 #[error("AWS KMS conversion error: {0}")]
76 ConvertError(String),
77 #[error("AWS KMS Other error: {0}")]
78 Other(String),
79}
80
81pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
82
83#[async_trait]
84#[cfg_attr(test, automock)]
85pub trait AwsKmsEvmService: Send + Sync {
86 async fn get_evm_address(&self) -> AwsKmsResult<Address>;
88 async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
98
99 async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>>;
109}
110
111#[async_trait]
112#[cfg_attr(test, automock)]
113pub trait AwsKmsK256: Send + Sync {
114 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
116 async fn sign_digest<'a, 'b>(
118 &'a self,
119 key_id: &'b str,
120 digest: [u8; 32],
121 ) -> AwsKmsResult<Vec<u8>>;
122}
123
124#[async_trait]
127#[cfg_attr(test, automock)]
128pub trait AwsKmsEd25519: Send + Sync {
129 async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
131 async fn sign_ed25519<'a, 'b>(
134 &'a self,
135 key_id: &'b str,
136 message: &'b [u8],
137 ) -> AwsKmsResult<Vec<u8>>;
138}
139
140#[async_trait]
142#[cfg_attr(test, automock)]
143pub trait AwsKmsSolanaService: Send + Sync {
144 async fn get_solana_address(&self) -> AwsKmsResult<Address>;
146 async fn sign_solana(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>>;
148}
149
150#[async_trait]
152#[cfg_attr(test, automock)]
153pub trait AwsKmsStellarService: Send + Sync {
154 async fn get_stellar_address(&self) -> AwsKmsResult<Address>;
156 async fn sign_stellar(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>>;
158}
159
160#[cfg(test)]
161mock! {
162 pub AwsKmsClient { }
163 impl Clone for AwsKmsClient {
164 fn clone(&self) -> Self;
165 }
166
167 #[async_trait]
168 impl AwsKmsK256 for AwsKmsClient {
169 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
170 async fn sign_digest<'a, 'b>(
171 &'a self,
172 key_id: &'b str,
173 digest: [u8; 32],
174 ) -> AwsKmsResult<Vec<u8>>;
175 }
176
177 #[async_trait]
178 impl AwsKmsEd25519 for AwsKmsClient {
179 async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
180 async fn sign_ed25519<'a, 'b>(
181 &'a self,
182 key_id: &'b str,
183 message: &'b [u8],
184 ) -> AwsKmsResult<Vec<u8>>;
185 }
186}
187
188static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
190 Lazy::new(|| RwLock::new(HashMap::new()));
191
192static KMS_ED25519_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
194 Lazy::new(|| RwLock::new(HashMap::new()));
195
196#[derive(Clone, Debug, Eq, PartialEq, Hash)]
197struct AwsKmsClientKey {
198 region: String,
199}
200
201static KMS_CLIENT_CACHE: Lazy<AsyncClientCache<AwsKmsClientKey, Client>> =
202 Lazy::new(AsyncClientCache::new);
203
204async fn get_or_create_kms_client(config: &AwsKmsSignerConfig) -> AwsKmsResult<Arc<Client>> {
207 let resolved_region = resolve_aws_region(config).await?;
208 let key = AwsKmsClientKey {
209 region: resolved_region.clone(),
210 };
211
212 KMS_CLIENT_CACHE
213 .get_or_try_init(key, || async {
214 debug!(
215 region = %resolved_region,
216 "Creating new AWS KMS client"
217 );
218 let auth_config = aws_config::defaults(BehaviorVersion::latest())
219 .region(Region::new(resolved_region))
220 .load()
221 .await;
222
223 std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| Client::new(&auth_config)))
226 .map_err(|panic| {
227 let msg = panic
228 .downcast_ref::<String>()
229 .map(|s| s.as_str())
230 .or_else(|| panic.downcast_ref::<&str>().copied())
231 .unwrap_or("unknown panic");
232 AwsKmsError::ConfigError(format!(
233 "Failed to initialize AWS KMS client (check TLS root certificates): {msg}"
234 ))
235 })
236 })
237 .await
238}
239
240async fn resolve_aws_region(config: &AwsKmsSignerConfig) -> AwsKmsResult<String> {
242 if let Some(region) = &config.region {
243 return Ok(region.clone());
244 }
245
246 let provider = RegionProviderChain::default_provider();
247 provider
248 .region()
249 .await
250 .map(|r| r.to_string())
251 .ok_or_else(|| {
252 AwsKmsError::ConfigError(
253 "AWS region not specified and could not be resolved from environment".to_string(),
254 )
255 })
256}
257
258#[derive(Debug, Clone)]
259pub struct AwsKmsClient {
260 inner: Arc<Client>,
261}
262
263#[async_trait]
264impl AwsKmsK256 for AwsKmsClient {
265 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
266 let cached = {
268 let cache_read = KMS_DER_PK_CACHE.read().await;
269 cache_read.get(key_id).cloned()
270 };
271 if let Some(cached) = cached {
272 return Ok(cached);
273 }
274
275 let get_output = self
277 .inner
278 .get_public_key()
279 .key_id(key_id)
280 .send()
281 .await
282 .map_err(|e| {
283 AwsKmsError::GetError(format!(
284 "Failed to get secp256k1 public key for key '{key_id}': {e:?}"
285 ))
286 })?;
287
288 let der_pk_blob = get_output
289 .public_key
290 .ok_or(AwsKmsError::GetError(
291 "No public key blob found".to_string(),
292 ))?
293 .into_inner();
294
295 let mut cache_write = KMS_DER_PK_CACHE.write().await;
296 cache_write.insert(key_id.to_string(), der_pk_blob.clone());
297
298 Ok(der_pk_blob)
299 }
300
301 async fn sign_digest<'a, 'b>(
302 &'a self,
303 key_id: &'b str,
304 digest: [u8; 32],
305 ) -> AwsKmsResult<Vec<u8>> {
306 let sign_result = self
308 .inner
309 .sign()
310 .key_id(key_id)
311 .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
312 .message_type(MessageType::Digest)
313 .message(Blob::new(digest))
314 .send()
315 .await;
316
317 let der_signature = sign_result
319 .map_err(|e| AwsKmsError::PermissionError(e.to_string()))?
320 .signature
321 .ok_or(AwsKmsError::SignError(
322 "Signature not found in response".to_string(),
323 ))?
324 .into_inner();
325
326 Ok(der_signature)
327 }
328}
329
330#[async_trait]
331impl AwsKmsEd25519 for AwsKmsClient {
332 async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
333 let cached = {
335 let cache_read = KMS_ED25519_PK_CACHE.read().await;
336 cache_read.get(key_id).cloned()
337 };
338 if let Some(cached) = cached {
339 return Ok(cached);
340 }
341
342 let get_output = self
344 .inner
345 .get_public_key()
346 .key_id(key_id)
347 .send()
348 .await
349 .map_err(|e| {
350 AwsKmsError::GetError(format!(
351 "Failed to get Ed25519 public key for key '{key_id}': {e:?}"
352 ))
353 })?;
354
355 let der_pk_blob = get_output
356 .public_key
357 .ok_or(AwsKmsError::GetError(
358 "No public key blob found".to_string(),
359 ))?
360 .into_inner();
361
362 let mut cache_write = KMS_ED25519_PK_CACHE.write().await;
363 cache_write.insert(key_id.to_string(), der_pk_blob.clone());
364
365 Ok(der_pk_blob)
366 }
367
368 async fn sign_ed25519<'a, 'b>(
369 &'a self,
370 key_id: &'b str,
371 message: &'b [u8],
372 ) -> AwsKmsResult<Vec<u8>> {
373 debug!("Signing Ed25519 message with AWS KMS, key_id: {}", key_id);
374
375 let sign_result = self
378 .inner
379 .sign()
380 .key_id(key_id)
381 .signing_algorithm(SigningAlgorithmSpec::Ed25519Sha512)
382 .message_type(MessageType::Raw)
383 .message(Blob::new(message))
384 .send()
385 .await;
386
387 let signature = sign_result
389 .map_err(|e| AwsKmsError::SignError(e.to_string()))?
390 .signature
391 .ok_or(AwsKmsError::SignError(
392 "Signature not found in response".to_string(),
393 ))?
394 .into_inner();
395
396 if signature.len() != 64 {
398 return Err(AwsKmsError::SignError(format!(
399 "Invalid Ed25519 signature length: expected 64 bytes, got {}",
400 signature.len()
401 )));
402 }
403
404 Ok(signature)
405 }
406}
407
408#[derive(Debug, Clone)]
409pub struct AwsKmsService<T: AwsKmsK256 + AwsKmsEd25519 + Clone = AwsKmsClient> {
410 pub kms_key_id: String,
411 client: T,
412}
413
414impl AwsKmsService<AwsKmsClient> {
415 pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
416 let shared_client = get_or_create_kms_client(&config).await?;
417
418 Ok(Self {
419 kms_key_id: config.key_id,
420 client: AwsKmsClient {
421 inner: shared_client,
422 },
423 })
424 }
425}
426
427#[cfg(test)]
428impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsService<T> {
429 pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
430 Self {
431 client,
432 kms_key_id: config.key_id,
433 }
434 }
435}
436
437impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsService<T> {
438 async fn sign_and_recover_evm(
447 &self,
448 digest: [u8; 32],
449 original_bytes: &[u8],
450 use_prehash_recovery: bool,
451 ) -> AwsKmsResult<Vec<u8>> {
452 let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
454
455 let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
457
458 recover_evm_signature_from_der(
460 &der_signature,
461 &der_pk,
462 digest,
463 original_bytes,
464 use_prehash_recovery,
465 )
466 .map_err(|e| AwsKmsError::ParseError(e.to_string()))
467 }
468
469 pub async fn sign_payload_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
479 let digest = keccak256(bytes).0;
480 self.sign_and_recover_evm(digest, bytes, false).await
481 }
482
483 pub async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
493 self.sign_and_recover_evm(*hash, hash, true).await
494 }
495}
496
497#[async_trait]
498impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsEvmService for AwsKmsService<T> {
499 async fn get_evm_address(&self) -> AwsKmsResult<Address> {
500 let der = self.client.get_der_public_key(&self.kms_key_id).await?;
501 let eth_address = derive_ethereum_address_from_der(&der)
502 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
503 Ok(Address::Evm(eth_address))
504 }
505
506 async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
507 let digest = keccak256(message).0;
508 self.sign_and_recover_evm(digest, message, false).await
509 }
510
511 async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
512 self.sign_and_recover_evm(*hash, hash, true).await
514 }
515}
516
517#[async_trait]
518impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsSolanaService for AwsKmsService<T> {
519 async fn get_solana_address(&self) -> AwsKmsResult<Address> {
520 let der = self.client.get_ed25519_public_key(&self.kms_key_id).await?;
521 let solana_address = derive_solana_address_from_der(&der)
522 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
523 Ok(Address::Solana(solana_address))
524 }
525
526 async fn sign_solana(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
527 self.client.sign_ed25519(&self.kms_key_id, message).await
528 }
529}
530
531#[async_trait]
532impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsStellarService for AwsKmsService<T> {
533 async fn get_stellar_address(&self) -> AwsKmsResult<Address> {
534 let der = self.client.get_ed25519_public_key(&self.kms_key_id).await?;
535 let stellar_address = derive_stellar_address_from_der(&der)
536 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
537 Ok(Address::Stellar(stellar_address))
538 }
539
540 async fn sign_stellar(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
541 self.client.sign_ed25519(&self.kms_key_id, message).await
542 }
543}
544
545#[cfg(test)]
546pub mod tests {
547 use super::*;
548
549 use alloy::primitives::utils::eip191_message;
550 use k256::{
551 ecdsa::SigningKey,
552 elliptic_curve::rand_core::OsRng,
553 pkcs8::{der::Encode, EncodePublicKey},
554 };
555 use mockall::predicate::{eq, ne};
556
557 pub struct TestEd25519Keys {
559 pub public_key_der: Vec<u8>,
560 pub public_key_raw: [u8; 32],
561 }
562
563 impl Default for TestEd25519Keys {
564 fn default() -> Self {
565 Self::new()
566 }
567 }
568
569 impl TestEd25519Keys {
570 pub fn new() -> Self {
571 let public_key_raw: [u8; 32] = [
573 0x9d, 0x45, 0x7e, 0x45, 0xe4, 0x16, 0xc4, 0xc6, 0x77, 0x67, 0x6a, 0x42, 0xff, 0x96,
574 0x8e, 0x3c, 0xf8, 0xdc, 0x73, 0xc8, 0xf3, 0x3a, 0x8d, 0x19, 0x81, 0x29, 0x7b, 0xfa,
575 0x3e, 0x00, 0x30, 0xba,
576 ];
577
578 let mut public_key_der = vec![
580 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, ];
586 public_key_der.extend_from_slice(&public_key_raw);
587
588 Self {
589 public_key_der,
590 public_key_raw,
591 }
592 }
593 }
594
595 pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
596 let mut client = MockAwsKmsClient::new();
597 let signing_key = SigningKey::random(&mut OsRng);
598 let s = signing_key
599 .verifying_key()
600 .to_public_key_der()
601 .unwrap()
602 .to_der()
603 .unwrap();
604
605 client
606 .expect_get_der_public_key()
607 .with(eq("test-key-id"))
608 .return_const(Ok(s));
609 client
610 .expect_get_der_public_key()
611 .with(ne("test-key-id"))
612 .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
613
614 client
615 .expect_sign_digest()
616 .withf(|key_id, _| key_id.ne("test-key-id"))
617 .return_const(Err(AwsKmsError::SignError(
618 "Key does not exist".to_string(),
619 )));
620
621 let key = signing_key.clone();
622 client
623 .expect_sign_digest()
624 .withf(|key_id, _| key_id.eq("test-key-id"))
625 .returning(move |_, digest| {
626 let (signature, _) = signing_key
627 .sign_prehash_recoverable(&digest)
628 .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
629 let der_signature = signature.to_der().as_bytes().to_vec();
630 Ok(der_signature)
631 });
632
633 let test_ed25519_keys = TestEd25519Keys::new();
635 client
636 .expect_get_ed25519_public_key()
637 .with(eq("test-key-id"))
638 .return_const(Ok(test_ed25519_keys.public_key_der.clone()));
639 client
640 .expect_get_ed25519_public_key()
641 .with(ne("test-key-id"))
642 .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
643
644 client
646 .expect_sign_ed25519()
647 .withf(|key_id, _| key_id.eq("test-key-id"))
648 .returning(|_, _| Ok(vec![0u8; 64]));
649 client
650 .expect_sign_ed25519()
651 .withf(|key_id, _| key_id.ne("test-key-id"))
652 .return_const(Err(AwsKmsError::SignError(
653 "Key does not exist".to_string(),
654 )));
655
656 client.expect_clone().return_once(MockAwsKmsClient::new);
657
658 (client, key)
659 }
660
661 #[tokio::test]
662 async fn test_get_public_key() {
663 let (mock_client, key) = setup_mock_kms_client();
664 let kms = AwsKmsService::new_for_testing(
665 mock_client,
666 AwsKmsSignerConfig {
667 region: Some("us-east-1".to_string()),
668 key_id: "test-key-id".to_string(),
669 },
670 );
671
672 let result = kms.get_evm_address().await;
673 assert!(result.is_ok());
674 if let Ok(Address::Evm(evm_address)) = result {
675 let expected_address = derive_ethereum_address_from_der(
676 key.verifying_key().to_public_key_der().unwrap().as_bytes(),
677 )
678 .unwrap();
679 assert_eq!(expected_address, evm_address);
680 }
681 }
682
683 #[tokio::test]
684 async fn test_get_public_key_fail() {
685 let (mock_client, _) = setup_mock_kms_client();
686 let kms = AwsKmsService::new_for_testing(
687 mock_client,
688 AwsKmsSignerConfig {
689 region: Some("us-east-1".to_string()),
690 key_id: "invalid-key-id".to_string(),
691 },
692 );
693
694 let result = kms.get_evm_address().await;
695 assert!(result.is_err());
696 if let Err(err) = result {
697 assert!(matches!(err, AwsKmsError::GetError(_)))
698 }
699 }
700
701 #[tokio::test]
702 async fn test_sign_digest() {
703 let (mock_client, _) = setup_mock_kms_client();
704 let kms = AwsKmsService::new_for_testing(
705 mock_client,
706 AwsKmsSignerConfig {
707 region: Some("us-east-1".to_string()),
708 key_id: "test-key-id".to_string(),
709 },
710 );
711
712 let message_eip = eip191_message(b"Hello World!");
713 let result = kms.sign_payload_evm(&message_eip).await;
714
715 assert!(result.is_ok());
717 }
718
719 #[tokio::test]
720 async fn test_sign_digest_fail() {
721 let (mock_client, _) = setup_mock_kms_client();
722 let kms = AwsKmsService::new_for_testing(
723 mock_client,
724 AwsKmsSignerConfig {
725 region: Some("us-east-1".to_string()),
726 key_id: "invalid-key-id".to_string(),
727 },
728 );
729
730 let message_eip = eip191_message(b"Hello World!");
731 let result = kms.sign_payload_evm(&message_eip).await;
732 assert!(result.is_err());
733 if let Err(err) = result {
734 assert!(matches!(err, AwsKmsError::SignError(_)))
735 }
736 }
737
738 #[tokio::test]
739 async fn test_get_solana_address() {
740 let (mock_client, _) = setup_mock_kms_client();
741 let kms = AwsKmsService::new_for_testing(
742 mock_client,
743 AwsKmsSignerConfig {
744 region: Some("us-east-1".to_string()),
745 key_id: "test-key-id".to_string(),
746 },
747 );
748
749 let result = kms.get_solana_address().await;
750 assert!(result.is_ok());
751 if let Ok(Address::Solana(solana_address)) = result {
752 assert!(!solana_address.is_empty());
754 assert!(solana_address.len() >= 32 && solana_address.len() <= 44);
755 let test_keys = TestEd25519Keys::new();
757 let expected_address = bs58::encode(test_keys.public_key_raw).into_string();
758 assert_eq!(solana_address, expected_address);
759 } else {
760 panic!("Expected Solana address");
761 }
762 }
763
764 #[tokio::test]
765 async fn test_get_solana_address_fail() {
766 let (mock_client, _) = setup_mock_kms_client();
767 let kms = AwsKmsService::new_for_testing(
768 mock_client,
769 AwsKmsSignerConfig {
770 region: Some("us-east-1".to_string()),
771 key_id: "invalid-key-id".to_string(),
772 },
773 );
774
775 let result = kms.get_solana_address().await;
776 assert!(result.is_err());
777 if let Err(err) = result {
778 assert!(matches!(err, AwsKmsError::GetError(_)))
779 }
780 }
781
782 #[tokio::test]
783 async fn test_sign_solana() {
784 let (mock_client, _) = setup_mock_kms_client();
785 let kms = AwsKmsService::new_for_testing(
786 mock_client,
787 AwsKmsSignerConfig {
788 region: Some("us-east-1".to_string()),
789 key_id: "test-key-id".to_string(),
790 },
791 );
792
793 let message = b"Test Solana message";
794 let result = kms.sign_solana(message).await;
795 assert!(result.is_ok());
796 let signature = result.unwrap();
797 assert_eq!(signature.len(), 64); }
799
800 #[tokio::test]
801 async fn test_sign_solana_fail() {
802 let (mock_client, _) = setup_mock_kms_client();
803 let kms = AwsKmsService::new_for_testing(
804 mock_client,
805 AwsKmsSignerConfig {
806 region: Some("us-east-1".to_string()),
807 key_id: "invalid-key-id".to_string(),
808 },
809 );
810
811 let message = b"Test Solana message";
812 let result = kms.sign_solana(message).await;
813 assert!(result.is_err());
814 if let Err(err) = result {
815 assert!(matches!(err, AwsKmsError::SignError(_)))
816 }
817 }
818
819 #[tokio::test]
820 async fn test_get_stellar_address() {
821 let (mock_client, _) = setup_mock_kms_client();
822 let kms = AwsKmsService::new_for_testing(
823 mock_client,
824 AwsKmsSignerConfig {
825 region: Some("us-east-1".to_string()),
826 key_id: "test-key-id".to_string(),
827 },
828 );
829
830 let result = kms.get_stellar_address().await;
831 assert!(result.is_ok());
832 if let Ok(Address::Stellar(stellar_address)) = result {
833 assert!(stellar_address.starts_with('G'));
835 assert_eq!(stellar_address.len(), 56);
837 } else {
838 panic!("Expected Stellar address");
839 }
840 }
841
842 #[tokio::test]
843 async fn test_get_stellar_address_fail() {
844 let (mock_client, _) = setup_mock_kms_client();
845 let kms = AwsKmsService::new_for_testing(
846 mock_client,
847 AwsKmsSignerConfig {
848 region: Some("us-east-1".to_string()),
849 key_id: "invalid-key-id".to_string(),
850 },
851 );
852
853 let result = kms.get_stellar_address().await;
854 assert!(result.is_err());
855 if let Err(err) = result {
856 assert!(matches!(err, AwsKmsError::GetError(_)))
857 }
858 }
859
860 #[tokio::test]
861 async fn test_sign_stellar() {
862 let (mock_client, _) = setup_mock_kms_client();
863 let kms = AwsKmsService::new_for_testing(
864 mock_client,
865 AwsKmsSignerConfig {
866 region: Some("us-east-1".to_string()),
867 key_id: "test-key-id".to_string(),
868 },
869 );
870
871 let message = b"Test Stellar message";
872 let result = kms.sign_stellar(message).await;
873 assert!(result.is_ok());
874 let signature = result.unwrap();
875 assert_eq!(signature.len(), 64); }
877
878 #[tokio::test]
879 async fn test_sign_stellar_fail() {
880 let (mock_client, _) = setup_mock_kms_client();
881 let kms = AwsKmsService::new_for_testing(
882 mock_client,
883 AwsKmsSignerConfig {
884 region: Some("us-east-1".to_string()),
885 key_id: "invalid-key-id".to_string(),
886 },
887 );
888
889 let message = b"Test Stellar message";
890 let result = kms.sign_stellar(message).await;
891 assert!(result.is_err());
892 if let Err(err) = result {
893 assert!(matches!(err, AwsKmsError::SignError(_)))
894 }
895 }
896
897 #[tokio::test]
900 async fn test_kms_client_cache_same_region_shares_client() {
901 let config1 = AwsKmsSignerConfig {
902 region: Some("us-west-2".to_string()),
903 key_id: "key-aaa".to_string(),
904 };
905 let config2 = AwsKmsSignerConfig {
906 region: Some("us-west-2".to_string()),
907 key_id: "key-bbb".to_string(),
908 };
909
910 let result1 = get_or_create_kms_client(&config1).await;
911 let result2 = get_or_create_kms_client(&config2).await;
912
913 match (result1, result2) {
914 (Ok(client1), Ok(client2)) => {
915 assert!(Arc::ptr_eq(&client1, &client2));
916 }
917 (Err(AwsKmsError::ConfigError(msg)), _) | (_, Err(AwsKmsError::ConfigError(msg))) => {
918 assert!(
920 msg.contains("TLS root certificates"),
921 "Expected TLS-related config error, got: {msg}"
922 );
923 }
924 (Err(e), _) | (_, Err(e)) => {
925 panic!("Unexpected error: {e:?}");
926 }
927 }
928 }
929
930 #[tokio::test]
931 async fn test_kms_client_returns_config_error_when_region_missing() {
932 let config = AwsKmsSignerConfig {
933 region: None,
934 key_id: "test-key".to_string(),
935 };
936
937 let result = get_or_create_kms_client(&config).await;
940 match result {
941 Err(AwsKmsError::ConfigError(_)) => {}
942 Ok(_) => panic!(
943 "Expected missing-region error; AWS_REGION/AWS_DEFAULT_REGION may be set in env"
944 ),
945 Err(e) => panic!("Expected ConfigError, got: {e:?}"),
946 }
947 }
948}