diff --git a/build.rs b/build.rs index cfed7b3..f2066b5 100644 --- a/build.rs +++ b/build.rs @@ -7,7 +7,7 @@ fn main() -> std::io::Result<()> { let mut lnd_rpc_dir = PathBuf::from(lnd_repo_path); lnd_rpc_dir.push("lnrpc"); lnd_rpc_dir - }, + } None => PathBuf::from("vendor"), }; diff --git a/examples/getinfo.rs b/examples/getinfo.rs index 2728e34..b8b7c8f 100644 --- a/examples/getinfo.rs +++ b/examples/getinfo.rs @@ -8,8 +8,12 @@ async fn main() { let mut args = std::env::args_os(); args.next().expect("not even zeroth arg given"); - let address = args.next().expect("missing arguments: address, cert file, macaroon file"); - let cert_file = args.next().expect("missing arguments: cert file, macaroon file"); + let address = args + .next() + .expect("missing arguments: address, cert file, macaroon file"); + let cert_file = args + .next() + .expect("missing arguments: cert file, macaroon file"); let macaroon_file = args.next().expect("missing argument: macaroon file"); let address = address.into_string().expect("address is not UTF-8"); diff --git a/examples/getversion.rs b/examples/getversion.rs index 051f89b..9fbf554 100644 --- a/examples/getversion.rs +++ b/examples/getversion.rs @@ -2,8 +2,12 @@ async fn main() { let mut args = std::env::args_os(); args.next().expect("not even zeroth arg given"); - let address = args.next().expect("missing arguments: address, cert file, macaroon file"); - let cert_file = args.next().expect("missing arguments: cert file, macaroon file"); + let address = args + .next() + .expect("missing arguments: address, cert file, macaroon file"); + let cert_file = args + .next() + .expect("missing arguments: cert file, macaroon file"); let macaroon_file = args.next().expect("missing argument: macaroon file"); let address = address.into_string().expect("address is not UTF-8"); diff --git a/src/error.rs b/src/error.rs index 62f5386..e4c3534 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,19 +13,29 @@ pub struct ConnectError { impl From for ConnectError { fn from(value: InternalConnectError) -> Self { - ConnectError { - internal: value, - } + ConnectError { internal: value } } } #[derive(Debug)] pub(crate) enum InternalConnectError { - ReadFile { file: PathBuf, error: std::io::Error, }, - ParseCert { file: PathBuf, error: std::io::Error, }, - InvalidAddress { address: String, error: Box, }, + ReadFile { + file: PathBuf, + error: std::io::Error, + }, + ParseCert { + file: PathBuf, + error: std::io::Error, + }, + InvalidAddress { + address: String, + error: Box, + }, TlsConfig(tonic::transport::Error), - Connect { address: String, error: tonic::transport::Error, } + Connect { + address: String, + error: tonic::transport::Error, + }, } impl fmt::Display for ConnectError { diff --git a/src/lib.rs b/src/lib.rs index ce03484..45616d4 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -70,9 +70,11 @@ MITNFA /// This is part of public interface so it's re-exported. pub extern crate tonic; -use std::path::{Path, PathBuf}; -use std::convert::TryInto; pub use error::ConnectError; + +use std::convert::TryInto; +use std::path::{Path, PathBuf}; + use error::InternalConnectError; use tonic::codegen::InterceptedService; use tonic::transport::Channel; @@ -81,7 +83,8 @@ use tonic::transport::Channel; use tracing; /// Convenience type alias for lightning client. -pub type LightningClient = lnrpc::lightning_client::LightningClient>; +pub type LightningClient = + lnrpc::lightning_client::LightningClient>; /// Convenience type alias for wallet client. pub type WalletKitClient = @@ -96,7 +99,8 @@ pub type VersionerClient = verrpc::versioner_client::VersionerClient>; // Convenience type alias for signer client. -pub type SignerClient = signrpc::signer_client::SignerClient>; +pub type SignerClient = + signrpc::signer_client::SignerClient>; /// The client returned by `connect` function /// @@ -148,7 +152,7 @@ macro_rules! try_map_err { Ok(value) => value, Err(error) => return Err($mapfn(error).into()), } - } + }; } /// Messages and other types generated by `tonic`/`prost` @@ -183,17 +187,25 @@ pub struct MacaroonInterceptor { impl tonic::service::Interceptor for MacaroonInterceptor { fn call(&mut self, mut request: tonic::Request<()>) -> Result, Error> { - request - .metadata_mut() - .insert("macaroon", tonic::metadata::MetadataValue::from_str(&self.macaroon).expect("hex produced non-ascii")); + request.metadata_mut().insert( + "macaroon", + tonic::metadata::MetadataValue::from_str(&self.macaroon) + .expect("hex produced non-ascii"), + ); Ok(request) } } -async fn load_macaroon(path: impl AsRef + Into) -> Result { - let macaroon = tokio::fs::read(&path) - .await - .map_err(|error| InternalConnectError::ReadFile { file: path.into(), error, })?; +async fn load_macaroon( + path: impl AsRef + Into, +) -> Result { + let macaroon = + tokio::fs::read(&path) + .await + .map_err(|error| InternalConnectError::ReadFile { + file: path.into(), + error, + })?; Ok(hex::encode(&macaroon)) } @@ -209,82 +221,132 @@ async fn load_macaroon(path: impl AsRef + Into) -> Result(address: A, cert_file: CP, macaroon_file: MP) -> Result where A: TryInto + std::fmt::Debug + ToString, >::Error: std::error::Error + Send + Sync + 'static, CP: AsRef + Into + std::fmt::Debug, MP: AsRef + Into + std::fmt::Debug { +pub async fn connect( + address: A, + cert_file: CP, + macaroon_file: MP, +) -> Result +where + A: TryInto + std::fmt::Debug + ToString, + >::Error: std::error::Error + Send + Sync + 'static, + CP: AsRef + Into + std::fmt::Debug, + MP: AsRef + Into + std::fmt::Debug, +{ let address_str = address.to_string(); - let conn = try_map_err!(address - .try_into(), |error| InternalConnectError::InvalidAddress { address: address_str.clone(), error: Box::new(error), }) - .tls_config(tls::config(cert_file).await?) - .map_err(InternalConnectError::TlsConfig)? - .connect() - .await - .map_err(|error| InternalConnectError::Connect { address: address_str, error, })?; + let conn = try_map_err!(address.try_into(), |error| { + InternalConnectError::InvalidAddress { + address: address_str.clone(), + error: Box::new(error), + } + }) + .tls_config(tls::config(cert_file).await?) + .map_err(InternalConnectError::TlsConfig)? + .connect() + .await + .map_err(|error| InternalConnectError::Connect { + address: address_str, + error, + })?; let macaroon = load_macaroon(macaroon_file).await?; - let interceptor = MacaroonInterceptor { macaroon, }; + let interceptor = MacaroonInterceptor { macaroon }; let client = Client { - lightning: lnrpc::lightning_client::LightningClient::with_interceptor(conn.clone(), interceptor.clone()), - wallet: walletrpc::wallet_kit_client::WalletKitClient::with_interceptor(conn.clone(), interceptor.clone()), + lightning: lnrpc::lightning_client::LightningClient::with_interceptor( + conn.clone(), + interceptor.clone(), + ), + wallet: walletrpc::wallet_kit_client::WalletKitClient::with_interceptor( + conn.clone(), + interceptor.clone(), + ), peers: peersrpc::peers_client::PeersClient::with_interceptor( conn.clone(), interceptor.clone(), ), - version: verrpc::versioner_client::VersionerClient::with_interceptor(conn.clone(), interceptor.clone()), + version: verrpc::versioner_client::VersionerClient::with_interceptor( + conn.clone(), + interceptor.clone(), + ), signer: signrpc::signer_client::SignerClient::with_interceptor(conn, interceptor), }; Ok(client) } mod tls { + use crate::error::{ConnectError, InternalConnectError}; + use rustls::{Certificate, RootCertStore, ServerCertVerified, TLSError}; use std::path::{Path, PathBuf}; - use rustls::{RootCertStore, Certificate, TLSError, ServerCertVerified}; use webpki::DNSNameRef; - use crate::error::{ConnectError, InternalConnectError}; - pub(crate) async fn config(path: impl AsRef + Into) -> Result { + pub(crate) async fn config( + path: impl AsRef + Into, + ) -> Result { let mut tls_config = rustls::ClientConfig::new(); - tls_config.dangerous().set_certificate_verifier(std::sync::Arc::new(CertVerifier::load(path).await?)); + tls_config + .dangerous() + .set_certificate_verifier(std::sync::Arc::new(CertVerifier::load(path).await?)); tls_config.set_protocols(&["h2".into()]); - Ok(tonic::transport::ClientTlsConfig::new() - .rustls_client_config(tls_config)) + Ok(tonic::transport::ClientTlsConfig::new().rustls_client_config(tls_config)) } pub(crate) struct CertVerifier { - certs: Vec> + certs: Vec>, } impl CertVerifier { - pub(crate) async fn load(path: impl AsRef + Into) -> Result { - let contents = try_map_err!(tokio::fs::read(&path).await, - |error| InternalConnectError::ReadFile { file: path.into(), error }); + pub(crate) async fn load( + path: impl AsRef + Into, + ) -> Result { + let contents = try_map_err!(tokio::fs::read(&path).await, |error| { + InternalConnectError::ReadFile { + file: path.into(), + error, + } + }); let mut reader = &*contents; - let certs = try_map_err!(rustls_pemfile::certs(&mut reader), - |error| InternalConnectError::ParseCert { file: path.into(), error }); + let certs = try_map_err!(rustls_pemfile::certs(&mut reader), |error| { + InternalConnectError::ParseCert { + file: path.into(), + error, + } + }); - #[cfg(feature = "tracing")] { + #[cfg(feature = "tracing")] + { tracing::debug!("Certificates loaded (Count: {})", certs.len()); } - Ok(CertVerifier { - certs: certs, - }) + Ok(CertVerifier { certs: certs }) } } impl rustls::ServerCertVerifier for CertVerifier { - fn verify_server_cert(&self, _roots: &RootCertStore, presented_certs: &[Certificate], _dns_name: DNSNameRef<'_>, _ocsp_response: &[u8]) -> Result { - + fn verify_server_cert( + &self, + _roots: &RootCertStore, + presented_certs: &[Certificate], + _dns_name: DNSNameRef<'_>, + _ocsp_response: &[u8], + ) -> Result { if self.certs.len() != presented_certs.len() { - return Err(TLSError::General(format!("Mismatched number of certificates (Expected: {}, Presented: {})", self.certs.len(), presented_certs.len()))); + return Err(TLSError::General(format!( + "Mismatched number of certificates (Expected: {}, Presented: {})", + self.certs.len(), + presented_certs.len() + ))); } - + for (c, p) in self.certs.iter().zip(presented_certs.iter()) { if *p.0 != **c { - return Err(TLSError::General(format!("Server certificates do not match ours"))); + return Err(TLSError::General(format!( + "Server certificates do not match ours" + ))); } else { - #[cfg(feature = "tracing")] { + #[cfg(feature = "tracing")] + { tracing::trace!("Confirmed certificate match"); } }