diff --git a/src/lib.rs b/src/lib.rs index 8a68435..710137d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,6 +74,8 @@ //! - If we don't know what to do with an argument we use [`return Err(arg.unexpected())`][Arg::unexpected] to turn it into an error message. //! - Strings can be promoted to errors for custom error messages. +#![feature(slice_range)] +#![deny(unsafe_code)] #![warn(missing_docs, missing_debug_implementations, elided_lifetimes_in_paths)] #![allow(clippy::should_implement_trait)] @@ -84,6 +86,10 @@ use std::{ str::{FromStr, Utf8Error}, }; +mod os_str_slice; + +use os_str_slice::OsStrSlice; + type InnerIter = std::vec::IntoIter; fn make_iter(iter: impl Iterator) -> InnerIter { @@ -109,10 +115,9 @@ enum State { PendingValue(OsString), /// We're in the middle of -abc. /// - /// In order to satisfy OsString::from_encoded_bytes_unchecked() we make - /// sure that the usize always point to the end of a valid UTF-8 substring. - /// This is a safety invariant! - Shorts(Vec, usize), + /// In order to satisfy OsStr::slice_encoded_bytes() we make sure that the + /// usize always point to the end of a valid UTF-8 substring. + Shorts(OsString, usize), /// We saw -- and know no more options are coming. FinishedOpts, } @@ -170,7 +175,7 @@ impl Parser { State::Shorts(ref arg, ref mut pos) => { // We're somewhere inside a -abc chain. Because we're in .next(), // not .value(), we can assume that the next character is another option. - match first_codepoint(&arg[*pos..]) { + match first_codepoint(&arg.as_encoded_bytes()[*pos..]) { Ok(None) => { self.state = State::None; } @@ -186,14 +191,13 @@ impl Parser { }); } Ok(Some(ch)) => { - // SAFETY: pos still points to the end of a valid UTF-8 codepoint. *pos += ch.len_utf8(); self.last_option = LastOption::Short(ch); return Ok(Some(Arg::Short(ch))); } Err(_) => { // Skip the rest of the argument. This makes it easy to maintain the - // OsString invariants, and the caller is almost certainly going to + // OsString invariant, and the caller is almost certainly going to // abort anyway. self.state = State::None; self.last_option = LastOption::Short('�'); @@ -212,7 +216,7 @@ impl Parser { ref state => panic!("unexpected state {:?}", state), } - let arg = match self.source.next() { + let mut arg = match self.source.next() { Some(arg) => arg, None => return Ok(None), }; @@ -222,33 +226,19 @@ impl Parser { return self.next(); } - if arg.as_encoded_bytes().starts_with(b"--") { - let mut arg = arg.into_encoded_bytes(); - + let arg_bytes = arg.as_encoded_bytes(); + if arg_bytes.starts_with(b"--") { // Long options have two forms: --option and --option=value. - if let Some(ind) = arg.iter().position(|&b| b == b'=') { + if let Some(ind) = arg_bytes.iter().position(|&b| b == b'=') { // The value can be an OsString... - let value = arg[ind + 1..].to_vec(); - - // SAFETY: this substring comes immediately after a valid UTF-8 sequence - // (i.e. the equals sign), and it originates from bytes we obtained from - // an OsString just now. - let value = unsafe { OsString::from_encoded_bytes_unchecked(value) }; + let value = arg.slice_encoded_bytes(ind + 1..).to_owned(); self.state = State::PendingValue(value); - arg.truncate(ind); + arg = arg.slice_encoded_bytes(..ind).to_owned(); } // ...but the option has to be a string. - // Transform arg back into an OsString so we can use the platform-specific - // to_string_lossy() implementation. - // (In particular: String::from_utf8_lossy() turns a WTF-8 lone surrogate - // into three replacement characters instead of one.) - // SAFETY: arg is either an unmodified OsString or one we truncated - // right before a valid UTF-8 sequence ("="). - let arg = unsafe { OsString::from_encoded_bytes_unchecked(arg) }; - // Calling arg.to_string_lossy().into_owned() would work, but because // the return type is Cow this would perform an unnecessary copy in // the common case where arg is already UTF-8. @@ -259,9 +249,7 @@ impl Parser { Err(arg) => arg.to_string_lossy().into_owned(), }; Ok(Some(self.set_long(option))) - } else if arg.as_encoded_bytes().len() > 1 && arg.as_encoded_bytes()[0] == b'-' { - let arg = arg.into_encoded_bytes(); - // SAFETY: 1 points at the end of the dash. + } else if arg_bytes.len() > 1 && arg_bytes[0] == b'-' { self.state = State::Shorts(arg, 1); self.next() } else { @@ -528,24 +516,19 @@ impl Parser { fn raw_optional_value(&mut self) -> Option<(OsString, bool)> { match replace(&mut self.state, State::None) { State::PendingValue(value) => Some((value, true)), - State::Shorts(mut arg, mut pos) => { + State::Shorts(arg, mut pos) => { if pos >= arg.len() { return None; } let mut had_eq_sign = false; - if arg[pos] == b'=' { + if arg.as_encoded_bytes()[pos] == b'=' { // -o=value. // clap actually strips out all leading '='s, but that seems silly. // We allow `-xo=value`. Python's argparse doesn't strip the = in that case. - // SAFETY: pos now points to the end of the '='. pos += 1; had_eq_sign = true; } - arg.drain(..pos); // Reuse allocation - - // SAFETY: arg originates from an OsString. We ensure that pos always - // points to a valid UTF-8 boundary. - let value = unsafe { OsString::from_encoded_bytes_unchecked(arg) }; + let value = arg.slice_encoded_bytes(pos..).to_owned(); Some((value, had_eq_sign)) } State::FinishedOpts => { diff --git a/src/os_str_slice.rs b/src/os_str_slice.rs new file mode 100644 index 0000000..ed351e5 --- /dev/null +++ b/src/os_str_slice.rs @@ -0,0 +1,77 @@ +#![allow(unsafe_code)] +use std::ffi::OsStr; +use std::ops::RangeBounds; + +pub(crate) trait OsStrSlice { + /// Takes a substring based on a range that corresponds to the return value of + /// [`OsStr::as_encoded_bytes`]. + /// + /// The range's start and end must lie on valid `OsStr` boundaries, meaning one of: + /// - The start of the string + /// - The end of the string + /// - Immediately before a valid non-empty UTF-8 substring + /// - Immediately after a valid non-empty UTF-8 substring + /// + /// This requirement holds even on platforms where the underlying encoding is more + /// permissive. + /// + /// # Panics + /// + /// Panics if the range does not lie on valid `OsStr` boundaries. + /// + /// # Example + /// + /// ```ignore + /// use std::ffi::OsStr; + /// + /// let os_str = OsStr::new("foo=bar"); + /// let bytes = os_str.as_encoded_bytes(); + /// if let Some(index) = bytes.iter().position(|b| *b == b'=') { + /// let key = os_str.slice_encoded_bytes(..index); + /// let value = os_str.slice_encoded_bytes(index + 1..); + /// assert_eq!(key, "foo"); + /// assert_eq!(value, "bar"); + /// } + /// ``` + fn slice_encoded_bytes>(&self, range: R) -> &Self; +} + +impl OsStrSlice for OsStr { + fn slice_encoded_bytes>(&self, range: R) -> &Self { + fn is_valid_boundary(bytes: &[u8], index: usize) -> bool { + if index == 0 || index == bytes.len() { + return true; + } + + let (before, after) = bytes.split_at(index); + + // UTF-8 takes at most 4 bytes per codepoint, so we don't + // need to check more than that. + let after = after.get(..4).unwrap_or(after); + match std::str::from_utf8(after) { + Ok(_) => return true, + Err(err) if err.valid_up_to() != 0 => return true, + Err(_) => (), + } + + for len in 1..=4.min(index) { + let before = &before[index - len..]; + if std::str::from_utf8(before).is_ok() { + return true; + } + } + + false + } + + let bytes = self.as_encoded_bytes(); + let range = std::slice::range(range, ..bytes.len()); + assert!(is_valid_boundary(bytes, range.start)); + assert!(is_valid_boundary(bytes, range.end)); + + // SAFETY: bytes was obtained from an OsStr just now, and we validated + // that we only slice immediately before or after a valid non-empty + // UTF-8 substring. + unsafe { Self::from_encoded_bytes_unchecked(&bytes[range]) } + } +}