Skip to content

Commit

Permalink
add unsubscribe
Browse files Browse the repository at this point in the history
  • Loading branch information
fredszaq committed Aug 7, 2018
1 parent cfc0b80 commit 47b190c
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 40 deletions.
34 changes: 20 additions & 14 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@ use mqtt3::{QoS, ToTopicPath, TopicPath};

use mio_more::channel::*;

use std::sync::atomic::{AtomicUsize, Ordering};

use MqttOptions;

#[allow(unused)]
#[derive(DebugStub)]
pub enum Command {
Status(#[debug_stub = ""] ::std::sync::mpsc::Sender<::state::MqttConnectionStatus>),
Subscribe(Subscription),
Unsubscribe(SubscriptionToken),
Publish(Publish),
Connect,
Disconnect,
}

pub struct MqttClient {
nw_request_tx: SyncSender<Command>,
subscription_id_source: AtomicUsize
}

impl MqttClient {
Expand Down Expand Up @@ -73,6 +77,7 @@ impl MqttClient {

Ok(MqttClient {
nw_request_tx: commands_tx,
subscription_id_source: AtomicUsize::new(0),
})
}

Expand All @@ -84,14 +89,18 @@ impl MqttClient {
Ok(SubscriptionBuilder {
client: self,
it: Subscription {
id: None,
id: self.subscription_id_source.fetch_add(1, Ordering::Relaxed),
topic_path: topic_path.to_topic_path()?,
qos: ::mqtt3::QoS::AtMostOnce,
callback,
},
})
}

pub fn unsubscribe(&self, token: SubscriptionToken) -> Result<()> {
self.send_command(Command::Unsubscribe(token))
}

pub fn publish<T: ToTopicPath>(&self, topic_path: T) -> Result<PublishBuilder> {
Ok(PublishBuilder {
client: self,
Expand Down Expand Up @@ -126,7 +135,7 @@ pub type SubscriptionCallback = Box<Fn(&::mqtt3::Publish) + Send>;

#[derive(DebugStub)]
pub struct Subscription {
pub id: Option<String>,
pub id: usize,
pub topic_path: TopicPath,
pub qos: ::mqtt3::QoS,
#[debug_stub = ""] pub callback: SubscriptionCallback,
Expand All @@ -139,28 +148,25 @@ pub struct SubscriptionBuilder<'a> {
}

impl<'a> SubscriptionBuilder<'a> {
pub fn id<S: ToString>(self, s: S) -> SubscriptionBuilder<'a> {
let SubscriptionBuilder { client, it } = self;
SubscriptionBuilder {
client,
it: Subscription {
id: Some(s.to_string()),
..it
},
}
}
pub fn qos(self, qos: QoS) -> SubscriptionBuilder<'a> {
let SubscriptionBuilder { client, it } = self;
SubscriptionBuilder {
client,
it: Subscription { qos, ..it },
}
}
pub fn send(self) -> Result<()> {
self.client.send_command(Command::Subscribe(self.it))
pub fn send(self) -> Result<SubscriptionToken> {
let token = SubscriptionToken { id: self.it.id};
self.client.send_command(Command::Subscribe(self.it))?;
Ok(token)
}
}

#[derive(Debug)]
pub struct SubscriptionToken {
pub id: usize
}

#[derive(Debug)]
pub struct Publish {
pub topic: TopicPath,
Expand Down
6 changes: 6 additions & 0 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ impl ConnectionState {
self.turn_command()?;
}
mqtt3::Packet::Suback(suback) => self.mqtt_state.handle_incoming_suback(suback)?,
mqtt3::Packet::Unsuback(packet_identifier) => self.mqtt_state.handle_incoming_unsuback(packet_identifier)?,
mqtt3::Packet::Publish(publish) => {
let (_, server) = self.mqtt_state.handle_incoming_publish(publish)?;
if let Some(server) = server {
Expand Down Expand Up @@ -437,6 +438,11 @@ impl ConnectionState {
let packet = self.mqtt_state.handle_outgoing_subscribe(vec![sub])?;
self.send_packet(mqtt3::Packet::Subscribe(packet))?
}
Command::Unsubscribe(token) => {
if let Some(packet) = self.mqtt_state.handle_outgoing_unsubscribe(vec![token.id])? {
self.send_packet(mqtt3::Packet::Unsubscribe(packet))?
}
}
Command::Status(tx) => {
let _ = tx.send(self.state().status());
}
Expand Down
91 changes: 65 additions & 26 deletions src/state.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::time::{Duration, Instant};
use std::collections::VecDeque;

use error::*;
use mqtt3;
use MqttOptions;
use error::*;
use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MqttConnectionStatus {
Expand Down Expand Up @@ -41,7 +40,8 @@ pub struct MqttState {
// Even so, if broker crashes, all its state will be lost (most brokers).
// client should resubscribe it comes back up again or else the data will
// be lost
subscriptions: Vec<::client::Subscription>,
subscriptions: HashMap<usize, ::client::Subscription>,
path_usage: HashMap<String, usize>,
}

/// Design: `MqttState` methods will just modify the state of the object
Expand All @@ -60,7 +60,8 @@ impl MqttState {
last_flush: Instant::now(),
last_pkid: mqtt3::PacketIdentifier(0),
outgoing_pub: VecDeque::new(),
subscriptions: Vec::new(),
subscriptions: HashMap::new(),
path_usage: HashMap::new(),
}
}

Expand Down Expand Up @@ -97,13 +98,13 @@ impl MqttState {
use self::MqttConnectionStatus::*;
use ReconnectOptions::*;
match (self.connection_status, self.opts.reconnect) {
(Handshake { initial: true }, Always(d))
| (Handshake {..}, AfterFirstSuccess(d))
(Handshake { initial: true }, Always(d))
| (Handshake { .. }, AfterFirstSuccess(d))
| (Connected, AfterFirstSuccess(d))
| (Connected, Always(d))
| (WantConnect { .. }, AfterFirstSuccess(d))
| (WantConnect { .. }, Always(d))
=> self.connection_status = WantConnect { when: Instant::now()+d },
=> self.connection_status = WantConnect { when: Instant::now() + d },
_ => self.connection_status = Disconnected
}
}
Expand Down Expand Up @@ -131,13 +132,14 @@ impl MqttState {
} else {
let sub = if self.subscriptions.len() > 0 {
Some(mqtt3::Subscribe {
pid: self.next_pkid(),
topics: self.subscriptions.iter().map(|s| {
pid: self.next_pkid(),
topics: self.subscriptions.iter().map(|(_id, s)| {
::mqtt3::SubscribeTopic {
topic_path: s.topic_path.path.clone(),
qos: s.qos,
}}).collect()
})
}
}).collect(),
})
} else {
None
};
Expand Down Expand Up @@ -202,7 +204,7 @@ impl MqttState {
let qos = publish.qos;

let concrete = ::mqtt3::TopicPath::from_str(&publish.topic_name)?;
for sub in &self.subscriptions {
for (_id, sub) in &self.subscriptions {
if sub.topic_path.is_match(&concrete) {
(sub.callback)(&publish);
}
Expand Down Expand Up @@ -289,7 +291,12 @@ impl MqttState {
}
})
.collect();
self.subscriptions.extend(subs);
for s in &subs {
*self.path_usage.entry(s.topic_path.path.clone()).or_insert(0) += 1;
}
self.subscriptions.extend(subs.into_iter().map(|it| {
(it.id, it)
}));

if self.connection_status == MqttConnectionStatus::Connected {
Ok(mqtt3::Subscribe { pid: pkid, topics })
Expand All @@ -302,14 +309,47 @@ impl MqttState {
}
}

pub fn handle_outgoing_unsubscribe(
&mut self,
ids: Vec<usize>,
) -> Result<Option<mqtt3::Unsubscribe>> {
let mut topics = vec![];
for id in ids {
if let Some(sub) = self.subscriptions.remove(&id) {
// we unwrap here because if the value is not there, there is an error in this code
let mut path_count = self.path_usage.get_mut(&sub.topic_path.path).unwrap();
*path_count -= 1;
if *path_count == 0 { topics.push(sub.topic_path.path) }
}
}
if !topics.is_empty() {
let pkid = self.next_pkid();

if self.connection_status == MqttConnectionStatus::Connected {
Ok(Some(mqtt3::Unsubscribe { pid: pkid, topics }))
} else {
error!(
"State = {:?}. Shouldn't unsubscribe in this state",
self.connection_status
);
Err(ErrorKind::InvalidState.into())
}
} else {
Ok(None)
}
}

pub fn handle_incoming_suback(&mut self, ack: mqtt3::Suback) -> Result<()> {
if ack.return_codes
.iter()
.any(|v| *v == ::mqtt3::SubscribeReturnCodes::Failure)
{
Err(format!("rejected subscription"))?
};
{
Err(format!("rejected subscription"))?
};
Ok(())
}

pub fn handle_incoming_unsuback(&mut self, ack: mqtt3::PacketIdentifier) -> Result<()> {
Ok(())
}

Expand Down Expand Up @@ -340,14 +380,13 @@ impl MqttState {

#[cfg(test)]
mod test {
use error::*;
use mqtt3::*;
use options::MqttOptions;
use std::sync::Arc;
use std::thread;
use std::time::Duration;

use super::{MqttConnectionStatus, MqttState};
use mqtt3::*;
use options::MqttOptions;
use error::*;

#[test]
fn next_pkid_roll() {
Expand Down Expand Up @@ -546,7 +585,7 @@ mod test {
mqtt.handle_socket_disconnect();
assert_eq!(mqtt.outgoing_pub.len(), 0);
match mqtt.connection_status {
MqttConnectionStatus::WantConnect { .. } => {},
MqttConnectionStatus::WantConnect { .. } => {}
_ => panic!()
}
assert_eq!(mqtt.await_pingresp, false);
Expand Down Expand Up @@ -574,7 +613,7 @@ mod test {
mqtt.handle_socket_disconnect();
assert_eq!(mqtt.outgoing_pub.len(), 3);
match mqtt.connection_status {
MqttConnectionStatus::WantConnect { .. } => {},
MqttConnectionStatus::WantConnect { .. } => {}
_ => panic!()
}
assert_eq!(mqtt.await_pingresp, false);
Expand All @@ -585,7 +624,7 @@ mod test {
let mut mqtt = MqttState::new(MqttOptions::new("test-id", "127.0.0.1:1883"));

match mqtt.connection_status {
MqttConnectionStatus::WantConnect { .. } => {},
MqttConnectionStatus::WantConnect { .. } => {}
_ => panic!()
}
mqtt.handle_outgoing_connect(true);
Expand All @@ -609,7 +648,7 @@ mod test {

assert!(mqtt.handle_incoming_connack(connack).is_err());
match mqtt.connection_status {
MqttConnectionStatus::WantConnect { .. } => {},
MqttConnectionStatus::WantConnect { .. } => {}
_ => panic!()
}
}
Expand Down
73 changes: 73 additions & 0 deletions tests/testsuite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,79 @@ fn basic_publishes_and_subscribes() {
assert_eq!(3, final_count.load(Ordering::SeqCst));
}

#[test]
fn publishes_and_subscribes_and_unsubscribes() {
// loggerv::init_with_level(log::LogLevel::Debug);
let client_options = MqttOptions::new("pubsubunsub", MOSQUITTO_ADDR);
let count = Arc::new(AtomicUsize::new(0));
let final_count = count.clone();
let count = count.clone();

let count2 = Arc::new(AtomicUsize::new(0));
let final_count2 = count2.clone();
let count2 = count2.clone();

let request = MqttClient::start(client_options).expect("Coudn't start");
let token = request
.subscribe(
"test/pubsubunsub",
Box::new(move |_| {
count.fetch_add(1, Ordering::SeqCst);
}),
)
.unwrap()
.send()
.unwrap();

let token2 = request
.subscribe(
"test/pubsubunsub",
Box::new(move |_| {
count2.fetch_add(1, Ordering::SeqCst);
}),
)
.unwrap()
.send()
.unwrap();

let payload = format!("hello rust");
request
.publish("test/pubsubunsub")
.unwrap()
.payload(payload.clone().into_bytes())
.send()
.unwrap();

thread::sleep(Duration::from_secs(1));
request.unsubscribe(token).unwrap();
thread::sleep(Duration::from_secs(1));

request
.publish("test/pubsubunsub")
.unwrap()
.payload(payload.clone().into_bytes())
.send()
.unwrap();

thread::sleep(Duration::from_secs(1));

request.unsubscribe(token2).unwrap();
thread::sleep(Duration::from_secs(1));

request
.publish("test/pubsubunsub")
.unwrap()
.payload(payload.clone().into_bytes())
.send()
.unwrap();

thread::sleep(Duration::from_secs(1));

assert_eq!(1, final_count.load(Ordering::SeqCst));
assert_eq!(2, final_count2.load(Ordering::SeqCst));
}


#[test]
fn alive() {
// loggerv::init_with_level(log::LogLevel::Debug);
Expand Down

0 comments on commit 47b190c

Please sign in to comment.