openzeppelin_relayer/services/aws_kms/
mod.rs

1//! # AWS KMS Service Module
2//!
3//! This module provides integration with AWS KMS for secure key management
4//! and cryptographic operations such as public key retrieval and message signing.
5//!
6//! Supports EVM (secp256k1/ECDSA), Solana (Ed25519), and Stellar (Ed25519) networks.
7//!
8//! ## Features
9//!
10//! - Service account authentication using credential providers
11//! - Public key retrieval from KMS
12//! - Message signing via KMS for multiple key types
13//!
14//! ## Architecture
15//!
16//! ```text
17//! AwsKmsService (implements AwsKmsEvmService, AwsKmsSolanaService, AwsKmsStellarService)
18//!   ├── Authentication (via AwsKmsClient)
19//!   ├── Public Key Retrieval (via AwsKmsClient)
20//!   └── Message Signing (via AwsKmsClient)
21//! ```
22//! is based on
23//! ```text
24//! AwsKmsClient (implements AwsKmsK256, AwsKmsEd25519)
25//!   ├── Authentication (via shared credentials)
26//!   ├── Public Key Retrieval in DER Encoding
27//!   └── Message Signing (ECDSA for secp256k1, Ed25519 for EdDSA)
28//! ```
29//! `AwsKmsK256` and `AwsKmsEd25519` are mocked with `mockall` for unit testing
30//! and injected into `AwsKmsService`
31//!
32
33use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37    primitives::Blob,
38    types::{MessageType, SigningAlgorithmSpec},
39    Client,
40};
41use once_cell::sync::Lazy;
42use serde::Serialize;
43use std::{collections::HashMap, sync::Arc};
44use tokio::sync::RwLock;
45
46use crate::{
47    models::{Address, AwsKmsSignerConfig},
48    services::{
49        client_cache::AsyncClientCache, signer::evm::utils::recover_evm_signature_from_der,
50    },
51    utils::{
52        self, derive_ethereum_address_from_der, derive_solana_address_from_der,
53        derive_stellar_address_from_der,
54    },
55};
56use tracing::debug;
57
58#[cfg(test)]
59use mockall::{automock, mock};
60
61#[derive(Clone, Debug, thiserror::Error, Serialize)]
62pub enum AwsKmsError {
63    #[error("AWS KMS response parse error: {0}")]
64    ParseError(String),
65    #[error("AWS KMS config error: {0}")]
66    ConfigError(String),
67    #[error("AWS KMS get error: {0}")]
68    GetError(String),
69    #[error("AWS KMS signing error: {0}")]
70    SignError(String),
71    #[error("AWS KMS permissions error: {0}")]
72    PermissionError(String),
73    #[error("AWS KMS public key error: {0}")]
74    RecoveryError(#[from] utils::Secp256k1Error),
75    #[error("AWS KMS conversion error: {0}")]
76    ConvertError(String),
77    #[error("AWS KMS Other error: {0}")]
78    Other(String),
79}
80
81pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
82
83#[async_trait]
84#[cfg_attr(test, automock)]
85pub trait AwsKmsEvmService: Send + Sync {
86    /// Returns the EVM address derived from the configured public key.
87    async fn get_evm_address(&self) -> AwsKmsResult<Address>;
88    /// Signs a payload using the EVM signing scheme (hashes before signing).
89    ///
90    /// This method applies keccak256 hashing before signing.
91    ///
92    /// **Use for:**
93    /// - Raw transaction data (TxLegacy, TxEip1559)
94    /// - EIP-191 personal messages
95    ///
96    /// **Note:** For EIP-712 typed data, use `sign_hash_evm()` to avoid double-hashing.
97    async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
98
99    /// Signs a pre-computed hash using the EVM signing scheme (no hashing).
100    ///
101    /// This method signs the hash directly without applying keccak256.
102    ///
103    /// **Use for:**
104    /// - EIP-712 typed data (already hashed)
105    /// - Pre-computed message digests
106    ///
107    /// **Note:** For raw data, use `sign_payload_evm()` instead.
108    async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>>;
109}
110
111#[async_trait]
112#[cfg_attr(test, automock)]
113pub trait AwsKmsK256: Send + Sync {
114    /// Fetches the DER-encoded public key from AWS KMS.
115    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
116    /// Signs a digest using EcdsaSha256 spec. Returns DER-encoded signature
117    async fn sign_digest<'a, 'b>(
118        &'a self,
119        key_id: &'b str,
120        digest: [u8; 32],
121    ) -> AwsKmsResult<Vec<u8>>;
122}
123
124/// Trait for Ed25519 (EdDSA) operations with AWS KMS.
125/// Used for Solana and Stellar signing.
126#[async_trait]
127#[cfg_attr(test, automock)]
128pub trait AwsKmsEd25519: Send + Sync {
129    /// Fetches the DER-encoded Ed25519 public key from AWS KMS.
130    async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
131    /// Signs a message using Ed25519. Returns 64-byte signature.
132    /// Uses ED25519_SHA_512 algorithm with RAW message type.
133    async fn sign_ed25519<'a, 'b>(
134        &'a self,
135        key_id: &'b str,
136        message: &'b [u8],
137    ) -> AwsKmsResult<Vec<u8>>;
138}
139
140/// Trait for Solana-specific AWS KMS operations
141#[async_trait]
142#[cfg_attr(test, automock)]
143pub trait AwsKmsSolanaService: Send + Sync {
144    /// Returns the Solana address derived from the configured Ed25519 public key.
145    async fn get_solana_address(&self) -> AwsKmsResult<Address>;
146    /// Signs a message using Ed25519 for Solana.
147    async fn sign_solana(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>>;
148}
149
150/// Trait for Stellar-specific AWS KMS operations
151#[async_trait]
152#[cfg_attr(test, automock)]
153pub trait AwsKmsStellarService: Send + Sync {
154    /// Returns the Stellar address derived from the configured Ed25519 public key.
155    async fn get_stellar_address(&self) -> AwsKmsResult<Address>;
156    /// Signs a message using Ed25519 for Stellar.
157    async fn sign_stellar(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>>;
158}
159
160#[cfg(test)]
161mock! {
162    pub AwsKmsClient { }
163    impl Clone for AwsKmsClient {
164        fn clone(&self) -> Self;
165    }
166
167    #[async_trait]
168    impl AwsKmsK256 for AwsKmsClient {
169        async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
170        async fn sign_digest<'a, 'b>(
171            &'a self,
172            key_id: &'b str,
173            digest: [u8; 32],
174        ) -> AwsKmsResult<Vec<u8>>;
175    }
176
177    #[async_trait]
178    impl AwsKmsEd25519 for AwsKmsClient {
179        async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
180        async fn sign_ed25519<'a, 'b>(
181            &'a self,
182            key_id: &'b str,
183            message: &'b [u8],
184        ) -> AwsKmsResult<Vec<u8>>;
185    }
186}
187
188// Global cache for secp256k1 public keys - HashMap keyed by kms_key_id
189static KMS_DER_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
190    Lazy::new(|| RwLock::new(HashMap::new()));
191
192// Global cache for Ed25519 public keys - HashMap keyed by kms_key_id
193static KMS_ED25519_PK_CACHE: Lazy<RwLock<HashMap<String, Vec<u8>>>> =
194    Lazy::new(|| RwLock::new(HashMap::new()));
195
196#[derive(Clone, Debug, Eq, PartialEq, Hash)]
197struct AwsKmsClientKey {
198    region: String,
199}
200
201static KMS_CLIENT_CACHE: Lazy<AsyncClientCache<AwsKmsClientKey, Client>> =
202    Lazy::new(AsyncClientCache::new);
203
204/// Get or create a shared AWS KMS SDK client for the given signer config.
205/// Keyed by resolved region — one client serves all KMS keys in that region.
206async fn get_or_create_kms_client(config: &AwsKmsSignerConfig) -> AwsKmsResult<Arc<Client>> {
207    let resolved_region = resolve_aws_region(config).await?;
208    let key = AwsKmsClientKey {
209        region: resolved_region.clone(),
210    };
211
212    KMS_CLIENT_CACHE
213        .get_or_try_init(key, || async {
214            debug!(
215                region = %resolved_region,
216                "Creating new AWS KMS client"
217            );
218            let auth_config = aws_config::defaults(BehaviorVersion::latest())
219                .region(Region::new(resolved_region))
220                .load()
221                .await;
222
223            // Client::new() can panic in environments without TLS root certificates
224            // (e.g., stripped containers). Catch the panic and return a typed error.
225            std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| Client::new(&auth_config)))
226                .map_err(|panic| {
227                    let msg = panic
228                        .downcast_ref::<String>()
229                        .map(|s| s.as_str())
230                        .or_else(|| panic.downcast_ref::<&str>().copied())
231                        .unwrap_or("unknown panic");
232                    AwsKmsError::ConfigError(format!(
233                        "Failed to initialize AWS KMS client (check TLS root certificates): {msg}"
234                    ))
235                })
236        })
237        .await
238}
239
240/// Resolve the AWS region from config or the default provider chain.
241async fn resolve_aws_region(config: &AwsKmsSignerConfig) -> AwsKmsResult<String> {
242    if let Some(region) = &config.region {
243        return Ok(region.clone());
244    }
245
246    let provider = RegionProviderChain::default_provider();
247    provider
248        .region()
249        .await
250        .map(|r| r.to_string())
251        .ok_or_else(|| {
252            AwsKmsError::ConfigError(
253                "AWS region not specified and could not be resolved from environment".to_string(),
254            )
255        })
256}
257
258#[derive(Debug, Clone)]
259pub struct AwsKmsClient {
260    inner: Arc<Client>,
261}
262
263#[async_trait]
264impl AwsKmsK256 for AwsKmsClient {
265    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
266        // Try cache first with minimal lock time
267        let cached = {
268            let cache_read = KMS_DER_PK_CACHE.read().await;
269            cache_read.get(key_id).cloned()
270        };
271        if let Some(cached) = cached {
272            return Ok(cached);
273        }
274
275        // Fetch from AWS KMS
276        let get_output = self
277            .inner
278            .get_public_key()
279            .key_id(key_id)
280            .send()
281            .await
282            .map_err(|e| {
283                AwsKmsError::GetError(format!(
284                    "Failed to get secp256k1 public key for key '{key_id}': {e:?}"
285                ))
286            })?;
287
288        let der_pk_blob = get_output
289            .public_key
290            .ok_or(AwsKmsError::GetError(
291                "No public key blob found".to_string(),
292            ))?
293            .into_inner();
294
295        let mut cache_write = KMS_DER_PK_CACHE.write().await;
296        cache_write.insert(key_id.to_string(), der_pk_blob.clone());
297
298        Ok(der_pk_blob)
299    }
300
301    async fn sign_digest<'a, 'b>(
302        &'a self,
303        key_id: &'b str,
304        digest: [u8; 32],
305    ) -> AwsKmsResult<Vec<u8>> {
306        // Sign the digest with the AWS KMS
307        let sign_result = self
308            .inner
309            .sign()
310            .key_id(key_id)
311            .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
312            .message_type(MessageType::Digest)
313            .message(Blob::new(digest))
314            .send()
315            .await;
316
317        // Process the result, extract DER signature
318        let der_signature = sign_result
319            .map_err(|e| AwsKmsError::PermissionError(e.to_string()))?
320            .signature
321            .ok_or(AwsKmsError::SignError(
322                "Signature not found in response".to_string(),
323            ))?
324            .into_inner();
325
326        Ok(der_signature)
327    }
328}
329
330#[async_trait]
331impl AwsKmsEd25519 for AwsKmsClient {
332    async fn get_ed25519_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
333        // Try cache first with minimal lock time
334        let cached = {
335            let cache_read = KMS_ED25519_PK_CACHE.read().await;
336            cache_read.get(key_id).cloned()
337        };
338        if let Some(cached) = cached {
339            return Ok(cached);
340        }
341
342        // Fetch from AWS KMS
343        let get_output = self
344            .inner
345            .get_public_key()
346            .key_id(key_id)
347            .send()
348            .await
349            .map_err(|e| {
350                AwsKmsError::GetError(format!(
351                    "Failed to get Ed25519 public key for key '{key_id}': {e:?}"
352                ))
353            })?;
354
355        let der_pk_blob = get_output
356            .public_key
357            .ok_or(AwsKmsError::GetError(
358                "No public key blob found".to_string(),
359            ))?
360            .into_inner();
361
362        let mut cache_write = KMS_ED25519_PK_CACHE.write().await;
363        cache_write.insert(key_id.to_string(), der_pk_blob.clone());
364
365        Ok(der_pk_blob)
366    }
367
368    async fn sign_ed25519<'a, 'b>(
369        &'a self,
370        key_id: &'b str,
371        message: &'b [u8],
372    ) -> AwsKmsResult<Vec<u8>> {
373        debug!("Signing Ed25519 message with AWS KMS, key_id: {}", key_id);
374
375        // Sign the message with Ed25519 using ED25519_SHA_512 algorithm
376        // Note: ED25519_SHA_512 requires MessageType::Raw - we pass the raw message
377        let sign_result = self
378            .inner
379            .sign()
380            .key_id(key_id)
381            .signing_algorithm(SigningAlgorithmSpec::Ed25519Sha512)
382            .message_type(MessageType::Raw)
383            .message(Blob::new(message))
384            .send()
385            .await;
386
387        // Process the result, extract signature
388        let signature = sign_result
389            .map_err(|e| AwsKmsError::SignError(e.to_string()))?
390            .signature
391            .ok_or(AwsKmsError::SignError(
392                "Signature not found in response".to_string(),
393            ))?
394            .into_inner();
395
396        // Ed25519 signatures should be 64 bytes
397        if signature.len() != 64 {
398            return Err(AwsKmsError::SignError(format!(
399                "Invalid Ed25519 signature length: expected 64 bytes, got {}",
400                signature.len()
401            )));
402        }
403
404        Ok(signature)
405    }
406}
407
408#[derive(Debug, Clone)]
409pub struct AwsKmsService<T: AwsKmsK256 + AwsKmsEd25519 + Clone = AwsKmsClient> {
410    pub kms_key_id: String,
411    client: T,
412}
413
414impl AwsKmsService<AwsKmsClient> {
415    pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
416        let shared_client = get_or_create_kms_client(&config).await?;
417
418        Ok(Self {
419            kms_key_id: config.key_id,
420            client: AwsKmsClient {
421                inner: shared_client,
422            },
423        })
424    }
425}
426
427#[cfg(test)]
428impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsService<T> {
429    pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
430        Self {
431            client,
432            kms_key_id: config.key_id,
433        }
434    }
435}
436
437impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsService<T> {
438    /// Common signing logic for EVM signatures.
439    ///
440    /// This internal helper eliminates duplication between `sign_payload_evm` and `sign_hash_evm`.
441    ///
442    /// # Parameters
443    /// * `digest` - The 32-byte hash to sign
444    /// * `original_bytes` - The original message bytes for recovery verification (if applicable)
445    /// * `use_prehash_recovery` - If true, recovers using hash directly; if false, uses original bytes
446    async fn sign_and_recover_evm(
447        &self,
448        digest: [u8; 32],
449        original_bytes: &[u8],
450        use_prehash_recovery: bool,
451    ) -> AwsKmsResult<Vec<u8>> {
452        // Sign the digest with AWS KMS
453        let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
454
455        // Get public key
456        let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
457
458        // Use shared signature recovery logic
459        recover_evm_signature_from_der(
460            &der_signature,
461            &der_pk,
462            digest,
463            original_bytes,
464            use_prehash_recovery,
465        )
466        .map_err(|e| AwsKmsError::ParseError(e.to_string()))
467    }
468
469    /// Signs a payload using the EVM signing scheme (hashes before signing).
470    ///
471    /// This method applies keccak256 hashing before signing.
472    ///
473    /// **Use for:**
474    /// - Raw transaction data (TxLegacy, TxEip1559)
475    /// - EIP-191 personal messages
476    ///
477    /// **Note:** For EIP-712 typed data, use `sign_hash_evm()` to avoid double-hashing.
478    pub async fn sign_payload_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
479        let digest = keccak256(bytes).0;
480        self.sign_and_recover_evm(digest, bytes, false).await
481    }
482
483    /// Signs a pre-computed hash using the EVM signing scheme (no hashing).
484    ///
485    /// This method signs the hash directly without applying keccak256.
486    ///
487    /// **Use for:**
488    /// - EIP-712 typed data (already hashed)
489    /// - Pre-computed message digests
490    ///
491    /// **Note:** For raw data, use `sign_payload_evm()` instead.
492    pub async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
493        self.sign_and_recover_evm(*hash, hash, true).await
494    }
495}
496
497#[async_trait]
498impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsEvmService for AwsKmsService<T> {
499    async fn get_evm_address(&self) -> AwsKmsResult<Address> {
500        let der = self.client.get_der_public_key(&self.kms_key_id).await?;
501        let eth_address = derive_ethereum_address_from_der(&der)
502            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
503        Ok(Address::Evm(eth_address))
504    }
505
506    async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
507        let digest = keccak256(message).0;
508        self.sign_and_recover_evm(digest, message, false).await
509    }
510
511    async fn sign_hash_evm(&self, hash: &[u8; 32]) -> AwsKmsResult<Vec<u8>> {
512        // Delegates to the implementation method on AwsKmsService
513        self.sign_and_recover_evm(*hash, hash, true).await
514    }
515}
516
517#[async_trait]
518impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsSolanaService for AwsKmsService<T> {
519    async fn get_solana_address(&self) -> AwsKmsResult<Address> {
520        let der = self.client.get_ed25519_public_key(&self.kms_key_id).await?;
521        let solana_address = derive_solana_address_from_der(&der)
522            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
523        Ok(Address::Solana(solana_address))
524    }
525
526    async fn sign_solana(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
527        self.client.sign_ed25519(&self.kms_key_id, message).await
528    }
529}
530
531#[async_trait]
532impl<T: AwsKmsK256 + AwsKmsEd25519 + Clone> AwsKmsStellarService for AwsKmsService<T> {
533    async fn get_stellar_address(&self) -> AwsKmsResult<Address> {
534        let der = self.client.get_ed25519_public_key(&self.kms_key_id).await?;
535        let stellar_address = derive_stellar_address_from_der(&der)
536            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
537        Ok(Address::Stellar(stellar_address))
538    }
539
540    async fn sign_stellar(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
541        self.client.sign_ed25519(&self.kms_key_id, message).await
542    }
543}
544
545#[cfg(test)]
546pub mod tests {
547    use super::*;
548
549    use alloy::primitives::utils::eip191_message;
550    use k256::{
551        ecdsa::SigningKey,
552        elliptic_curve::rand_core::OsRng,
553        pkcs8::{der::Encode, EncodePublicKey},
554    };
555    use mockall::predicate::{eq, ne};
556
557    /// Test Ed25519 key pair for mocking AWS KMS Ed25519 operations
558    pub struct TestEd25519Keys {
559        pub public_key_der: Vec<u8>,
560        pub public_key_raw: [u8; 32],
561    }
562
563    impl Default for TestEd25519Keys {
564        fn default() -> Self {
565            Self::new()
566        }
567    }
568
569    impl TestEd25519Keys {
570        pub fn new() -> Self {
571            // Well-known test Ed25519 public key (32 bytes)
572            let public_key_raw: [u8; 32] = [
573                0x9d, 0x45, 0x7e, 0x45, 0xe4, 0x16, 0xc4, 0xc6, 0x77, 0x67, 0x6a, 0x42, 0xff, 0x96,
574                0x8e, 0x3c, 0xf8, 0xdc, 0x73, 0xc8, 0xf3, 0x3a, 0x8d, 0x19, 0x81, 0x29, 0x7b, 0xfa,
575                0x3e, 0x00, 0x30, 0xba,
576            ];
577
578            // Ed25519 SPKI format: 12-byte header + 32-byte key
579            let mut public_key_der = vec![
580                0x30, 0x2a, // SEQUENCE, 42 bytes
581                0x30, 0x05, // SEQUENCE, 5 bytes
582                0x06, 0x03, 0x2b, 0x65, 0x70, // OID 1.3.101.112 (Ed25519)
583                0x03, 0x21, // BIT STRING, 33 bytes
584                0x00, // zero unused bits
585            ];
586            public_key_der.extend_from_slice(&public_key_raw);
587
588            Self {
589                public_key_der,
590                public_key_raw,
591            }
592        }
593    }
594
595    pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
596        let mut client = MockAwsKmsClient::new();
597        let signing_key = SigningKey::random(&mut OsRng);
598        let s = signing_key
599            .verifying_key()
600            .to_public_key_der()
601            .unwrap()
602            .to_der()
603            .unwrap();
604
605        client
606            .expect_get_der_public_key()
607            .with(eq("test-key-id"))
608            .return_const(Ok(s));
609        client
610            .expect_get_der_public_key()
611            .with(ne("test-key-id"))
612            .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
613
614        client
615            .expect_sign_digest()
616            .withf(|key_id, _| key_id.ne("test-key-id"))
617            .return_const(Err(AwsKmsError::SignError(
618                "Key does not exist".to_string(),
619            )));
620
621        let key = signing_key.clone();
622        client
623            .expect_sign_digest()
624            .withf(|key_id, _| key_id.eq("test-key-id"))
625            .returning(move |_, digest| {
626                let (signature, _) = signing_key
627                    .sign_prehash_recoverable(&digest)
628                    .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
629                let der_signature = signature.to_der().as_bytes().to_vec();
630                Ok(der_signature)
631            });
632
633        // Setup Ed25519 mock expectations
634        let test_ed25519_keys = TestEd25519Keys::new();
635        client
636            .expect_get_ed25519_public_key()
637            .with(eq("test-key-id"))
638            .return_const(Ok(test_ed25519_keys.public_key_der.clone()));
639        client
640            .expect_get_ed25519_public_key()
641            .with(ne("test-key-id"))
642            .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
643
644        // Mock Ed25519 signing - return a fixed 64-byte signature
645        client
646            .expect_sign_ed25519()
647            .withf(|key_id, _| key_id.eq("test-key-id"))
648            .returning(|_, _| Ok(vec![0u8; 64]));
649        client
650            .expect_sign_ed25519()
651            .withf(|key_id, _| key_id.ne("test-key-id"))
652            .return_const(Err(AwsKmsError::SignError(
653                "Key does not exist".to_string(),
654            )));
655
656        client.expect_clone().return_once(MockAwsKmsClient::new);
657
658        (client, key)
659    }
660
661    #[tokio::test]
662    async fn test_get_public_key() {
663        let (mock_client, key) = setup_mock_kms_client();
664        let kms = AwsKmsService::new_for_testing(
665            mock_client,
666            AwsKmsSignerConfig {
667                region: Some("us-east-1".to_string()),
668                key_id: "test-key-id".to_string(),
669            },
670        );
671
672        let result = kms.get_evm_address().await;
673        assert!(result.is_ok());
674        if let Ok(Address::Evm(evm_address)) = result {
675            let expected_address = derive_ethereum_address_from_der(
676                key.verifying_key().to_public_key_der().unwrap().as_bytes(),
677            )
678            .unwrap();
679            assert_eq!(expected_address, evm_address);
680        }
681    }
682
683    #[tokio::test]
684    async fn test_get_public_key_fail() {
685        let (mock_client, _) = setup_mock_kms_client();
686        let kms = AwsKmsService::new_for_testing(
687            mock_client,
688            AwsKmsSignerConfig {
689                region: Some("us-east-1".to_string()),
690                key_id: "invalid-key-id".to_string(),
691            },
692        );
693
694        let result = kms.get_evm_address().await;
695        assert!(result.is_err());
696        if let Err(err) = result {
697            assert!(matches!(err, AwsKmsError::GetError(_)))
698        }
699    }
700
701    #[tokio::test]
702    async fn test_sign_digest() {
703        let (mock_client, _) = setup_mock_kms_client();
704        let kms = AwsKmsService::new_for_testing(
705            mock_client,
706            AwsKmsSignerConfig {
707                region: Some("us-east-1".to_string()),
708                key_id: "test-key-id".to_string(),
709            },
710        );
711
712        let message_eip = eip191_message(b"Hello World!");
713        let result = kms.sign_payload_evm(&message_eip).await;
714
715        // We just assert for Ok, since the pubkey recovery indicates the validity of signature
716        assert!(result.is_ok());
717    }
718
719    #[tokio::test]
720    async fn test_sign_digest_fail() {
721        let (mock_client, _) = setup_mock_kms_client();
722        let kms = AwsKmsService::new_for_testing(
723            mock_client,
724            AwsKmsSignerConfig {
725                region: Some("us-east-1".to_string()),
726                key_id: "invalid-key-id".to_string(),
727            },
728        );
729
730        let message_eip = eip191_message(b"Hello World!");
731        let result = kms.sign_payload_evm(&message_eip).await;
732        assert!(result.is_err());
733        if let Err(err) = result {
734            assert!(matches!(err, AwsKmsError::SignError(_)))
735        }
736    }
737
738    #[tokio::test]
739    async fn test_get_solana_address() {
740        let (mock_client, _) = setup_mock_kms_client();
741        let kms = AwsKmsService::new_for_testing(
742            mock_client,
743            AwsKmsSignerConfig {
744                region: Some("us-east-1".to_string()),
745                key_id: "test-key-id".to_string(),
746            },
747        );
748
749        let result = kms.get_solana_address().await;
750        assert!(result.is_ok());
751        if let Ok(Address::Solana(solana_address)) = result {
752            // Verify it's a valid base58-encoded address
753            assert!(!solana_address.is_empty());
754            assert!(solana_address.len() >= 32 && solana_address.len() <= 44);
755            // Verify it matches the expected address from our test key
756            let test_keys = TestEd25519Keys::new();
757            let expected_address = bs58::encode(test_keys.public_key_raw).into_string();
758            assert_eq!(solana_address, expected_address);
759        } else {
760            panic!("Expected Solana address");
761        }
762    }
763
764    #[tokio::test]
765    async fn test_get_solana_address_fail() {
766        let (mock_client, _) = setup_mock_kms_client();
767        let kms = AwsKmsService::new_for_testing(
768            mock_client,
769            AwsKmsSignerConfig {
770                region: Some("us-east-1".to_string()),
771                key_id: "invalid-key-id".to_string(),
772            },
773        );
774
775        let result = kms.get_solana_address().await;
776        assert!(result.is_err());
777        if let Err(err) = result {
778            assert!(matches!(err, AwsKmsError::GetError(_)))
779        }
780    }
781
782    #[tokio::test]
783    async fn test_sign_solana() {
784        let (mock_client, _) = setup_mock_kms_client();
785        let kms = AwsKmsService::new_for_testing(
786            mock_client,
787            AwsKmsSignerConfig {
788                region: Some("us-east-1".to_string()),
789                key_id: "test-key-id".to_string(),
790            },
791        );
792
793        let message = b"Test Solana message";
794        let result = kms.sign_solana(message).await;
795        assert!(result.is_ok());
796        let signature = result.unwrap();
797        assert_eq!(signature.len(), 64); // Ed25519 signatures are 64 bytes
798    }
799
800    #[tokio::test]
801    async fn test_sign_solana_fail() {
802        let (mock_client, _) = setup_mock_kms_client();
803        let kms = AwsKmsService::new_for_testing(
804            mock_client,
805            AwsKmsSignerConfig {
806                region: Some("us-east-1".to_string()),
807                key_id: "invalid-key-id".to_string(),
808            },
809        );
810
811        let message = b"Test Solana message";
812        let result = kms.sign_solana(message).await;
813        assert!(result.is_err());
814        if let Err(err) = result {
815            assert!(matches!(err, AwsKmsError::SignError(_)))
816        }
817    }
818
819    #[tokio::test]
820    async fn test_get_stellar_address() {
821        let (mock_client, _) = setup_mock_kms_client();
822        let kms = AwsKmsService::new_for_testing(
823            mock_client,
824            AwsKmsSignerConfig {
825                region: Some("us-east-1".to_string()),
826                key_id: "test-key-id".to_string(),
827            },
828        );
829
830        let result = kms.get_stellar_address().await;
831        assert!(result.is_ok());
832        if let Ok(Address::Stellar(stellar_address)) = result {
833            // Stellar addresses start with 'G' for public accounts
834            assert!(stellar_address.starts_with('G'));
835            // Stellar addresses are 56 characters long
836            assert_eq!(stellar_address.len(), 56);
837        } else {
838            panic!("Expected Stellar address");
839        }
840    }
841
842    #[tokio::test]
843    async fn test_get_stellar_address_fail() {
844        let (mock_client, _) = setup_mock_kms_client();
845        let kms = AwsKmsService::new_for_testing(
846            mock_client,
847            AwsKmsSignerConfig {
848                region: Some("us-east-1".to_string()),
849                key_id: "invalid-key-id".to_string(),
850            },
851        );
852
853        let result = kms.get_stellar_address().await;
854        assert!(result.is_err());
855        if let Err(err) = result {
856            assert!(matches!(err, AwsKmsError::GetError(_)))
857        }
858    }
859
860    #[tokio::test]
861    async fn test_sign_stellar() {
862        let (mock_client, _) = setup_mock_kms_client();
863        let kms = AwsKmsService::new_for_testing(
864            mock_client,
865            AwsKmsSignerConfig {
866                region: Some("us-east-1".to_string()),
867                key_id: "test-key-id".to_string(),
868            },
869        );
870
871        let message = b"Test Stellar message";
872        let result = kms.sign_stellar(message).await;
873        assert!(result.is_ok());
874        let signature = result.unwrap();
875        assert_eq!(signature.len(), 64); // Ed25519 signatures are 64 bytes
876    }
877
878    #[tokio::test]
879    async fn test_sign_stellar_fail() {
880        let (mock_client, _) = setup_mock_kms_client();
881        let kms = AwsKmsService::new_for_testing(
882            mock_client,
883            AwsKmsSignerConfig {
884                region: Some("us-east-1".to_string()),
885                key_id: "invalid-key-id".to_string(),
886            },
887        );
888
889        let message = b"Test Stellar message";
890        let result = kms.sign_stellar(message).await;
891        assert!(result.is_err());
892        if let Err(err) = result {
893            assert!(matches!(err, AwsKmsError::SignError(_)))
894        }
895    }
896
897    // Note: Ed25519 DER parsing tests are in utils/ed25519.rs
898
899    #[tokio::test]
900    async fn test_kms_client_cache_same_region_shares_client() {
901        let config1 = AwsKmsSignerConfig {
902            region: Some("us-west-2".to_string()),
903            key_id: "key-aaa".to_string(),
904        };
905        let config2 = AwsKmsSignerConfig {
906            region: Some("us-west-2".to_string()),
907            key_id: "key-bbb".to_string(),
908        };
909
910        let result1 = get_or_create_kms_client(&config1).await;
911        let result2 = get_or_create_kms_client(&config2).await;
912
913        match (result1, result2) {
914            (Ok(client1), Ok(client2)) => {
915                assert!(Arc::ptr_eq(&client1, &client2));
916            }
917            (Err(AwsKmsError::ConfigError(msg)), _) | (_, Err(AwsKmsError::ConfigError(msg))) => {
918                // In environments without TLS roots, the panic is caught as ConfigError
919                assert!(
920                    msg.contains("TLS root certificates"),
921                    "Expected TLS-related config error, got: {msg}"
922                );
923            }
924            (Err(e), _) | (_, Err(e)) => {
925                panic!("Unexpected error: {e:?}");
926            }
927        }
928    }
929
930    #[tokio::test]
931    async fn test_kms_client_returns_config_error_when_region_missing() {
932        let config = AwsKmsSignerConfig {
933            region: None,
934            key_id: "test-key".to_string(),
935        };
936
937        // Covers the missing-region branch in resolve_aws_region().
938        // Does not exercise Client::new() panic handling (that requires TLS root absence).
939        let result = get_or_create_kms_client(&config).await;
940        match result {
941            Err(AwsKmsError::ConfigError(_)) => {}
942            Ok(_) => panic!(
943                "Expected missing-region error; AWS_REGION/AWS_DEFAULT_REGION may be set in env"
944            ),
945            Err(e) => panic!("Expected ConfigError, got: {e:?}"),
946        }
947    }
948}