1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
use crate::crypto::error::CryptoError;
use dotenv;
use rand::rngs::OsRng;
use rsa::pkcs8::{DecodePublicKey, EncodePublicKey, LineEnding};
use rsa::{RsaPrivateKey, RsaPublicKey};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::{SystemTime, UNIX_EPOCH};

/// Represents a RSA key pair with a unique identifier and expiry timestamp.
#[derive(Serialize, Deserialize)]
pub struct KeyPair {
    /// A unique identifier for the key pair.
    pub kid: String,
    /// The RSA public key, serialized and deserialized as PEM format.
    #[serde(
        serialize_with = "serialize_rsa_public_key",
        deserialize_with = "deserialize_rsa_public_key"
    )]
    pub public_key: RsaPublicKey,
    /// The RSA private key, which is excluded from serialization and deserialization.
    #[serde(skip)]
    pub private_key: Option<RsaPrivateKey>,
    /// The expiry timestamp of the key pair in UNIX timestamp format.
    pub expiry: u64,
}

/// Serializes an `RsaPublicKey` to a PEM format string for storage or transmission.
fn serialize_rsa_public_key<S>(key: &RsaPublicKey, serializer: S) -> Result<S::Ok, S::Error>
where
    S: Serializer,
{
    let pem = key
        .to_public_key_pem(LineEnding::CRLF)
        .map_err(serde::ser::Error::custom)?;
    serializer.serialize_str(&pem)
}

/// Deserializes an `RsaPublicKey` from a PEM format string.
fn deserialize_rsa_public_key<'de, D>(deserializer: D) -> Result<RsaPublicKey, D::Error>
where
    D: Deserializer<'de>,
{
    let pem = String::deserialize(deserializer)?;
    RsaPublicKey::from_public_key_pem(&pem).map_err(serde::de::Error::custom)
}

impl KeyPair {
    /// Creates a new RSA `KeyPair` with the specified unique identifier (`kid`), key size, and expiry duration.
    ///
    /// This function generates a new RSA key pair of the given size and sets its expiry based on the provided duration.
    /// It encapsulates the generated key pair within a `KeyPair` struct along with a unique identifier and expiry timestamp.
    ///
    /// # Parameters
    ///
    /// * `kid` - A unique identifier for the key pair. This is typically a UUID.
    /// * `expiry_duration` - The duration in seconds from the current time after which the key pair is considered expired.
    ///
    /// # Returns
    ///
    /// Returns a `Result` which is:
    ///
    /// - `Ok(KeyPair)` - A `KeyPair` instance if the key pair was successfully generated.
    /// - `Err(CryptoError)` - An `CryptoError` if an error occurred during key pair generation.
    ///
    /// # Errors
    ///
    /// This function can return an error if:
    ///
    /// - The RSA key generation fails due to invalid parameters or internal errors.
    /// - There are issues with system time retrieval.
    pub fn new(kid: &str, expiry_duration: i64) -> Result<Self, CryptoError> {
        let key_size_str = dotenv::var("KEY_SIZE")?;
        let key_size = key_size_str.parse::<usize>().map_err(CryptoError::from)?;

        let mut rng = OsRng;
        let private_key = RsaPrivateKey::new(&mut rng, key_size)?;
        let public_key = RsaPublicKey::from(&private_key);

        let current_time = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .expect("Time went backwards")
            .as_secs() as u64;

        let expiry = if expiry_duration < 0 {
            current_time
                .checked_sub(expiry_duration.abs() as u64)
                .unwrap_or(0)
        } else {
            current_time
                .checked_add(expiry_duration as u64)
                .unwrap_or(u64::MAX)
        };

        let private_key = Some(private_key);

        Ok(Self {
            kid: kid.to_owned(),
            public_key,
            private_key,
            expiry,
        })
    }

    /// Checks whether the key pair has expired based on the current system time.
    ///
    /// # Returns
    ///
    /// `true` if the key pair has expired, `false` otherwise.
    pub fn is_expired(&self) -> bool {
        let current_time = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .expect("Time went backwards")
            .as_secs() as u64;
        self.expiry < current_time
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn key_pair_generation() {
        let kid = "test_key";
        let expiry_duration: i64 = 3600;

        let key_pair = KeyPair::new(kid, expiry_duration).unwrap();

        assert_eq!(key_pair.kid, kid);
        assert!(key_pair.private_key.is_some());
        // Check if the expiry is roughly in the future by at least the expiry duration minus a small delta
        let now_i64 = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .expect("Time went backwards")
            .as_secs() as i64;

        let now_u64 = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .expect("Time went backwards")
            .as_secs() as u64;

        assert!(key_pair.expiry > now_u64 && key_pair.expiry <= (now_i64 + expiry_duration) as u64);
    }

    #[test]
    fn key_pair_expiry() {
        let kid = "expired_key";
        let expiry_duration = 1; // 1 second
        let key_pair = KeyPair::new(kid, expiry_duration).unwrap();

        // Sleep for 2 seconds to ensure the key expires
        std::thread::sleep(std::time::Duration::new(2, 0));
        assert!(key_pair.is_expired());
    }
}