diff --git a/README.md b/README.md index 298134e..26f2ca5 100644 --- a/README.md +++ b/README.md @@ -20,3 +20,9 @@ Testing ------- $ cargo test + +Code Format +----------- +The formatting options currently use nightly-only options. + + $ cargo +nightly fmt diff --git a/src/decoder.rs b/src/decoder.rs index 45d2d40..dc0ad1d 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -743,7 +743,7 @@ fn decode_publish_complete( fn decode_subscribe( bytes: &mut Cursor<&mut BytesMut>, remaining_packet_length: u32, - protocol_version: &ProtocolVersion, + protocol_version: ProtocolVersion, ) -> Result, DecodeError> { let start_cursor_pos = bytes.position(); @@ -752,7 +752,7 @@ fn decode_subscribe( let mut subscription_identifier = None; let mut user_properties = vec![]; - if *protocol_version == ProtocolVersion::V500 { + if protocol_version == ProtocolVersion::V500 { return_if_none!(decode_properties(bytes, |property| { match property { Property::SubscriptionIdentifier(p) => subscription_identifier = Some(p), @@ -1016,7 +1016,7 @@ fn decode_authenticate( } fn decode_packet( - protocol_version: &ProtocolVersion, + protocol_version: ProtocolVersion, packet_type: &PacketType, bytes: &mut Cursor<&mut BytesMut>, remaining_packet_length: u32, @@ -1043,7 +1043,7 @@ fn decode_packet( pub fn decode_mqtt( bytes: &mut BytesMut, - protocol_version: &ProtocolVersion, + protocol_version: ProtocolVersion, ) -> Result, DecodeError> { let mut bytes = Cursor::new(bytes); let first_byte = read_u8!(bytes); diff --git a/src/encoder.rs b/src/encoder.rs index 739db37..5215780 100644 --- a/src/encoder.rs +++ b/src/encoder.rs @@ -530,7 +530,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -562,7 +562,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -591,7 +591,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -608,7 +608,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -625,7 +625,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -642,7 +642,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -659,7 +659,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -683,7 +683,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -701,7 +701,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -718,7 +718,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -736,7 +736,7 @@ mod tests { let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -746,7 +746,7 @@ mod tests { let packet = Packet::PingRequest; let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -756,7 +756,7 @@ mod tests { let packet = Packet::PingResponse; let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -773,7 +773,7 @@ mod tests { }); let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } @@ -790,7 +790,7 @@ mod tests { }); let mut bytes = BytesMut::new(); encode_mqtt(&packet, &mut bytes); - let decoded = decode_mqtt(&mut bytes, &ProtocolVersion::V500).unwrap().unwrap(); + let decoded = decode_mqtt(&mut bytes, ProtocolVersion::V500).unwrap().unwrap(); assert_eq!(packet, decoded); } diff --git a/src/main.rs b/src/main.rs index 7470400..23518ed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,7 @@ mod broker; mod client; mod decoder; mod encoder; +mod topic; mod types; pub struct MqttCodec { @@ -28,7 +29,7 @@ impl MqttCodec { pub fn decode(&mut self, buf: &mut BytesMut) -> Result, DecodeError> { // TODO - Ideally we should keep a state machine to store the data we've read so far. - let packet = decoder::decode_mqtt(buf, &self.version); + let packet = decoder::decode_mqtt(buf, self.version); if let Ok(Some(Packet::Connect(packet))) = &packet { self.version = packet.protocol_version; diff --git a/src/topic/filter.rs b/src/topic/filter.rs new file mode 100644 index 0000000..185769c --- /dev/null +++ b/src/topic/filter.rs @@ -0,0 +1,590 @@ +use crate::topic::{ + MAX_TOPIC_LEN_BYTES, MULTI_LEVEL_WILDCARD, MULTI_LEVEL_WILDCARD_STR, + SHARED_SUBSCRIPTION_PREFIX, SINGLE_LEVEL_WILDCARD, SINGLE_LEVEL_WILDCARD_STR, TOPIC_SEPARATOR, +}; +use std::str::FromStr; + +/// A filter for subscribers to indicate which topics they want +/// to receive messages from. Can contain wildcards. +#[derive(Debug, PartialEq)] +pub enum TopicFilter { + Concrete { filter: String, level_count: u32 }, + Wildcard { filter: String, level_count: u32 }, + SharedConcrete { group_name: String, filter: String, level_count: u32 }, + SharedWildcard { group_name: String, filter: String, level_count: u32 }, +} + +/// A topic name publishers use when sending MQTT messages. +/// Cannot contain wildcards. +#[derive(Debug, PartialEq)] +pub struct Topic { + topic_name: String, + level_count: u32, +} + +#[derive(Debug, PartialEq)] +pub enum TopicLevel<'a> { + Concrete(&'a str), + SingleLevelWildcard, + MultiLevelWildcard, +} + +#[derive(Debug, PartialEq)] +pub enum TopicParseError { + EmptyTopic, + TopicTooLong, + MultilevelWildcardNotAtEnd, + InvalidWildcardLevel, + InvalidSharedGroupName, + EmptySharedGroupName, + WildcardOrNullInTopic, +} + +/// If Ok, returns (level_count, contains_wildcards). +fn process_filter(filter: &str) -> Result<(u32, bool), TopicParseError> { + let mut level_count = 0; + let mut contains_wildcards = false; + for level in filter.split(TOPIC_SEPARATOR) { + let level_contains_wildcard = + level.contains(|x: char| x == SINGLE_LEVEL_WILDCARD || x == MULTI_LEVEL_WILDCARD); + if level_contains_wildcard { + // Any wildcards on a particular level must be specified on their own + if level.len() > 1 { + return Err(TopicParseError::InvalidWildcardLevel); + } + + contains_wildcards = true; + } + + level_count += 1; + } + + Ok((level_count, contains_wildcards)) +} + +impl FromStr for TopicFilter { + type Err = TopicParseError; + + fn from_str(filter: &str) -> Result { + // Filters and topics cannot be empty + if filter.is_empty() { + return Err(TopicParseError::EmptyTopic); + } + + // TODO - assert no null character U+0000 + if filter.contains('\0') { + return Err(TopicParseError::WildcardOrNullInTopic); + } + + // Filters cannot exceed the byte length in the MQTT spec + if filter.len() > MAX_TOPIC_LEN_BYTES { + return Err(TopicParseError::TopicTooLong); + } + + // Multi-level wildcards can only be at the end of the topic + if let Some(pos) = filter.rfind(MULTI_LEVEL_WILDCARD) { + if pos != filter.len() - 1 { + return Err(TopicParseError::MultilevelWildcardNotAtEnd); + } + } + + let mut shared_group = None; + + if filter.starts_with(SHARED_SUBSCRIPTION_PREFIX) { + let filter_rest = &filter[SHARED_SUBSCRIPTION_PREFIX.len()..]; + + if filter_rest.is_empty() { + return Err(TopicParseError::EmptySharedGroupName); + } + + if let Some(slash_pos) = filter_rest.find(TOPIC_SEPARATOR) { + let shared_name = &filter_rest[0..slash_pos]; + + // slash_pos+1 is safe here, we've already validated the string + // has a nonzero length. + let shared_filter = &filter_rest[(slash_pos + 1)..]; + + if shared_name.is_empty() { + return Err(TopicParseError::EmptySharedGroupName); + } + + if shared_name + .contains(|x: char| x == SINGLE_LEVEL_WILDCARD || x == MULTI_LEVEL_WILDCARD) + { + return Err(TopicParseError::InvalidSharedGroupName); + } + + if shared_filter.is_empty() { + return Err(TopicParseError::EmptyTopic); + } + + shared_group = Some((shared_name, shared_filter)) + } else { + return Err(TopicParseError::EmptyTopic); + } + } + + let topic_filter = if let Some((group_name, shared_filter)) = shared_group { + let (level_count, contains_wildcards) = process_filter(shared_filter)?; + + if contains_wildcards { + TopicFilter::SharedWildcard { + group_name: group_name.to_string(), + filter: shared_filter.to_string(), + level_count, + } + } else { + TopicFilter::SharedConcrete { + group_name: group_name.to_string(), + filter: shared_filter.to_string(), + level_count, + } + } + } else { + let (level_count, contains_wildcards) = process_filter(filter)?; + + if contains_wildcards { + TopicFilter::Wildcard { filter: filter.to_string(), level_count } + } else { + TopicFilter::Concrete { filter: filter.to_string(), level_count } + } + }; + + Ok(topic_filter) + } +} + +impl FromStr for Topic { + type Err = TopicParseError; + + fn from_str(topic: &str) -> Result { + // TODO - Consider disallowing leading $ characters + + // Topics cannot be empty + if topic.is_empty() { + return Err(TopicParseError::EmptyTopic); + } + + // Topics cannot exceed the byte length in the MQTT spec + if topic.len() > MAX_TOPIC_LEN_BYTES { + return Err(TopicParseError::TopicTooLong); + } + + // Topics cannot contain wildcards or null characters + if topic.contains(|x: char| { + x == SINGLE_LEVEL_WILDCARD || x == MULTI_LEVEL_WILDCARD || x == '\0' + }) { + return Err(TopicParseError::WildcardOrNullInTopic); + } + + let level_count = topic.split(TOPIC_SEPARATOR).count() as u32; + + let topic = Topic { topic_name: topic.to_string(), level_count }; + + Ok(topic) + } +} + +pub struct TopicLevels<'a> { + levels_iter: std::str::Split<'a, char>, +} + +impl<'a> TopicFilter { + fn filter(&'a self) -> &'a str { + match self { + TopicFilter::Concrete { filter, .. } => filter, + TopicFilter::Wildcard { filter, .. } => filter, + TopicFilter::SharedConcrete { filter, .. } => filter, + TopicFilter::SharedWildcard { filter, .. } => filter, + } + } + + pub fn levels(&'a self) -> TopicLevels<'a> { + TopicLevels { levels_iter: self.filter().split(TOPIC_SEPARATOR) } + } +} + +impl<'a> Topic { + pub fn levels(&'a self) -> TopicLevels<'a> { + TopicLevels { levels_iter: self.topic_name.split(TOPIC_SEPARATOR) } + } +} + +impl<'a> Iterator for TopicLevels<'a> { + type Item = TopicLevel<'a>; + + fn next(&mut self) -> Option { + match self.levels_iter.next() { + Some(MULTI_LEVEL_WILDCARD_STR) => Some(TopicLevel::MultiLevelWildcard), + Some(SINGLE_LEVEL_WILDCARD_STR) => Some(TopicLevel::SingleLevelWildcard), + Some(level) => Some(TopicLevel::Concrete(level)), + None => None, + } + } +} + +#[cfg(test)] +mod tests { + use crate::topic::{Topic, TopicFilter, TopicLevel, TopicParseError, MAX_TOPIC_LEN_BYTES}; + + #[test] + fn test_topic_filter_parse_empty_topic() { + assert_eq!("".parse::().unwrap_err(), TopicParseError::EmptyTopic); + } + + #[test] + fn test_topic_filter_parse_length() { + let just_right_topic = "a".repeat(MAX_TOPIC_LEN_BYTES); + assert!(just_right_topic.parse::().is_ok()); + + let too_long_topic = "a".repeat(MAX_TOPIC_LEN_BYTES + 1); + assert_eq!( + too_long_topic.parse::().unwrap_err(), + TopicParseError::TopicTooLong + ); + } + + #[test] + fn test_topic_filter_parse_concrete() { + assert_eq!( + "/".parse::().unwrap(), + TopicFilter::Concrete { filter: "/".to_string(), level_count: 2 } + ); + + assert_eq!( + "a".parse::().unwrap(), + TopicFilter::Concrete { filter: "a".to_string(), level_count: 1 } + ); + + // $SYS topics can be subscribed to, but can't be published + assert_eq!( + "home/kitchen".parse::().unwrap(), + TopicFilter::Concrete { filter: "home/kitchen".to_string(), level_count: 2 } + ); + + assert_eq!( + "home/kitchen/temperature".parse::().unwrap(), + TopicFilter::Concrete { + filter: "home/kitchen/temperature".to_string(), + level_count: 3, + } + ); + + assert_eq!( + "home/kitchen/temperature/celsius".parse::().unwrap(), + TopicFilter::Concrete { + filter: "home/kitchen/temperature/celsius".to_string(), + level_count: 4, + } + ); + } + + #[test] + fn test_topic_filter_parse_single_level_wildcard() { + assert_eq!( + "+".parse::().unwrap(), + TopicFilter::Wildcard { filter: "+".to_string(), level_count: 1 } + ); + + assert_eq!( + "+/".parse::().unwrap(), + TopicFilter::Wildcard { filter: "+/".to_string(), level_count: 2 } + ); + + assert_eq!( + "sport/+".parse::().unwrap(), + TopicFilter::Wildcard { filter: "sport/+".to_string(), level_count: 2 } + ); + + assert_eq!( + "/+".parse::().unwrap(), + TopicFilter::Wildcard { filter: "/+".to_string(), level_count: 2 } + ); + } + + #[test] + fn test_topic_filter_parse_multi_level_wildcard() { + assert_eq!( + "#".parse::().unwrap(), + TopicFilter::Wildcard { filter: "#".to_string(), level_count: 1 } + ); + + assert_eq!( + "#/".parse::().unwrap_err(), + TopicParseError::MultilevelWildcardNotAtEnd + ); + + assert_eq!( + "/#".parse::().unwrap(), + TopicFilter::Wildcard { filter: "/#".to_string(), level_count: 2 } + ); + + assert_eq!( + "sport/#".parse::().unwrap(), + TopicFilter::Wildcard { filter: "sport/#".to_string(), level_count: 2 } + ); + + assert_eq!( + "home/kitchen/temperature/#".parse::().unwrap(), + TopicFilter::Wildcard { + filter: "home/kitchen/temperature/#".to_string(), + level_count: 4, + } + ); + } + + #[test] + fn test_topic_filter_parse_shared_subscription_concrete() { + assert_eq!( + "$share/group_a/home".parse::().unwrap(), + TopicFilter::SharedConcrete { + group_name: "group_a".to_string(), + filter: "home".to_string(), + level_count: 1, + } + ); + + assert_eq!( + "$share/group_a/home/kitchen/temperature".parse::().unwrap(), + TopicFilter::SharedConcrete { + group_name: "group_a".to_string(), + filter: "home/kitchen/temperature".to_string(), + level_count: 3, + } + ); + + assert_eq!( + "$share/group_a//".parse::().unwrap(), + TopicFilter::SharedConcrete { + group_name: "group_a".to_string(), + filter: "/".to_string(), + level_count: 2, + } + ); + } + + #[test] + fn test_topic_filter_parse_shared_subscription_wildcard() { + assert_eq!( + "$share/group_b/#".parse::().unwrap(), + TopicFilter::SharedWildcard { + group_name: "group_b".to_string(), + filter: "#".to_string(), + level_count: 1, + } + ); + + assert_eq!( + "$share/group_b/+".parse::().unwrap(), + TopicFilter::SharedWildcard { + group_name: "group_b".to_string(), + filter: "+".to_string(), + level_count: 1, + } + ); + + assert_eq!( + "$share/group_b/+/temperature".parse::().unwrap(), + TopicFilter::SharedWildcard { + group_name: "group_b".to_string(), + filter: "+/temperature".to_string(), + level_count: 2, + } + ); + + assert_eq!( + "$share/group_c/+/temperature/+/meta".parse::().unwrap(), + TopicFilter::SharedWildcard { + group_name: "group_c".to_string(), + filter: "+/temperature/+/meta".to_string(), + level_count: 4, + } + ); + } + + #[test] + fn test_topic_filter_parse_invalid_shared_subscription() { + assert_eq!( + "$share/".parse::().unwrap_err(), + TopicParseError::EmptySharedGroupName + ); + assert_eq!("$share/a".parse::().unwrap_err(), TopicParseError::EmptyTopic); + assert_eq!("$share/a/".parse::().unwrap_err(), TopicParseError::EmptyTopic); + assert_eq!( + "$share//".parse::().unwrap_err(), + TopicParseError::EmptySharedGroupName + ); + assert_eq!( + "$share///".parse::().unwrap_err(), + TopicParseError::EmptySharedGroupName + ); + + assert_eq!( + "$share/invalid_group#/#".parse::().unwrap_err(), + TopicParseError::InvalidSharedGroupName + ); + } + + #[test] + fn test_topic_filter_parse_sys_prefix() { + assert_eq!( + "$SYS/stats".parse::().unwrap(), + TopicFilter::Concrete { filter: "$SYS/stats".to_string(), level_count: 2 } + ); + + assert_eq!( + "/$SYS/stats".parse::().unwrap(), + TopicFilter::Concrete { filter: "/$SYS/stats".to_string(), level_count: 3 } + ); + + assert_eq!( + "$SYS/+".parse::().unwrap(), + TopicFilter::Wildcard { filter: "$SYS/+".to_string(), level_count: 2 } + ); + + assert_eq!( + "$SYS/#".parse::().unwrap(), + TopicFilter::Wildcard { filter: "$SYS/#".to_string(), level_count: 2 } + ); + } + + #[test] + fn test_topic_filter_parse_invalid_filters() { + assert_eq!( + "sport/#/stats".parse::().unwrap_err(), + TopicParseError::MultilevelWildcardNotAtEnd + ); + assert_eq!( + "sport/#/stats#".parse::().unwrap_err(), + TopicParseError::InvalidWildcardLevel + ); + assert_eq!( + "sport#/stats#".parse::().unwrap_err(), + TopicParseError::InvalidWildcardLevel + ); + assert_eq!( + "sport/tennis#".parse::().unwrap_err(), + TopicParseError::InvalidWildcardLevel + ); + assert_eq!( + "sport/++".parse::().unwrap_err(), + TopicParseError::InvalidWildcardLevel + ); + } + + #[test] + fn test_topic_name_success() { + assert_eq!( + "/".parse::().unwrap(), + Topic { topic_name: "/".to_string(), level_count: 2 } + ); + + assert_eq!( + "Accounts payable".parse::().unwrap(), + Topic { topic_name: "Accounts payable".to_string(), level_count: 1 } + ); + + assert_eq!( + "home/kitchen".parse::().unwrap(), + Topic { topic_name: "home/kitchen".to_string(), level_count: 2 } + ); + + assert_eq!( + "home/kitchen/temperature".parse::().unwrap(), + Topic { topic_name: "home/kitchen/temperature".to_string(), level_count: 3 } + ); + } + + #[test] + fn test_topic_name_failure() { + assert_eq!("#".parse::().unwrap_err(), TopicParseError::WildcardOrNullInTopic,); + + assert_eq!("+".parse::().unwrap_err(), TopicParseError::WildcardOrNullInTopic,); + + assert_eq!("\0".parse::().unwrap_err(), TopicParseError::WildcardOrNullInTopic,); + + assert_eq!( + "/multi/level/#".parse::().unwrap_err(), + TopicParseError::WildcardOrNullInTopic, + ); + + assert_eq!( + "/single/level/+".parse::().unwrap_err(), + TopicParseError::WildcardOrNullInTopic, + ); + + assert_eq!( + "/null/byte/\0".parse::().unwrap_err(), + TopicParseError::WildcardOrNullInTopic, + ); + } + + #[test] + fn test_topic_filter_level_iterator_simple() { + let filter: TopicFilter = "/".parse().unwrap(); + + let mut levels = filter.levels(); + + assert_eq!(levels.next(), Some(TopicLevel::Concrete(""))); + assert_eq!(levels.next(), Some(TopicLevel::Concrete(""))); + assert_eq!(levels.next(), None); + } + + #[test] + fn test_topic_filter_level_iterator_concrete() { + let filter: TopicFilter = "home/kitchen/temperature".parse().unwrap(); + + let mut levels = filter.levels(); + + assert_eq!(levels.next(), Some(TopicLevel::Concrete("home"))); + assert_eq!(levels.next(), Some(TopicLevel::Concrete("kitchen"))); + assert_eq!(levels.next(), Some(TopicLevel::Concrete("temperature"))); + assert_eq!(levels.next(), None); + } + + #[test] + fn test_topic_filter_level_iterator_single_level_wildcard_1() { + let filter: TopicFilter = "home/+/+/temperature/+".parse().unwrap(); + + let mut levels = filter.levels(); + + assert_eq!(levels.next(), Some(TopicLevel::Concrete("home"))); + assert_eq!(levels.next(), Some(TopicLevel::SingleLevelWildcard)); + assert_eq!(levels.next(), Some(TopicLevel::SingleLevelWildcard)); + assert_eq!(levels.next(), Some(TopicLevel::Concrete("temperature"))); + assert_eq!(levels.next(), Some(TopicLevel::SingleLevelWildcard)); + assert_eq!(levels.next(), None); + } + + #[test] + fn test_topic_filter_level_iterator_single_level_wildcard_2() { + let filter: TopicFilter = "+".parse().unwrap(); + + let mut levels = filter.levels(); + + assert_eq!(levels.next(), Some(TopicLevel::SingleLevelWildcard)); + assert_eq!(levels.next(), None); + } + + #[test] + fn test_topic_filter_level_iterator_mutli_level_wildcard_1() { + let filter: TopicFilter = "home/kitchen/#".parse().unwrap(); + + let mut levels = filter.levels(); + + assert_eq!(levels.next(), Some(TopicLevel::Concrete("home"))); + assert_eq!(levels.next(), Some(TopicLevel::Concrete("kitchen"))); + assert_eq!(levels.next(), Some(TopicLevel::MultiLevelWildcard)); + assert_eq!(levels.next(), None); + } + + #[test] + fn test_topic_filter_level_iterator_mutli_level_wildcard_2() { + let filter: TopicFilter = "#".parse().unwrap(); + + let mut levels = filter.levels(); + + assert_eq!(levels.next(), Some(TopicLevel::MultiLevelWildcard)); + assert_eq!(levels.next(), None); + } +} diff --git a/src/topic/mod.rs b/src/topic/mod.rs new file mode 100644 index 0000000..b342f0d --- /dev/null +++ b/src/topic/mod.rs @@ -0,0 +1,17 @@ +const TOPIC_SEPARATOR: char = '/'; + +const MULTI_LEVEL_WILDCARD: char = '#'; +const MULTI_LEVEL_WILDCARD_STR: &str = "#"; + +const SINGLE_LEVEL_WILDCARD: char = '+'; +const SINGLE_LEVEL_WILDCARD_STR: &str = "+"; + +const SHARED_SUBSCRIPTION_PREFIX: &str = "$share/"; + +pub const MAX_TOPIC_LEN_BYTES: usize = 65_535; + +mod filter; +mod tree; + +pub use filter::*; +pub use tree::*; diff --git a/src/topic/tree.rs b/src/topic/tree.rs new file mode 100644 index 0000000..0187425 --- /dev/null +++ b/src/topic/tree.rs @@ -0,0 +1,353 @@ +use crate::topic::{Topic, TopicFilter, TopicLevel}; +use std::collections::{hash_map::Entry, HashMap}; + +// TODO(bschwind) - Support shared subscriptions + +#[derive(Debug)] +pub struct SubscriptionTreeNode { + subscribers: Vec<(u64, T)>, + single_level_wildcards: Option>>, + multi_level_wildcards: Vec<(u64, T)>, + concrete_topic_levels: HashMap>, +} + +#[derive(Debug)] +pub struct SubscriptionTree { + root: SubscriptionTreeNode, + counter: u64, +} + +impl SubscriptionTree { + pub fn new() -> Self { + Self { root: SubscriptionTreeNode::new(), counter: 0 } + } + + pub fn insert(&mut self, topic_filter: &TopicFilter, value: T) -> u64 { + let counter = self.counter; + self.root.insert(topic_filter, value, counter); + self.counter += 1; + + counter + } + + pub fn matching_subscribers<'a, F: FnMut(&T)>(&'a self, topic: &Topic, sub_fn: F) { + self.root.matching_subscribers(topic, sub_fn) + } + + pub fn remove(&mut self, topic_filter: &TopicFilter, counter: u64) -> Option { + self.root.remove(topic_filter, counter) + } + + fn is_empty(&self) -> bool { + self.root.is_empty() + } +} + +// TODO(bschwind) - All these topic strings need validation before +// operating on them. + +impl SubscriptionTreeNode { + fn new() -> Self { + Self { + subscribers: Vec::new(), + single_level_wildcards: None, + multi_level_wildcards: Vec::new(), + concrete_topic_levels: HashMap::new(), + } + } + + fn is_empty(&self) -> bool { + self.subscribers.is_empty() + && self.single_level_wildcards.is_none() + && self.multi_level_wildcards.is_empty() + && self.concrete_topic_levels.is_empty() + } + + fn insert(&mut self, topic_filter: &TopicFilter, value: T, counter: u64) { + let mut current_tree = self; + let mut multi_level = false; + + for level in topic_filter.levels() { + match level { + TopicLevel::SingleLevelWildcard => { + if current_tree.single_level_wildcards.is_some() { + current_tree = current_tree.single_level_wildcards.as_mut().unwrap(); + } else { + current_tree.single_level_wildcards = + Some(Box::new(SubscriptionTreeNode::new())); + current_tree = current_tree.single_level_wildcards.as_mut().unwrap(); + } + }, + TopicLevel::MultiLevelWildcard => { + multi_level = true; + break; + }, + TopicLevel::Concrete(concrete_topic_level) => { + if current_tree.concrete_topic_levels.contains_key(concrete_topic_level) { + current_tree = current_tree + .concrete_topic_levels + .get_mut(concrete_topic_level) + .unwrap(); + } else { + current_tree + .concrete_topic_levels + .insert(concrete_topic_level.to_string(), SubscriptionTreeNode::new()); + + // TODO - Do this without another hash lookup + current_tree = current_tree + .concrete_topic_levels + .get_mut(concrete_topic_level) + .unwrap(); + } + }, + } + } + + if multi_level { + current_tree.multi_level_wildcards.push((counter, value)); + } else { + current_tree.subscribers.push((counter, value)); + } + } + + fn remove(&mut self, topic_filter: &TopicFilter, counter: u64) -> Option { + let mut current_tree = self; + let mut stack: Vec<(*mut SubscriptionTreeNode, usize)> = vec![]; + + let levels: Vec = topic_filter.levels().collect(); + let mut level_index = 0; + + for level in &levels { + match level { + TopicLevel::SingleLevelWildcard => { + if current_tree.single_level_wildcards.is_some() { + stack.push((&mut *current_tree, level_index)); + level_index += 1; + + current_tree = current_tree.single_level_wildcards.as_mut().unwrap(); + } else { + return None; + } + }, + TopicLevel::MultiLevelWildcard => { + break; + }, + TopicLevel::Concrete(concrete_topic_level) => { + if current_tree.concrete_topic_levels.contains_key(*concrete_topic_level) { + stack.push((&mut *current_tree, level_index)); + level_index += 1; + + current_tree = current_tree + .concrete_topic_levels + .get_mut(*concrete_topic_level) + .unwrap(); + } else { + return None; + } + }, + } + } + + // Get the return value + let return_val = { + let level = &levels[levels.len() - 1]; + + if *level == TopicLevel::MultiLevelWildcard { + if let Some(pos) = + current_tree.multi_level_wildcards.iter().position(|(c, _)| *c == counter) + { + Some(current_tree.multi_level_wildcards.remove(pos)) + } else { + None + } + } else if let Some(pos) = + current_tree.subscribers.iter().position(|(c, _)| *c == counter) + { + Some(current_tree.subscribers.remove(pos)) + } else { + None + } + }; + + // Go up the stack, cleaning up empty nodes + while let Some((stack_val, level_index)) = stack.pop() { + let mut tree = unsafe { &mut *stack_val }; + + let level = &levels[level_index]; + + match level { + TopicLevel::SingleLevelWildcard => { + if tree.single_level_wildcards.as_ref().map(|t| t.is_empty()).unwrap_or(false) { + tree.single_level_wildcards = None; + } + }, + TopicLevel::MultiLevelWildcard => { + // TODO - Ignore this case? + }, + TopicLevel::Concrete(concrete_topic_level) => { + if let Entry::Occupied(o) = + tree.concrete_topic_levels.entry((*concrete_topic_level).to_string()) + { + if o.get().is_empty() { + o.remove_entry(); + } + } + }, + } + } + + return_val.map(|(_, val)| val) + } + + fn matching_subscribers<'a, F: FnMut(&T)>(&'a self, topic: &Topic, mut sub_fn: F) { + let mut tree_stack = vec![]; + let levels: Vec = topic.levels().collect(); + + tree_stack.push((self, 0)); + + while !tree_stack.is_empty() { + let (current_tree, current_level) = tree_stack.pop().unwrap(); + let level = &levels[current_level]; + + for (_, subscriber) in ¤t_tree.multi_level_wildcards { + sub_fn(subscriber); + } + + if let Some(sub_tree) = ¤t_tree.single_level_wildcards { + if current_level + 1 < levels.len() { + tree_stack.push((sub_tree, current_level + 1)); + } else { + for (_, subscriber) in &sub_tree.subscribers { + sub_fn(subscriber); + } + } + } + + if let TopicLevel::Concrete(level) = level { + if current_tree.concrete_topic_levels.contains_key(*level) { + let sub_tree = current_tree.concrete_topic_levels.get(*level).unwrap(); + + if current_level + 1 < levels.len() { + let sub_tree = current_tree.concrete_topic_levels.get(*level).unwrap(); + tree_stack.push((sub_tree, current_level + 1)); + } else { + for (_, subscriber) in &sub_tree.subscribers { + sub_fn(subscriber); + } + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::topic::SubscriptionTree; + + #[test] + fn test_insert() { + let mut sub_tree = SubscriptionTree::new(); + sub_tree.insert(&"home/kitchen/temperature".parse().unwrap(), 1); + sub_tree.insert(&"home/kitchen/humidity".parse().unwrap(), 2); + sub_tree.insert(&"home/kitchen".parse().unwrap(), 3); + sub_tree.insert(&"home/+/humidity".parse().unwrap(), 4); + sub_tree.insert(&"home/+".parse().unwrap(), 5); + sub_tree.insert(&"home/#".parse().unwrap(), 6); + sub_tree.insert(&"home/+/temperature".parse().unwrap(), 7); + sub_tree.insert(&"office/stairwell/temperature".parse().unwrap(), 8); + sub_tree.insert(&"office/+/+".parse().unwrap(), 9); + sub_tree.insert(&"office/+/+/some_desk/+/fan_speed/+/temperature".parse().unwrap(), 10); + sub_tree.insert(&"office/+/+/some_desk/+/#".parse().unwrap(), 11); + sub_tree.insert(&"sport/tennis/+".parse().unwrap(), 21); + sub_tree.insert(&"#".parse().unwrap(), 12); + + println!("{:#?}", sub_tree); + + sub_tree.matching_subscribers(&"home/kitchen".parse().unwrap(), |s| { + println!("{}", s); + }); + + println!(); + + sub_tree.matching_subscribers(&"home/kitchen/humidity".parse().unwrap(), |s| { + println!("{}", s); + }); + + println!(); + + sub_tree.matching_subscribers(&"office/stairwell/temperature".parse().unwrap(), |s| { + println!("{}", s); + }); + + println!(); + + sub_tree.matching_subscribers( + &"office/tokyo/shibuya/some_desk/cpu_1/fan_speed/blade_4/temperature".parse().unwrap(), + |s| { + println!("{}", s); + }, + ); + + println!(); + + sub_tree.matching_subscribers(&"home".parse().unwrap(), |s| { + println!("{}", s); + }); + + println!(); + + sub_tree.matching_subscribers(&"sport/tennis/player1".parse().unwrap(), |s| { + println!("{}", s); + }); + + println!(); + + sub_tree.matching_subscribers(&"sport/tennis/player2".parse().unwrap(), |s| { + println!("{}", s); + }); + + println!(); + + sub_tree.matching_subscribers(&"sport/tennis/player1/ranking".parse().unwrap(), |s| { + println!("{}", s); + }); + } + + #[test] + fn test_remove() { + let mut sub_tree = SubscriptionTree::new(); + let sub_1 = sub_tree.insert(&"home/kitchen/temperature".parse().unwrap(), "sub_1"); + let sub_2 = sub_tree.insert(&"home/kitchen/temperature".parse().unwrap(), "sub_2"); + let sub_3 = sub_tree.insert(&"home/kitchen/humidity".parse().unwrap(), "sub_3"); + let sub_4 = sub_tree.insert(&"home/kitchen/#".parse().unwrap(), "sub_4"); + let sub_5 = sub_tree.insert(&"home/kitchen/+".parse().unwrap(), "sub_5"); + let sub_6 = sub_tree.insert(&"home/kitchen/+".parse().unwrap(), "sub_6"); + let sub_7 = sub_tree.insert(&"#".parse().unwrap(), "sub_7"); + + assert!(!sub_tree.is_empty()); + + assert!(sub_tree.remove(&"#".parse().unwrap(), sub_1).is_none()); + + assert_eq!( + sub_tree.remove(&"home/kitchen/temperature".parse().unwrap(), sub_1).unwrap(), + "sub_1" + ); + assert_eq!( + sub_tree.remove(&"home/kitchen/temperature".parse().unwrap(), sub_2).unwrap(), + "sub_2" + ); + assert_eq!(sub_tree.remove(&"home/kitchen/#".parse().unwrap(), sub_4).unwrap(), "sub_4"); + assert_eq!(sub_tree.remove(&"home/kitchen/+".parse().unwrap(), sub_5).unwrap(), "sub_5"); + assert_eq!( + sub_tree.remove(&"home/kitchen/humidity".parse().unwrap(), sub_3).unwrap(), + "sub_3" + ); + assert_eq!(sub_tree.remove(&"#".parse().unwrap(), sub_7).unwrap(), "sub_7"); + assert_eq!(sub_tree.remove(&"home/kitchen/+".parse().unwrap(), sub_6).unwrap(), "sub_6"); + + assert!(sub_tree.is_empty()); + + assert!(sub_tree.remove(&"home/kitchen/+".parse().unwrap(), sub_6).is_none()); + } +}