From be0bc1ec08ef4784a355f250dc772f0da1d8a79f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= Date: Wed, 10 Jan 2024 16:37:50 -0300 Subject: [PATCH] src: connection: Fix panic when DNS lookup fails --- src/connection/mod.rs | 16 ++++++++++++++++ src/connection/tcp.rs | 15 +++++---------- src/connection/udp.rs | 22 ++++++---------------- 3 files changed, 27 insertions(+), 26 deletions(-) diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 10dacbe976..7aa4ab78ec 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -102,3 +102,19 @@ pub fn connect(address: &str) -> io::Result protocol_err } } + +/// Returns the socket address for the given address. +pub(crate) fn get_socket_addr( + address: T, +) -> Result { + let addr = match address.to_socket_addrs()?.next() { + Some(addr) => addr, + None => { + return Err(io::Error::new( + io::ErrorKind::Other, + "Host address lookup failed", + )); + } + }; + Ok(addr) +} diff --git a/src/connection/tcp.rs b/src/connection/tcp.rs index 491dff0129..b34eb5f091 100644 --- a/src/connection/tcp.rs +++ b/src/connection/tcp.rs @@ -6,6 +6,8 @@ use std::net::{TcpListener, TcpStream}; use std::sync::Mutex; use std::time::Duration; +use super::get_socket_addr; + /// TCP MAVLink connection pub fn select_protocol( @@ -26,11 +28,8 @@ pub fn select_protocol( } pub fn tcpout(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Host address lookup failed."); + let addr = get_socket_addr(address)?; + let socket = TcpStream::connect(addr)?; socket.set_read_timeout(Some(Duration::from_millis(100)))?; @@ -45,11 +44,7 @@ pub fn tcpout(address: T) -> io::Result { } pub fn tcpin(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Invalid address"); + let addr = get_socket_addr(address)?; let listener = TcpListener::bind(addr)?; //For now we only accept one incoming stream: this blocks until we get one diff --git a/src/connection/udp.rs b/src/connection/udp.rs index 7357ca99c3..1522ed20df 100644 --- a/src/connection/udp.rs +++ b/src/connection/udp.rs @@ -6,6 +6,8 @@ use std::net::ToSocketAddrs; use std::net::{SocketAddr, UdpSocket}; use std::sync::Mutex; +use super::get_socket_addr; + /// UDP MAVLink connection pub fn select_protocol( @@ -28,12 +30,8 @@ pub fn select_protocol( } pub fn udpbcast(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Invalid address"); - let socket = UdpSocket::bind("0.0.0.0:0").unwrap(); + let addr = get_socket_addr(address)?; + let socket = UdpSocket::bind("0.0.0.0:0")?; socket .set_broadcast(true) .expect("Couldn't bind to broadcast address."); @@ -41,21 +39,13 @@ pub fn udpbcast(address: T) -> io::Result { } pub fn udpout(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Invalid address"); + let addr = get_socket_addr(address)?; let socket = UdpSocket::bind("0.0.0.0:0")?; UdpConnection::new(socket, false, Some(addr)) } pub fn udpin(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Invalid address"); + let addr = get_socket_addr(address)?; let socket = UdpSocket::bind(addr)?; UdpConnection::new(socket, true, None) }