Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions rs/hang-cli/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use axum::{http::Method, routing::get, Router};
use hang::moq_lite;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use tower_http::cors::{Any, CorsLayer};
use tower_http::services::ServeDir;

Expand All @@ -26,17 +27,15 @@ pub async fn server(

let server = config.init()?;

// Get the first certificate's fingerprint.
// TODO serve all of them so we can support multiple signature algorithms.
let fingerprint = server.fingerprints().first().context("missing certificate")?.clone();

// Notify systemd that we're ready.
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);

let fingerprints = server.fingerprints();

tokio::select! {
res = accept(server, name, publish.consume()) => res,
res = publish.run() => res,
res = web(listen, fingerprint, public) => res,
res = web(listen, fingerprints, public) => res,
}
}

Expand Down Expand Up @@ -91,13 +90,24 @@ async fn run_session(
}

// Initialize the HTTP server (but don't serve yet).
async fn web(bind: SocketAddr, fingerprint: String, public: Option<PathBuf>) -> anyhow::Result<()> {
async fn web(bind: SocketAddr, fingerprints: Arc<RwLock<Vec<String>>>, public: Option<PathBuf>) -> anyhow::Result<()> {
async fn handle_404() -> impl IntoResponse {
(StatusCode::NOT_FOUND, "Not found")
}

let fingerprint_handler = move || async move {
// Get the first certificate's fingerprint.
// TODO serve all of them so we can support multiple signature algorithms.
fingerprints
.read()
.expect("fingerprints read lock poisoned")
.first()
.cloned()
.ok_or((StatusCode::NOT_FOUND, "missing certificate"))
};

let mut app = Router::new()
.route("/certificate.sha256", get(fingerprint))
.route("/certificate.sha256", get(fingerprint_handler))
.layer(CorsLayer::new().allow_origin(Any).allow_methods([Method::GET]));

// If a public directory is provided, serve it.
Expand Down
120 changes: 84 additions & 36 deletions rs/moq-native/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::path::PathBuf;
use std::{net, sync::Arc, time::Duration};
use std::{net, time::Duration};

use crate::crypto;
use anyhow::Context;
Expand All @@ -8,6 +8,9 @@ use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use std::fs;
use std::io::{self, Cursor, Read};
use std::sync::{Arc, RwLock};
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
use url::Url;
use web_transport_quinn::{http, ServerError};

Expand Down Expand Up @@ -81,7 +84,7 @@ impl ServerConfig {
pub struct Server {
quic: quinn::Endpoint,
accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<Request>>>,
fingerprints: Vec<String>,
certs: Arc<ServeCerts>,
}

impl Server {
Expand All @@ -97,28 +100,37 @@ impl Server {

let provider = crypto::provider();

let mut serve = ServeCerts::new(provider.clone());
let certs = ServeCerts::new(provider.clone());

// Load the certificate and key files based on their index.
anyhow::ensure!(
config.tls.cert.len() == config.tls.key.len(),
"must provide both cert and key"
);
certs.load_certs(&config.tls)?;

for (cert, key) in config.tls.cert.iter().zip(config.tls.key.iter()) {
serve.load(cert, key)?;
}
let certs = Arc::new(certs);

if !config.tls.generate.is_empty() {
serve.generate(&config.tls.generate)?;
}
#[cfg(unix)]
{
let certs = certs.clone();
tokio::spawn(async move {
let tls_config = config.tls.clone();

let fingerprints = serve.fingerprints();
match signal(SignalKind::user_defined1()) {
Ok(mut signal) => loop {
if signal.recv().await.is_some() {
tracing::info!("reloading server certificates");

if let Err(err) = certs.load_certs(&tls_config) {
tracing::warn!(%err, "failed to reload server certificates");
}
}
},
Err(err) => tracing::warn!(%err, "failed to setup server certificate reloading"),
}
});
}

let mut tls = rustls::ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&rustls::version::TLS13])?
.with_no_client_auth()
.with_cert_resolver(Arc::new(serve));
.with_cert_resolver(certs.clone());

tls.alpn_protocols = vec![
web_transport_quinn::ALPN.as_bytes().to_vec(),
Expand All @@ -145,12 +157,13 @@ impl Server {
Ok(Self {
quic: quic.clone(),
accept: Default::default(),
fingerprints,
certs,
})
}

pub fn fingerprints(&self) -> &[String] {
&self.fingerprints
// Return the SHA256 fingerprints of all our certificates.
pub fn fingerprints(&self) -> Arc<RwLock<Vec<String>>> {
self.certs.fingerprints.clone()
}

/// Returns the next partially established QUIC or WebTransport session.
Expand Down Expand Up @@ -301,21 +314,42 @@ impl QuicRequest {

#[derive(Debug)]
struct ServeCerts {
certs: Vec<Arc<CertifiedKey>>,
certs: RwLock<Vec<Arc<CertifiedKey>>>,
fingerprints: Arc<RwLock<Vec<String>>>,
provider: crypto::Provider,
}

impl ServeCerts {
pub fn new(provider: crypto::Provider) -> Self {
Self {
certs: Vec::new(),
certs: RwLock::new(Vec::new()),
fingerprints: Arc::new(RwLock::new(Vec::new())),
provider,
}
}

// Load a certificate and corresponding key from a file
pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
let chain = fs::File::open(chain).context("failed to open cert file")?;
pub fn load_certs(&self, config: &ServerTlsConfig) -> anyhow::Result<()> {
anyhow::ensure!(config.cert.len() == config.key.len(), "must provide both cert and key");

let mut certs = Vec::new();

// Load the certificate and key files based on their index.
for (cert, key) in config.cert.iter().zip(config.key.iter()) {
certs.push(Arc::new(self.load(cert, key)?));
}

// Generate a new certificate if requested.
if !config.generate.is_empty() {
certs.push(Arc::new(self.generate(&config.generate)?));
}

self.set_certs(certs);
Ok(())
}

// Load a certificate and corresponding key from a file, but don't add it to the certs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? Why is this is &mut self if it doesn't even mutate ServeCerts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was &mut self but I've removed that because on reload it first loads all the certificates/keys and then replaces them. This prevents that certificates get unloaded when there is an issue with the new certificates like a key that doesn't match the certificate.

fn load(&self, chain_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<CertifiedKey> {
let chain = fs::File::open(chain_path).context("failed to open cert file")?;
let mut chain = io::BufReader::new(chain);

let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
Expand All @@ -325,7 +359,7 @@ impl ServeCerts {
anyhow::ensure!(!chain.is_empty(), "could not find certificate");

// Read the PEM private key
let mut keys = fs::File::open(key).context("failed to open key file")?;
let mut keys = fs::File::open(key_path).context("failed to open key file")?;

// Read the keys into a Vec so we can parse it twice.
let mut buf = Vec::new();
Expand All @@ -334,12 +368,18 @@ impl ServeCerts {
let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
let key = self.provider.key_provider.load_private_key(key)?;

self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
let certified_key = CertifiedKey::new(chain, key);

Ok(())
certified_key.keys_match().context(format!(
"private key {} doesn't match certificate {}",
key_path.display(),
chain_path.display()
))?;

Ok(certified_key)
}

pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
fn generate(&self, hostnames: &[String]) -> anyhow::Result<CertifiedKey> {
let key_pair = rcgen::KeyPair::generate()?;

let mut params = rcgen::CertificateParams::new(hostnames)?;
Expand All @@ -358,28 +398,36 @@ impl ServeCerts {
let key = self.provider.key_provider.load_private_key(key_der.into())?;

// Create a rustls::sign::CertifiedKey
self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));
Ok(CertifiedKey::new(vec![cert.into()], key))
}

Ok(())
// Replace the certificates
pub fn set_certs(&self, certs: Vec<Arc<CertifiedKey>>) {
*self.certs.write().expect("certs write lock poisened") = certs;
self.update_fingerprints();
}

// Return the SHA256 fingerprints of all our certificates.
pub fn fingerprints(&self) -> Vec<String> {
self.certs
fn update_fingerprints(&self) {
let fingerprints = self
.certs
.read()
.unwrap()
.iter()
.map(|ck| {
let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
hex::encode(fingerprint)
})
.collect()
.collect();

*self.fingerprints.write().expect("fingerprints write lock poisened") = fingerprints;
}

// Return the best certificate for the given ClientHello.
fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let server_name = client_hello.server_name()?;
let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;

for ck in &self.certs {
for ck in self.certs.read().expect("certs read lock poisoned").iter() {
let leaf: webpki::EndEntityCert = ck
.end_entity_cert()
.expect("missing certificate")
Expand All @@ -405,6 +453,6 @@ impl ResolvesServerCert for ServeCerts {
// We do our best and return the first certificate.
tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");

self.certs.first().cloned()
self.certs.read().expect("certs read lock poisoned").first().cloned()
}
}
3 changes: 1 addition & 2 deletions rs/moq-relay/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ async fn main() -> anyhow::Result<()> {
let mut server = config.server.init()?;
let client = config.client.init()?;
let auth = config.auth.init()?;
let fingerprints = server.fingerprints().to_vec();

let cluster = Cluster::new(config.cluster, client);
let cloned = cluster.clone();
Expand All @@ -29,7 +28,7 @@ async fn main() -> anyhow::Result<()> {
WebState {
auth: auth.clone(),
cluster: cluster.clone(),
fingerprints,
fingerprints: server.fingerprints(),
conn_id: Default::default(),
},
config.web,
Expand Down
48 changes: 39 additions & 9 deletions rs/moq-relay/src/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use clap::Parser;
use moq_lite::{OriginConsumer, OriginProducer};
use serde::{Deserialize, Serialize};
use std::future::Future;
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};
use tower_http::cors::{Any, CorsLayer};

use crate::{Auth, Cluster};
Expand Down Expand Up @@ -75,7 +77,7 @@ pub struct HttpsConfig {
pub struct WebState {
pub auth: Auth,
pub cluster: Cluster,
pub fingerprints: Vec<String>,
pub fingerprints: Arc<std::sync::RwLock<Vec<String>>>,
pub conn_id: AtomicU64,
}

Expand All @@ -91,12 +93,8 @@ impl Web {
}

pub async fn run(self) -> anyhow::Result<()> {
// Get the first certificate's fingerprint.
// TODO serve all of them so we can support multiple signature algorithms.
let fingerprint = self.state.fingerprints.first().expect("missing certificate").clone();

let app = Router::new()
.route("/certificate.sha256", get(fingerprint))
.route("/certificate.sha256", get(serve_fingerprint))
.route("/announced", get(serve_announced))
.route("/announced/{*prefix}", get(serve_announced))
.route("/fetch/{*path}", get(serve_fetch));
Expand All @@ -118,10 +116,12 @@ impl Web {
};

let https = if let Some(listen) = self.config.https.listen {
let cert = self.config.https.cert.as_ref().expect("missing certificate");
let key = self.config.https.key.as_ref().expect("missing key");
let cert = self.config.https.cert.expect("missing https.cert");
let key = self.config.https.key.expect("missing https.key");
let config = hyper_serve::tls_rustls::RustlsConfig::from_pem_file(cert.clone(), key.clone()).await?;

let config = hyper_serve::tls_rustls::RustlsConfig::from_pem_file(cert, key).await?;
#[cfg(unix)]
setup_reload(config.clone(), cert, key);

let server = hyper_serve::bind_rustls(listen, config);
Some(server.serve(app))
Expand All @@ -139,6 +139,36 @@ impl Web {
}
}

#[cfg(unix)]
fn setup_reload(config: hyper_serve::tls_rustls::RustlsConfig, cert: PathBuf, key: PathBuf) {
tokio::spawn(async move {
match signal(SignalKind::user_defined1()) {
Ok(mut signal) => loop {
if signal.recv().await.is_some() {
tracing::info!("reloading web certificate");

if let Err(err) = config.reload_from_pem_file(cert.clone(), key.clone()).await {
tracing::warn!(%err, "failed to reload web certificate");
}
}
},
Err(err) => tracing::warn!(%err, "failed to setup web certificate reloading"),
}
});
}

async fn serve_fingerprint(State(state): State<Arc<WebState>>) -> String {
// Get the first certificate's fingerprint.
// TODO serve all of them so we can support multiple signature algorithms.
state
.fingerprints
.read()
.expect("fingerprints lock poisoned")
.first()
.expect("missing certificate")
.clone()
}

async fn serve_ws(
ws: WebSocketUpgrade,
Path(path): Path<String>,
Expand Down
Loading