1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
use crate::crypto::error::CryptoError;
use dotenv;
use rand::rngs::OsRng;
use rsa::pkcs8::{DecodePublicKey, EncodePublicKey, LineEnding};
use rsa::{RsaPrivateKey, RsaPublicKey};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::{SystemTime, UNIX_EPOCH};
/// Represents a RSA key pair with a unique identifier and expiry timestamp.
#[derive(Serialize, Deserialize)]
pub struct KeyPair {
/// A unique identifier for the key pair.
pub kid: String,
/// The RSA public key, serialized and deserialized as PEM format.
#[serde(
serialize_with = "serialize_rsa_public_key",
deserialize_with = "deserialize_rsa_public_key"
)]
pub public_key: RsaPublicKey,
/// The RSA private key, which is excluded from serialization and deserialization.
#[serde(skip)]
pub private_key: Option<RsaPrivateKey>,
/// The expiry timestamp of the key pair in UNIX timestamp format.
pub expiry: u64,
}
/// Serializes an `RsaPublicKey` to a PEM format string for storage or transmission.
fn serialize_rsa_public_key<S>(key: &RsaPublicKey, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let pem = key
.to_public_key_pem(LineEnding::CRLF)
.map_err(serde::ser::Error::custom)?;
serializer.serialize_str(&pem)
}
/// Deserializes an `RsaPublicKey` from a PEM format string.
fn deserialize_rsa_public_key<'de, D>(deserializer: D) -> Result<RsaPublicKey, D::Error>
where
D: Deserializer<'de>,
{
let pem = String::deserialize(deserializer)?;
RsaPublicKey::from_public_key_pem(&pem).map_err(serde::de::Error::custom)
}
impl KeyPair {
/// Creates a new RSA `KeyPair` with the specified unique identifier (`kid`), key size, and expiry duration.
///
/// This function generates a new RSA key pair of the given size and sets its expiry based on the provided duration.
/// It encapsulates the generated key pair within a `KeyPair` struct along with a unique identifier and expiry timestamp.
///
/// # Parameters
///
/// * `kid` - A unique identifier for the key pair. This is typically a UUID.
/// * `expiry_duration` - The duration in seconds from the current time after which the key pair is considered expired.
///
/// # Returns
///
/// Returns a `Result` which is:
///
/// - `Ok(KeyPair)` - A `KeyPair` instance if the key pair was successfully generated.
/// - `Err(CryptoError)` - An `CryptoError` if an error occurred during key pair generation.
///
/// # Errors
///
/// This function can return an error if:
///
/// - The RSA key generation fails due to invalid parameters or internal errors.
/// - There are issues with system time retrieval.
pub fn new(kid: &str, expiry_duration: i64) -> Result<Self, CryptoError> {
let key_size_str = dotenv::var("KEY_SIZE")?;
let key_size = key_size_str.parse::<usize>().map_err(CryptoError::from)?;
let mut rng = OsRng;
let private_key = RsaPrivateKey::new(&mut rng, key_size)?;
let public_key = RsaPublicKey::from(&private_key);
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as u64;
let expiry = if expiry_duration < 0 {
current_time
.checked_sub(expiry_duration.abs() as u64)
.unwrap_or(0)
} else {
current_time
.checked_add(expiry_duration as u64)
.unwrap_or(u64::MAX)
};
let private_key = Some(private_key);
Ok(Self {
kid: kid.to_owned(),
public_key,
private_key,
expiry,
})
}
/// Checks whether the key pair has expired based on the current system time.
///
/// # Returns
///
/// `true` if the key pair has expired, `false` otherwise.
pub fn is_expired(&self) -> bool {
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as u64;
self.expiry < current_time
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn key_pair_generation() {
let kid = "test_key";
let expiry_duration: i64 = 3600;
let key_pair = KeyPair::new(kid, expiry_duration).unwrap();
assert_eq!(key_pair.kid, kid);
assert!(key_pair.private_key.is_some());
// Check if the expiry is roughly in the future by at least the expiry duration minus a small delta
let now_i64 = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as i64;
let now_u64 = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs() as u64;
assert!(key_pair.expiry > now_u64 && key_pair.expiry <= (now_i64 + expiry_duration) as u64);
}
#[test]
fn key_pair_expiry() {
let kid = "expired_key";
let expiry_duration = 1; // 1 second
let key_pair = KeyPair::new(kid, expiry_duration).unwrap();
// Sleep for 2 seconds to ensure the key expires
std::thread::sleep(std::time::Duration::new(2, 0));
assert!(key_pair.is_expired());
}
}