Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions rs/hang-cli/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use axum::{http::Method, routing::get, Router};
use hang::moq_lite;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use tower_http::cors::{Any, CorsLayer};
use tower_http::services::ServeDir;

Expand All @@ -26,17 +27,15 @@ pub async fn server(

let server = config.init()?;

// Get the first certificate's fingerprint.
// TODO serve all of them so we can support multiple signature algorithms.
let fingerprint = server.fingerprints().first().context("missing certificate")?.clone();

// Notify systemd that we're ready.
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);

let tls_info = server.tls_info();

tokio::select! {
res = accept(server, name, publish.consume()) => res,
res = publish.run() => res,
res = web(listen, fingerprint, public) => res,
res = web(listen, tls_info, public) => res,
}
}

Expand Down Expand Up @@ -91,13 +90,29 @@ async fn run_session(
}

// Initialize the HTTP server (but don't serve yet).
async fn web(bind: SocketAddr, fingerprint: String, public: Option<PathBuf>) -> anyhow::Result<()> {
async fn web(
bind: SocketAddr,
tls_info: Arc<RwLock<moq_native::TlsInfo>>,
public: Option<PathBuf>,
) -> anyhow::Result<()> {
async fn handle_404() -> impl IntoResponse {
(StatusCode::NOT_FOUND, "Not found")
}

let fingerprint_handler = move || async move {
// Get the first certificate's fingerprint.
// TODO serve all of them so we can support multiple signature algorithms.
tls_info
.read()
.expect("tls_info read lock poisoned")
.fingerprints
.first()
.expect("missing certificate")
.clone()
};

let mut app = Router::new()
.route("/certificate.sha256", get(fingerprint))
.route("/certificate.sha256", get(fingerprint_handler))
.layer(CorsLayer::new().allow_origin(Any).allow_methods([Method::GET]));

// If a public directory is provided, serve it.
Expand Down
126 changes: 88 additions & 38 deletions rs/moq-native/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::path::PathBuf;
use std::{net, sync::Arc, time::Duration};
use std::{net, time::Duration};

use crate::crypto;
use anyhow::Context;
Expand All @@ -8,6 +8,7 @@ use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use std::fs;
use std::io::{self, Cursor, Read};
use std::sync::{Arc, RwLock};
use url::Url;
use web_transport_quinn::{http, ServerError};

Expand Down Expand Up @@ -81,7 +82,7 @@ impl ServerConfig {
pub struct Server {
quic: quinn::Endpoint,
accept: FuturesUnordered<BoxFuture<'static, anyhow::Result<Request>>>,
fingerprints: Vec<String>,
certs: Arc<ServeCerts>,
}

impl Server {
Expand All @@ -97,28 +98,19 @@ impl Server {

let provider = crypto::provider();

let mut serve = ServeCerts::new(provider.clone());
let certs = ServeCerts::new(provider.clone());

// Load the certificate and key files based on their index.
anyhow::ensure!(
config.tls.cert.len() == config.tls.key.len(),
"must provide both cert and key"
);

for (cert, key) in config.tls.cert.iter().zip(config.tls.key.iter()) {
serve.load(cert, key)?;
}
certs.load_certs(&config.tls)?;

if !config.tls.generate.is_empty() {
serve.generate(&config.tls.generate)?;
}
let certs = Arc::new(certs);

let fingerprints = serve.fingerprints();
#[cfg(unix)]
tokio::spawn(Self::reload_certs(certs.clone(), config.tls.clone()));

let mut tls = rustls::ServerConfig::builder_with_provider(provider)
.with_protocol_versions(&[&rustls::version::TLS13])?
.with_no_client_auth()
.with_cert_resolver(Arc::new(serve));
.with_cert_resolver(certs.clone());

tls.alpn_protocols = vec![
web_transport_quinn::ALPN.as_bytes().to_vec(),
Expand All @@ -145,12 +137,29 @@ impl Server {
Ok(Self {
quic: quic.clone(),
accept: Default::default(),
fingerprints,
certs,
})
}

pub fn fingerprints(&self) -> &[String] {
&self.fingerprints
#[cfg(unix)]
async fn reload_certs(certs: Arc<ServeCerts>, tls_config: ServerTlsConfig) {
use tokio::signal::unix::{signal, SignalKind};

// Dunno why we wouldn't be allowed to listen for signals, but just in case.
let mut listener = signal(SignalKind::user_defined1()).expect("failed to listen for signals");

while listener.recv().await.is_some() {
tracing::info!("reloading server certificates");

if let Err(err) = certs.load_certs(&tls_config) {
tracing::warn!(%err, "failed to reload server certificates");
}
}
}

// Return the SHA256 fingerprints of all our certificates.
pub fn tls_info(&self) -> Arc<RwLock<TlsInfo>> {
self.certs.info.clone()
}

/// Returns the next partially established QUIC or WebTransport session.
Expand Down Expand Up @@ -299,23 +308,51 @@ impl QuicRequest {
}
}

#[derive(Debug)]
pub struct TlsInfo {
pub(crate) certs: Vec<Arc<CertifiedKey>>,
pub fingerprints: Vec<String>,
}

#[derive(Debug)]
struct ServeCerts {
certs: Vec<Arc<CertifiedKey>>,
info: Arc<RwLock<TlsInfo>>,
provider: crypto::Provider,
}

impl ServeCerts {
pub fn new(provider: crypto::Provider) -> Self {
Self {
certs: Vec::new(),
info: Arc::new(RwLock::new(TlsInfo {
certs: Vec::new(),
fingerprints: Vec::new(),
})),
provider,
}
}

// Load a certificate and corresponding key from a file
pub fn load(&mut self, chain: &PathBuf, key: &PathBuf) -> anyhow::Result<()> {
let chain = fs::File::open(chain).context("failed to open cert file")?;
pub fn load_certs(&self, config: &ServerTlsConfig) -> anyhow::Result<()> {
anyhow::ensure!(config.cert.len() == config.key.len(), "must provide both cert and key");

let mut certs = Vec::new();

// Load the certificate and key files based on their index.
for (cert, key) in config.cert.iter().zip(config.key.iter()) {
certs.push(Arc::new(self.load(cert, key)?));
}

// Generate a new certificate if requested.
if !config.generate.is_empty() {
certs.push(Arc::new(self.generate(&config.generate)?));
}

self.set_certs(certs);
Ok(())
}

// Load a certificate and corresponding key from a file, but don't add it to the certs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not? Why is this is &mut self if it doesn't even mutate ServeCerts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was &mut self but I've removed that because on reload it first loads all the certificates/keys and then replaces them. This prevents that certificates get unloaded when there is an issue with the new certificates like a key that doesn't match the certificate.

fn load(&self, chain_path: &PathBuf, key_path: &PathBuf) -> anyhow::Result<CertifiedKey> {
let chain = fs::File::open(chain_path).context("failed to open cert file")?;
let mut chain = io::BufReader::new(chain);

let chain: Vec<CertificateDer> = rustls_pemfile::certs(&mut chain)
Expand All @@ -325,7 +362,7 @@ impl ServeCerts {
anyhow::ensure!(!chain.is_empty(), "could not find certificate");

// Read the PEM private key
let mut keys = fs::File::open(key).context("failed to open key file")?;
let mut keys = fs::File::open(key_path).context("failed to open key file")?;

// Read the keys into a Vec so we can parse it twice.
let mut buf = Vec::new();
Expand All @@ -334,12 +371,18 @@ impl ServeCerts {
let key = rustls_pemfile::private_key(&mut Cursor::new(&buf))?.context("missing private key")?;
let key = self.provider.key_provider.load_private_key(key)?;

self.certs.push(Arc::new(CertifiedKey::new(chain, key)));
let certified_key = CertifiedKey::new(chain, key);

Ok(())
certified_key.keys_match().context(format!(
"private key {} doesn't match certificate {}",
key_path.display(),
chain_path.display()
))?;

Ok(certified_key)
}

pub fn generate(&mut self, hostnames: &[String]) -> anyhow::Result<()> {
fn generate(&self, hostnames: &[String]) -> anyhow::Result<CertifiedKey> {
let key_pair = rcgen::KeyPair::generate()?;

let mut params = rcgen::CertificateParams::new(hostnames)?;
Expand All @@ -358,28 +401,30 @@ impl ServeCerts {
let key = self.provider.key_provider.load_private_key(key_der.into())?;

// Create a rustls::sign::CertifiedKey
self.certs.push(Arc::new(CertifiedKey::new(vec![cert.into()], key)));

Ok(())
Ok(CertifiedKey::new(vec![cert.into()], key))
}

// Return the SHA256 fingerprints of all our certificates.
pub fn fingerprints(&self) -> Vec<String> {
self.certs
// Replace the certificates
pub fn set_certs(&self, certs: Vec<Arc<CertifiedKey>>) {
let fingerprints = certs
.iter()
.map(|ck| {
let fingerprint = crate::crypto::sha256(&self.provider, ck.cert[0].as_ref());
hex::encode(fingerprint)
})
.collect()
.collect();

let mut info = self.info.write().expect("info write lock poisoned");
info.certs = certs;
info.fingerprints = fingerprints;
}

// Return the best certificate for the given ClientHello.
fn best_certificate(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let server_name = client_hello.server_name()?;
let dns_name = rustls::pki_types::ServerName::try_from(server_name).ok()?;

for ck in &self.certs {
for ck in self.info.read().expect("info read lock poisoned").certs.iter() {
let leaf: webpki::EndEntityCert = ck
.end_entity_cert()
.expect("missing certificate")
Expand All @@ -405,6 +450,11 @@ impl ResolvesServerCert for ServeCerts {
// We do our best and return the first certificate.
tracing::warn!(server_name = ?client_hello.server_name(), "no SNI certificate found");

self.certs.first().cloned()
self.info
.read()
.expect("info read lock poisoned")
.certs
.first()
.cloned()
}
}
3 changes: 1 addition & 2 deletions rs/moq-relay/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ async fn main() -> anyhow::Result<()> {
let mut server = config.server.init()?;
let client = config.client.init()?;
let auth = config.auth.init()?;
let fingerprints = server.fingerprints().to_vec();

let cluster = Cluster::new(config.cluster, client);
let cloned = cluster.clone();
Expand All @@ -29,7 +28,7 @@ async fn main() -> anyhow::Result<()> {
WebState {
auth: auth.clone(),
cluster: cluster.clone(),
fingerprints,
tls_info: server.tls_info(),
conn_id: Default::default(),
},
config.web,
Expand Down
45 changes: 36 additions & 9 deletions rs/moq-relay/src/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub struct HttpsConfig {
pub struct WebState {
pub auth: Auth,
pub cluster: Cluster,
pub fingerprints: Vec<String>,
pub tls_info: Arc<std::sync::RwLock<moq_native::TlsInfo>>,
pub conn_id: AtomicU64,
}

Expand All @@ -91,12 +91,8 @@ impl Web {
}

pub async fn run(self) -> anyhow::Result<()> {
// Get the first certificate's fingerprint.
// TODO serve all of them so we can support multiple signature algorithms.
let fingerprint = self.state.fingerprints.first().expect("missing certificate").clone();

let app = Router::new()
.route("/certificate.sha256", get(fingerprint))
.route("/certificate.sha256", get(serve_fingerprint))
.route("/announced", get(serve_announced))
.route("/announced/{*prefix}", get(serve_announced))
.route("/fetch/{*path}", get(serve_fetch));
Expand All @@ -118,10 +114,12 @@ impl Web {
};

let https = if let Some(listen) = self.config.https.listen {
let cert = self.config.https.cert.as_ref().expect("missing certificate");
let key = self.config.https.key.as_ref().expect("missing key");
let cert = self.config.https.cert.expect("missing https.cert");
let key = self.config.https.key.expect("missing https.key");
let config = hyper_serve::tls_rustls::RustlsConfig::from_pem_file(cert.clone(), key.clone()).await?;

let config = hyper_serve::tls_rustls::RustlsConfig::from_pem_file(cert, key).await?;
#[cfg(unix)]
tokio::spawn(reload_certs(config.clone(), cert, key));

let server = hyper_serve::bind_rustls(listen, config);
Some(server.serve(app))
Expand All @@ -139,6 +137,35 @@ impl Web {
}
}

#[cfg(unix)]
async fn reload_certs(config: hyper_serve::tls_rustls::RustlsConfig, cert: PathBuf, key: PathBuf) {
use tokio::signal::unix::{signal, SignalKind};

// Dunno why we wouldn't be allowed to listen for signals, but just in case.
let mut listener = signal(SignalKind::user_defined1()).expect("failed to listen for signals");

while listener.recv().await.is_some() {
tracing::info!("reloading web certificate");

if let Err(err) = config.reload_from_pem_file(cert.clone(), key.clone()).await {
tracing::warn!(%err, "failed to reload web certificate");
}
}
}

async fn serve_fingerprint(State(state): State<Arc<WebState>>) -> String {
// Get the first certificate's fingerprint.
// TODO serve all of them so we can support multiple signature algorithms.
state
.tls_info
.read()
.expect("tls_info lock poisoned")
.fingerprints
.first()
.expect("missing certificate")
.clone()
}

async fn serve_ws(
ws: WebSocketUpgrade,
Path(path): Path<String>,
Expand Down
Loading