diff --git a/rs/hang-cli/src/server.rs b/rs/hang-cli/src/server.rs index cb46d42ec..a1e491100 100644 --- a/rs/hang-cli/src/server.rs +++ b/rs/hang-cli/src/server.rs @@ -6,6 +6,7 @@ use axum::{http::Method, routing::get, Router}; use hang::moq_lite; use std::net::SocketAddr; use std::path::PathBuf; +use std::sync::{Arc, RwLock}; use tower_http::cors::{Any, CorsLayer}; use tower_http::services::ServeDir; @@ -26,17 +27,15 @@ pub async fn server( let server = config.init()?; - // Get the first certificate's fingerprint. - // TODO serve all of them so we can support multiple signature algorithms. - let fingerprint = server.fingerprints().first().context("missing certificate")?.clone(); - // Notify systemd that we're ready. let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + let tls_info = server.tls_info(); + tokio::select! { res = accept(server, name, publish.consume()) => res, res = publish.run() => res, - res = web(listen, fingerprint, public) => res, + res = web(listen, tls_info, public) => res, } } @@ -91,13 +90,29 @@ async fn run_session( } // Initialize the HTTP server (but don't serve yet). -async fn web(bind: SocketAddr, fingerprint: String, public: Option) -> anyhow::Result<()> { +async fn web( + bind: SocketAddr, + tls_info: Arc>, + public: Option, +) -> anyhow::Result<()> { async fn handle_404() -> impl IntoResponse { (StatusCode::NOT_FOUND, "Not found") } + let fingerprint_handler = move || async move { + // Get the first certificate's fingerprint. + // TODO serve all of them so we can support multiple signature algorithms. + tls_info + .read() + .expect("tls_info read lock poisoned") + .fingerprints + .first() + .expect("missing certificate") + .clone() + }; + let mut app = Router::new() - .route("/certificate.sha256", get(fingerprint)) + .route("/certificate.sha256", get(fingerprint_handler)) .layer(CorsLayer::new().allow_origin(Any).allow_methods([Method::GET])); // If a public directory is provided, serve it. diff --git a/rs/moq-native/src/server.rs b/rs/moq-native/src/server.rs index cfee06a72..6d8845d18 100644 --- a/rs/moq-native/src/server.rs +++ b/rs/moq-native/src/server.rs @@ -1,5 +1,5 @@ use std::path::PathBuf; -use std::{net, sync::Arc, time::Duration}; +use std::{net, time::Duration}; use crate::crypto; use anyhow::Context; @@ -8,6 +8,7 @@ use rustls::server::{ClientHello, ResolvesServerCert}; use rustls::sign::CertifiedKey; use std::fs; use std::io::{self, Cursor, Read}; +use std::sync::{Arc, RwLock}; use url::Url; use web_transport_quinn::{http, ServerError}; @@ -81,7 +82,7 @@ impl ServerConfig { pub struct Server { quic: quinn::Endpoint, accept: FuturesUnordered>>, - fingerprints: Vec, + certs: Arc, } impl Server { @@ -97,28 +98,19 @@ impl Server { let provider = crypto::provider(); - let mut serve = ServeCerts::new(provider.clone()); + let certs = ServeCerts::new(provider.clone()); - // Load the certificate and key files based on their index. - anyhow::ensure!( - config.tls.cert.len() == config.tls.key.len(), - "must provide both cert and key" - ); - - for (cert, key) in config.tls.cert.iter().zip(config.tls.key.iter()) { - serve.load(cert, key)?; - } + certs.load_certs(&config.tls)?; - if !config.tls.generate.is_empty() { - serve.generate(&config.tls.generate)?; - } + let certs = Arc::new(certs); - let fingerprints = serve.fingerprints(); + #[cfg(unix)] + tokio::spawn(Self::reload_certs(certs.clone(), config.tls.clone())); let mut tls = rustls::ServerConfig::builder_with_provider(provider) .with_protocol_versions(&[&rustls::version::TLS13])? .with_no_client_auth() - .with_cert_resolver(Arc::new(serve)); + .with_cert_resolver(certs.clone()); tls.alpn_protocols = vec![ web_transport_quinn::ALPN.as_bytes().to_vec(), @@ -145,12 +137,29 @@ impl Server { Ok(Self { quic: quic.clone(), accept: Default::default(), - fingerprints, + certs, }) } - pub fn fingerprints(&self) -> &[String] { - &self.fingerprints + #[cfg(unix)] + async fn reload_certs(certs: Arc, tls_config: ServerTlsConfig) { + use tokio::signal::unix::{signal, SignalKind}; + + // Dunno why we wouldn't be allowed to listen for signals, but just in case. + let mut listener = signal(SignalKind::user_defined1()).expect("failed to listen for signals"); + + while listener.recv().await.is_some() { + tracing::info!("reloading server certificates"); + + if let Err(err) = certs.load_certs(&tls_config) { + tracing::warn!(%err, "failed to reload server certificates"); + } + } + } + + // Return the SHA256 fingerprints of all our certificates. + pub fn tls_info(&self) -> Arc> { + self.certs.info.clone() } /// Returns the next partially established QUIC or WebTransport session. @@ -299,23 +308,51 @@ impl QuicRequest { } } +#[derive(Debug)] +pub struct TlsInfo { + pub(crate) certs: Vec>, + pub fingerprints: Vec, +} + #[derive(Debug)] struct ServeCerts { - certs: Vec>, + info: Arc>, provider: crypto::Provider, } impl ServeCerts { pub fn new(provider: crypto::Provider) -> Self { Self { - certs: Vec::new(), + info: Arc::new(RwLock::new(TlsInfo { + certs: Vec::new(), + fingerprints: Vec::new(), + })), provider, } } - // Load a certificate and corresponding key from a file - pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> { - let chain = fs::File::open(chain).context("failed to open cert file")?; + pub fn load_certs(&self, config: &ServerTlsConfig) -> anyhow::Result<()> { + anyhow::ensure!(config.cert.len() == config.key.len(), "must provide both cert and key"); + + let mut certs = Vec::new(); + + // Load the certificate and key files based on their index. + for (cert, key) in config.cert.iter().zip(config.key.iter()) { + certs.push(Arc::new(self.load(cert, key)?)); + } + + // Generate a new certificate if requested. + if !config.generate.is_empty() { + certs.push(Arc::new(self.generate(&config.generate)?)); + } + + self.set_certs(certs); + Ok(()) + } + + // Load a certificate and corresponding key from a file, but don't add it to the certs + fn load(&self, chain_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result { + let chain = fs::File::open(chain_path).context("failed to open cert file")?; let mut chain = io::BufReader::new(chain); let chain: Vec = rustls_pemfile::certs(&mut chain) @@ -325,7 +362,7 @@ impl ServeCerts { anyhow::ensure!(!chain.is_empty(), "could not find certificate"); // Read the PEM private key - let mut keys = fs::File::open(key).context("failed to open key file")?; + let mut keys = fs::File::open(key_path).context("failed to open key file")?; // Read the keys into a Vec so we can parse it twice. let mut buf = Vec::new(); @@ -334,12 +371,18 @@ impl ServeCerts { let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?; let key = self.provider.key_provider.load_private_key(key)?; - self.certs.push(Arc::new(CertifiedKey::new(chain, key))); + let certified_key = CertifiedKey::new(chain, key); - Ok(()) + certified_key.keys_match().context(format!( + "private key {} doesn't match certificate {}", + key_path.display(), + chain_path.display() + ))?; + + Ok(certified_key) } - pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> { + fn generate(&self, hostnames: &[String]) -> anyhow::Result { let key_pair = rcgen::KeyPair::generate()?; let mut params = rcgen::CertificateParams::new(hostnames)?; @@ -358,20 +401,22 @@ impl ServeCerts { let key = self.provider.key_provider.load_private_key(key_der.into())?; // Create a rustls::sign::CertifiedKey - self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key))); - - Ok(()) + Ok(CertifiedKey::new(vec![cert.into()], key)) } - // Return the SHA256 fingerprints of all our certificates. - pub fn fingerprints(&self) -> Vec { - self.certs + // Replace the certificates + pub fn set_certs(&self, certs: Vec>) { + let fingerprints = certs .iter() .map(|ck| { let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref()); hex::encode(fingerprint) }) - .collect() + .collect(); + + let mut info = self.info.write().expect("info write lock poisoned"); + info.certs = certs; + info.fingerprints = fingerprints; } // Return the best certificate for the given ClientHello. @@ -379,7 +424,7 @@ impl ServeCerts { let server_name = client_hello.server_name()?; let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?; - for ck in &self.certs { + for ck in self.info.read().expect("info read lock poisoned").certs.iter() { let leaf: webpki::EndEntityCert = ck .end_entity_cert() .expect("missing certificate") @@ -405,6 +450,11 @@ impl ResolvesServerCert for ServeCerts { // We do our best and return the first certificate. tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found"); - self.certs.first().cloned() + self.info + .read() + .expect("info read lock poisoned") + .certs + .first() + .cloned() } } diff --git a/rs/moq-relay/src/main.rs b/rs/moq-relay/src/main.rs index 11010f819..83678e460 100644 --- a/rs/moq-relay/src/main.rs +++ b/rs/moq-relay/src/main.rs @@ -18,7 +18,6 @@ async fn main() -> anyhow::Result<()> { let mut server = config.server.init()?; let client = config.client.init()?; let auth = config.auth.init()?; - let fingerprints = server.fingerprints().to_vec(); let cluster = Cluster::new(config.cluster, client); let cloned = cluster.clone(); @@ -29,7 +28,7 @@ async fn main() -> anyhow::Result<()> { WebState { auth: auth.clone(), cluster: cluster.clone(), - fingerprints, + tls_info: server.tls_info(), conn_id: Default::default(), }, config.web, diff --git a/rs/moq-relay/src/web.rs b/rs/moq-relay/src/web.rs index 07863cc03..9f7fb1715 100644 --- a/rs/moq-relay/src/web.rs +++ b/rs/moq-relay/src/web.rs @@ -75,7 +75,7 @@ pub struct HttpsConfig { pub struct WebState { pub auth: Auth, pub cluster: Cluster, - pub fingerprints: Vec, + pub tls_info: Arc>, pub conn_id: AtomicU64, } @@ -91,12 +91,8 @@ impl Web { } pub async fn run(self) -> anyhow::Result<()> { - // Get the first certificate's fingerprint. - // TODO serve all of them so we can support multiple signature algorithms. - let fingerprint = self.state.fingerprints.first().expect("missing certificate").clone(); - let app = Router::new() - .route("/certificate.sha256", get(fingerprint)) + .route("/certificate.sha256", get(serve_fingerprint)) .route("/announced", get(serve_announced)) .route("/announced/{*prefix}", get(serve_announced)) .route("/fetch/{*path}", get(serve_fetch)); @@ -118,10 +114,12 @@ impl Web { }; let https = if let Some(listen) = self.config.https.listen { - let cert = self.config.https.cert.as_ref().expect("missing certificate"); - let key = self.config.https.key.as_ref().expect("missing key"); + let cert = self.config.https.cert.expect("missing https.cert"); + let key = self.config.https.key.expect("missing https.key"); + let config = hyper_serve::tls_rustls::RustlsConfig::from_pem_file(cert.clone(), key.clone()).await?; - let config = hyper_serve::tls_rustls::RustlsConfig::from_pem_file(cert, key).await?; + #[cfg(unix)] + tokio::spawn(reload_certs(config.clone(), cert, key)); let server = hyper_serve::bind_rustls(listen, config); Some(server.serve(app)) @@ -139,6 +137,35 @@ impl Web { } } +#[cfg(unix)] +async fn reload_certs(config: hyper_serve::tls_rustls::RustlsConfig, cert: PathBuf, key: PathBuf) { + use tokio::signal::unix::{signal, SignalKind}; + + // Dunno why we wouldn't be allowed to listen for signals, but just in case. + let mut listener = signal(SignalKind::user_defined1()).expect("failed to listen for signals"); + + while listener.recv().await.is_some() { + tracing::info!("reloading web certificate"); + + if let Err(err) = config.reload_from_pem_file(cert.clone(), key.clone()).await { + tracing::warn!(%err, "failed to reload web certificate"); + } + } +} + +async fn serve_fingerprint(State(state): State>) -> String { + // Get the first certificate's fingerprint. + // TODO serve all of them so we can support multiple signature algorithms. + state + .tls_info + .read() + .expect("tls_info lock poisoned") + .fingerprints + .first() + .expect("missing certificate") + .clone() +} + async fn serve_ws( ws: WebSocketUpgrade, Path(path): Path,