Skip to content

Commit

Permalink
Support MQTT 3.1.1 and store version on the Session
Browse files Browse the repository at this point in the history
  • Loading branch information
bschwind committed Jan 13, 2020
1 parent 25cafbd commit fc1900f
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 77 deletions.
23 changes: 15 additions & 8 deletions src/broker.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
use crate::{
client::ClientMessage,
types::{
properties::AssignedClientIdentifier, ConnectAckPacket, ConnectReason, SubscribeAckPacket,
SubscribeAckReason, SubscribePacket, SubscriptionTopic,
properties::AssignedClientIdentifier, ConnectAckPacket, ConnectReason, ProtocolVersion,
SubscribeAckPacket, SubscribeAckReason, SubscribePacket, SubscriptionTopic,
},
};
use std::collections::{HashMap, HashSet};
use tokio::sync::mpsc::{self, Receiver, Sender};

pub struct Session {
pub protocol_version: ProtocolVersion,
pub subscriptions: HashSet<SubscriptionTopic>,
pub shared_subscriptions: HashSet<SubscriptionTopic>,
pub client_sender: Sender<ClientMessage>,
}

impl Session {
pub fn new(client_sender: Sender<ClientMessage>) -> Self {
Self { subscriptions: HashSet::new(), shared_subscriptions: HashSet::new(), client_sender }
pub fn new(protocol_version: ProtocolVersion, client_sender: Sender<ClientMessage>) -> Self {
Self {
protocol_version,
subscriptions: HashSet::new(),
shared_subscriptions: HashSet::new(),
client_sender,
}
}
}

#[derive(Debug)]
pub enum BrokerMessage {
NewClient(String, Sender<ClientMessage>),
NewClient(String, ProtocolVersion, Sender<ClientMessage>),
Publish,
Subscribe(String, SubscribePacket), // TODO - replace string client_id with int
Disconnect(String),
Expand All @@ -48,7 +54,7 @@ impl Broker {
pub async fn run(mut self) {
while let Some(msg) = self.receiver.recv().await {
match msg {
BrokerMessage::NewClient(client_id, mut client_msg_sender) => {
BrokerMessage::NewClient(client_id, protocol_version, mut client_msg_sender) => {
let mut session_present = false;

if let Some(mut session) = self.sessions.remove(&client_id) {
Expand All @@ -58,7 +64,7 @@ impl Broker {
let _ = session.client_sender.try_send(ClientMessage::Disconnect);
}

println!("Client ID {} connected", client_id);
println!("Client ID {} connected (Version: {:?})", client_id, protocol_version);

let connect_ack = ConnectAckPacket {
// Variable header
Expand Down Expand Up @@ -89,7 +95,8 @@ impl Broker {

let _ = client_msg_sender.try_send(ClientMessage::ConnectAck(connect_ack));

self.sessions.insert(client_id, Session::new(client_msg_sender));
self.sessions
.insert(client_id, Session::new(protocol_version, client_msg_sender));
},
BrokerMessage::Subscribe(client_id, packet) => {
if let Some(session) = self.sessions.get_mut(&client_id) {
Expand Down
18 changes: 14 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
broker::BrokerMessage,
types::{ConnectAckPacket, Packet, ProtocolError, SubscribeAckPacket},
types::{ConnectAckPacket, Packet, ProtocolError, ProtocolVersion, SubscribeAckPacket},
MqttCodec,
};
use futures::{
Expand Down Expand Up @@ -43,14 +43,22 @@ impl<T: AsyncRead + AsyncWrite + Unpin> UnconnectedClient<T> {
connect_packet.client_id
};

let protocol_version = connect_packet.protocol_version;
let self_tx = sender.clone();

self.broker_tx
.send(BrokerMessage::NewClient(client_id.clone(), sender))
.send(BrokerMessage::NewClient(client_id.clone(), protocol_version, sender))
.await
.expect("Couldn't send NewClient message to broker");

Ok(Client::new(client_id, self.framed_stream, self.broker_tx, receiver, self_tx))
Ok(Client::new(
client_id,
protocol_version,
self.framed_stream,
self.broker_tx,
receiver,
self_tx,
))
},
Some(Ok(_)) => Err(ProtocolError::FirstPacketNotConnect),
Some(Err(e)) => Err(ProtocolError::MalformedPacket(e)),
Expand All @@ -71,6 +79,7 @@ pub enum ClientMessage {

pub struct Client<T: AsyncRead + AsyncWrite + Unpin> {
id: String,
protocol_version: ProtocolVersion,
framed_stream: Framed<T, MqttCodec>,
broker_tx: Sender<BrokerMessage>,
broker_rx: Receiver<ClientMessage>,
Expand All @@ -80,12 +89,13 @@ pub struct Client<T: AsyncRead + AsyncWrite + Unpin> {
impl<T: AsyncRead + AsyncWrite + Unpin> Client<T> {
pub fn new(
id: String,
protocol_version: ProtocolVersion,
framed_stream: Framed<T, MqttCodec>,
broker_tx: Sender<BrokerMessage>,
broker_rx: Receiver<ClientMessage>,
self_tx: Sender<ClientMessage>,
) -> Self {
Self { id, framed_stream, broker_tx, broker_rx, self_tx }
Self { id, protocol_version, framed_stream, broker_tx, broker_rx, self_tx }
}

async fn handle_socket_reads(
Expand Down
97 changes: 56 additions & 41 deletions src/decoder.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::types::{
properties::*, AuthenticatePacket, AuthenticateReason, ConnectAckPacket, ConnectPacket,
ConnectReason, DecodeError, DisconnectPacket, DisconnectReason, FinalWill, Packet, PacketType,
PublishAckPacket, PublishAckReason, PublishCompletePacket, PublishCompleteReason,
PublishPacket, PublishReceivedPacket, PublishReceivedReason, PublishReleasePacket,
PublishReleaseReason, QoS, RetainHandling, SubscribeAckPacket, SubscribeAckReason,
SubscribePacket, SubscriptionTopic, UnsubscribeAckPacket, UnsubscribeAckReason,
UnsubscribePacket, VariableByteInt,
ProtocolVersion, PublishAckPacket, PublishAckReason, PublishCompletePacket,
PublishCompleteReason, PublishPacket, PublishReceivedPacket, PublishReceivedReason,
PublishReleasePacket, PublishReleaseReason, QoS, RetainHandling, SubscribeAckPacket,
SubscribeAckReason, SubscribePacket, SubscriptionTopic, UnsubscribeAckPacket,
UnsubscribeAckReason, UnsubscribePacket, VariableByteInt,
};
use bytes::{Buf, BytesMut};
use std::{convert::TryFrom, io::Cursor};
Expand Down Expand Up @@ -325,6 +325,9 @@ fn decode_connect(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Packet>, D
let connect_flags = read_u8!(bytes);
let keep_alive = read_u16!(bytes);

let protocol_version = ProtocolVersion::try_from(protocol_level)
.map_err(|_| DecodeError::InvalidProtocolVersion)?;

let mut session_expiry_interval = None;
let mut receive_maximum = None;
let mut maximum_packet_size = None;
Expand All @@ -335,20 +338,22 @@ fn decode_connect(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Packet>, D
let mut authentication_method = None;
let mut authentication_data = None;

return_if_none!(decode_properties(bytes, |property| {
match property {
Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
Property::ReceiveMaximum(p) => receive_maximum = Some(p),
Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
Property::RequestResponseInformation(p) => request_response_information = Some(p),
Property::RequestProblemInformation(p) => request_problem_information = Some(p),
Property::UserProperty(p) => user_properties.push(p),
Property::AuthenticationMethod(p) => authentication_method = Some(p),
Property::AuthenticationData(p) => authentication_data = Some(p),
_ => {}, // Invalid property for packet
}
})?);
if protocol_version == ProtocolVersion::V500 {
return_if_none!(decode_properties(bytes, |property| {
match property {
Property::SessionExpiryInterval(p) => session_expiry_interval = Some(p),
Property::ReceiveMaximum(p) => receive_maximum = Some(p),
Property::MaximumPacketSize(p) => maximum_packet_size = Some(p),
Property::TopicAliasMaximum(p) => topic_alias_maximum = Some(p),
Property::RequestResponseInformation(p) => request_response_information = Some(p),
Property::RequestProblemInformation(p) => request_problem_information = Some(p),
Property::UserProperty(p) => user_properties.push(p),
Property::AuthenticationMethod(p) => authentication_method = Some(p),
Property::AuthenticationData(p) => authentication_data = Some(p),
_ => {}, // Invalid property for packet
}
})?);
}

// Start payload
let clean_start = connect_flags & 0b0000_0010 == 0b0000_0010;
Expand All @@ -370,18 +375,20 @@ fn decode_connect(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Packet>, D
let mut correlation_data = None;
let mut user_properties = vec![];

return_if_none!(decode_properties(bytes, |property| {
match property {
Property::WillDelayInterval(p) => will_delay_interval = Some(p),
Property::PayloadFormatIndicator(p) => payload_format_indicator = Some(p),
Property::MessageExpiryInterval(p) => message_expiry_interval = Some(p),
Property::ContentType(p) => content_type = Some(p),
Property::ResponseTopic(p) => response_topic = Some(p),
Property::CorrelationData(p) => correlation_data = Some(p),
Property::UserProperty(p) => user_properties.push(p),
_ => {}, // Invalid property for packet
}
})?);
if protocol_version == ProtocolVersion::V500 {
return_if_none!(decode_properties(bytes, |property| {
match property {
Property::WillDelayInterval(p) => will_delay_interval = Some(p),
Property::PayloadFormatIndicator(p) => payload_format_indicator = Some(p),
Property::MessageExpiryInterval(p) => message_expiry_interval = Some(p),
Property::ContentType(p) => content_type = Some(p),
Property::ResponseTopic(p) => response_topic = Some(p),
Property::CorrelationData(p) => correlation_data = Some(p),
Property::UserProperty(p) => user_properties.push(p),
_ => {}, // Invalid property for packet
}
})?);
}

let topic = read_string!(bytes);
let payload = read_binary_data!(bytes);
Expand Down Expand Up @@ -416,7 +423,7 @@ fn decode_connect(bytes: &mut Cursor<&mut BytesMut>) -> Result<Option<Packet>, D

let packet = ConnectPacket {
protocol_name,
protocol_level,
protocol_version,
clean_start,
keep_alive,
session_expiry_interval,
Expand Down Expand Up @@ -736,6 +743,7 @@ fn decode_publish_complete(
fn decode_subscribe(
bytes: &mut Cursor<&mut BytesMut>,
remaining_packet_length: u32,
protocol_version: &ProtocolVersion,
) -> Result<Option<Packet>, DecodeError> {
let start_cursor_pos = bytes.position();

Expand All @@ -744,13 +752,15 @@ fn decode_subscribe(
let mut subscription_identifier = None;
let mut user_properties = vec![];

return_if_none!(decode_properties(bytes, |property| {
match property {
Property::SubscriptionIdentifier(p) => subscription_identifier = Some(p),
Property::UserProperty(p) => user_properties.push(p),
_ => {}, // Invalid property for packet
}
})?);
if *protocol_version == ProtocolVersion::V500 {
return_if_none!(decode_properties(bytes, |property| {
match property {
Property::SubscriptionIdentifier(p) => subscription_identifier = Some(p),
Property::UserProperty(p) => user_properties.push(p),
_ => {}, // Invalid property for packet
}
})?);
}

let variable_header_size = bytes.position() - start_cursor_pos;
let payload_size = remaining_packet_length as u64 - variable_header_size;
Expand Down Expand Up @@ -1006,6 +1016,7 @@ fn decode_authenticate(
}

fn decode_packet(
protocol_version: &ProtocolVersion,
packet_type: &PacketType,
bytes: &mut Cursor<&mut BytesMut>,
remaining_packet_length: u32,
Expand All @@ -1019,7 +1030,7 @@ fn decode_packet(
PacketType::PublishReceived => decode_publish_received(bytes, remaining_packet_length),
PacketType::PublishRelease => decode_publish_release(bytes, remaining_packet_length),
PacketType::PublishComplete => decode_publish_complete(bytes, remaining_packet_length),
PacketType::Subscribe => decode_subscribe(bytes, remaining_packet_length),
PacketType::Subscribe => decode_subscribe(bytes, remaining_packet_length, protocol_version),
PacketType::SubscribeAck => decode_subscribe_ack(bytes, remaining_packet_length),
PacketType::Unsubscribe => decode_unsubscribe(bytes, remaining_packet_length),
PacketType::UnsubscribeAck => decode_unsubscribe_ack(bytes, remaining_packet_length),
Expand All @@ -1030,7 +1041,10 @@ fn decode_packet(
}
}

pub fn decode_mqtt(bytes: &mut BytesMut) -> Result<Option<Packet>, DecodeError> {
pub fn decode_mqtt(
bytes: &mut BytesMut,
protocol_version: &ProtocolVersion,
) -> Result<Option<Packet>, DecodeError> {
let mut bytes = Cursor::new(bytes);
let first_byte = read_u8!(bytes);

Expand All @@ -1048,6 +1062,7 @@ pub fn decode_mqtt(bytes: &mut BytesMut) -> Result<Option<Packet>, DecodeError>
}

let packet = return_if_none!(decode_packet(
protocol_version,
&packet_type,
&mut bytes,
remaining_packet_length,
Expand Down
Loading

0 comments on commit fc1900f

Please sign in to comment.