Skip to content

Certs #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tonic_lnd"
version = "0.5.0"
version = "0.25.0"
authors = ["Martin Habovstiak <[email protected]>"]
edition = "2018"
description = "An async library implementing LND RPC via tonic and prost"
Expand All @@ -22,6 +22,7 @@ tokio = { version = "1.7.1", features = ["fs"] }
tracing = { version = "0.1", features = ["log"], optional = true }
rust_decimal = { version = "1.26.1", features = ["db-postgres"] }
rust_decimal_macros = "1.26.1"
webpki-roots = "0.19.0"

rand = "0.8.5"
serde = { version = "1.0.145", features = ["derive"] }
Expand Down
16 changes: 15 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ fn main() -> std::io::Result<()> {
"faraday.proto",
"looprpc/client.proto",
"invoicesrpc/invoices.proto",
"walletunlocker.proto"
"walletunlocker.proto",
"stateservice.proto",
];

let proto_paths: Vec<_> = protos
Expand Down Expand Up @@ -74,6 +75,19 @@ fn main() -> std::io::Result<()> {
.type_attribute("ChannelBackup", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("ChannelPoint", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("funding_txid", "#[derive(serde::Deserialize, serde::Serialize)]")
//StateService fields
.type_attribute("GetState", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("SubscribeState", "#[derive(serde::Deserialize, serde::Serialize)]")
//PendingChannels fields
.type_attribute("PendingChannelsResponse", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("Commitments", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("PendingOpenChannel", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("PendingChannel", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("ClosedChannel", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("ForceClosedChannel", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("PendingHTLC", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("AnchorState", "#[derive(serde::Deserialize, serde::Serialize)]")
.type_attribute("WaitingCloseChannel", "#[derive(serde::Deserialize, serde::Serialize)]")
.format(false)
.compile(&proto_paths, &[dir])?;
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion examples/send_via_route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async fn main() {
// All calls require at least empty parameter
.send_to_route_v2(tonic_lnd::routerrpc::SendToRouteRequest {
payment_hash: vec![],
route: Some(Route {}),
route: Some(Route { total_time_lock: 1000, total_fees: 100000, total_amt: 1000000, hops: todo!(), total_fees_msat: 100, total_amt_msat: 1000 }),
skip_temp_err: true,
})
.await
Expand Down
29 changes: 29 additions & 0 deletions examples/state_service.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// This example unlocks a locked initialized node.
//
// This program accepts three arguments: address, cert file, macaroon file
// The address must start with `https://`!

#[tokio::main]
async fn main() {
let mut args = std::env::args_os();
args.next().expect("not even zeroth arg given");
let address = args.next().expect("missing arguments: address, cert file, macaroon file");
let cert_file = args.next().expect("missing arguments: cert file, macaroon file");
let macaroon_file = args.next().expect("missing argument: macaroon file");
let address = address.into_string().expect("address is not UTF-8");

// Connecting to LND requires only address, cert file, and macaroon file
let mut client = tonic_lnd::connect(address, cert_file, macaroon_file)
.await
.expect("failed to connect");

let unlock = client
.state()
.get_state(tonic_lnd::lnrpc::GetStateRequest { })
.await
.expect("failed to get info");

// We only print it here, note that in real-life code you may want to call `.into_inner()` on
// the response to get the message.
println!("{:#?}", unlock);
}
53 changes: 43 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub type LoopClient = looprpc::swap_client_client::SwapClientClient<InterceptedS
pub type FaradayServerClient = frdrpc::faraday_server_client::FaradayServerClient<InterceptedService<Channel, MacaroonInterceptor>>;
pub type InvoicesClient = invoicesrpc::invoices_client::InvoicesClient<InterceptedService<Channel, MacaroonInterceptor>>;
pub type WalletUnlockerClient = lnrpc::wallet_unlocker_client::WalletUnlockerClient<InterceptedService<Channel, MacaroonInterceptor>>;
pub type StateClient = lnrpc::state_client::StateClient<InterceptedService<Channel, MacaroonInterceptor>>;

/// The client returned by `connect` function
///
Expand All @@ -93,10 +94,11 @@ pub struct Client {
lightning: LightningClient,
wallet: WalletKitClient,
router: RouterClient,
loopclient: LoopClient,
loopclient: LoopClient,
faraday: FaradayServerClient,
invoices: InvoicesClient,
wallet_unlocker: WalletUnlockerClient,
state: StateClient,
}

impl Client {
Expand Down Expand Up @@ -129,6 +131,10 @@ impl Client {
pub fn wallet_unlocker(&mut self) -> &mut WalletUnlockerClient {
&mut self.wallet_unlocker
}

pub fn state(&mut self) -> &mut StateClient {
&mut self.state
}
}

/// [`tonic::Status`] is re-exported as `Error` for convenience.
Expand Down Expand Up @@ -215,7 +221,7 @@ pub async fn connect<A, CP, MP>(address: A, cert_file: CP, macaroon_file: MP) ->
let address_str = address.to_string();
let conn = try_map_err!(address
.try_into(), |error| InternalConnectError::InvalidAddress { address: address_str.clone(), error: Box::new(error), })
.tls_config(tls::config(cert_file).await?)
.tls_config(tls::config(Some(cert_file)).await?)
.map_err(InternalConnectError::TlsConfig)?
.connect()
.await
Expand All @@ -233,6 +239,7 @@ pub async fn connect<A, CP, MP>(address: A, cert_file: CP, macaroon_file: MP) ->
faraday: frdrpc::faraday_server_client::FaradayServerClient::with_interceptor(conn.clone(), interceptor.clone()),
invoices: invoicesrpc::invoices_client::InvoicesClient::with_interceptor(conn.clone(), interceptor.clone()),
wallet_unlocker: lnrpc::wallet_unlocker_client::WalletUnlockerClient::with_interceptor(conn.clone(), interceptor.clone()),
state: lnrpc::state_client::StateClient::with_interceptor(conn.clone(), interceptor.clone()),
};
Ok(client)
}
Expand All @@ -242,7 +249,7 @@ pub async fn in_mem_connect<A>(address: A, cert_file_as_hex: String, macaroon_as
let address_str = address.to_string();
let conn = try_map_err!(address
.try_into(), |error| InternalConnectError::InvalidAddress { address: address_str.clone(), error: Box::new(error), })
.tls_config(tls::config_with_hex(cert_file_as_hex).await?)
.tls_config(tls::config_with_hex(Some(cert_file_as_hex)).await?)
.map_err(InternalConnectError::TlsConfig)?
.connect()
.await
Expand All @@ -260,6 +267,7 @@ pub async fn in_mem_connect<A>(address: A, cert_file_as_hex: String, macaroon_as
faraday: frdrpc::faraday_server_client::FaradayServerClient::with_interceptor(conn.clone(), interceptor.clone()),
invoices: invoicesrpc::invoices_client::InvoicesClient::with_interceptor(conn.clone(), interceptor.clone()),
wallet_unlocker: lnrpc::wallet_unlocker_client::WalletUnlockerClient::with_interceptor(conn.clone(), interceptor.clone()),
state: lnrpc::state_client::StateClient::with_interceptor(conn.clone(), interceptor.clone()),
};
Ok(client)
}
Expand All @@ -269,18 +277,43 @@ mod tls {
use rustls::{RootCertStore, Certificate, TLSError, ServerCertVerified};
use webpki::DNSNameRef;
use crate::error::{ConnectError, InternalConnectError};
use webpki_roots;

pub(crate) async fn config(path: impl AsRef<Path> + Into<PathBuf>) -> Result<tonic::transport::ClientTlsConfig, ConnectError> {
pub(crate) async fn config(path: Option<impl AsRef<Path> + Into<PathBuf>>) -> Result<tonic::transport::ClientTlsConfig, ConnectError> {
let mut tls_config = rustls::ClientConfig::new();
tls_config.dangerous().set_certificate_verifier(std::sync::Arc::new(CertVerifier::load(path).await?));

match path {
Some(cert_path) if cert_path.as_ref().exists() => {
tls_config.dangerous().set_certificate_verifier(std::sync::Arc::new(CertVerifier::load(cert_path).await?));
},
_ => {
tls_config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
#[cfg(feature = "tracing")] {
tracing::warn!("Certificate file not provided or does not exist. Using system trust anchors");
}
}
}

tls_config.set_protocols(&["h2".into()]);
Ok(tonic::transport::ClientTlsConfig::new()
.rustls_client_config(tls_config))
}

pub(crate) async fn config_with_hex(file_as_hex: String) -> Result<tonic::transport::ClientTlsConfig, ConnectError> {
pub(crate) async fn config_with_hex(file_as_hex: Option<String>) -> Result<tonic::transport::ClientTlsConfig, ConnectError> {
let mut tls_config = rustls::ClientConfig::new();
tls_config.dangerous().set_certificate_verifier(std::sync::Arc::new(CertVerifier::load_as_hex(file_as_hex).await?));

match file_as_hex {
Some(hex_cert) if !hex_cert.is_empty() => {
tls_config.dangerous().set_certificate_verifier(std::sync::Arc::new(CertVerifier::load_as_hex(hex_cert).await?));
},
_ => {
tls_config.root_store.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
#[cfg(feature = "tracing")] {
tracing::warn!("Certificate hex not provided or empty. Using system trust anchors");
}
}
}

tls_config.set_protocols(&["h2".into()]);
Ok(tonic::transport::ClientTlsConfig::new()
.rustls_client_config(tls_config))
Expand Down Expand Up @@ -326,11 +359,11 @@ mod tls {

impl rustls::ServerCertVerifier for CertVerifier {
fn verify_server_cert(&self, _roots: &RootCertStore, presented_certs: &[Certificate], _dns_name: DNSNameRef<'_>, _ocsp_response: &[u8]) -> Result<ServerCertVerified, TLSError> {

if self.certs.len() != presented_certs.len() {
return Err(TLSError::General(format!("Mismatched number of certificates (Expected: {}, Presented: {})", self.certs.len(), presented_certs.len())));
}

for (c, p) in self.certs.iter().zip(presented_certs.iter()) {
if *p.0 != **c {
return Err(TLSError::General(format!("Server certificates do not match ours")));
Expand All @@ -344,4 +377,4 @@ mod tls {
Ok(ServerCertVerified::assertion())
}
}
}
}
Loading