Skip to content

Commit

Permalink
Implement OsStr::slice_encoded_bytes() proof of concept
Browse files Browse the repository at this point in the history
  • Loading branch information
blyxxyz committed Nov 28, 2023
1 parent 2ed9095 commit 8077851
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 56 deletions.
83 changes: 27 additions & 56 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
//! - If we don't know what to do with an argument we use [`return Err(arg.unexpected())`][Arg::unexpected] to turn it into an error message.
//! - Strings can be promoted to errors for custom error messages.

#![feature(slice_range)]
#![deny(unsafe_code)]
#![warn(missing_docs, missing_debug_implementations, elided_lifetimes_in_paths)]
#![allow(clippy::should_implement_trait)]

Expand All @@ -84,6 +86,10 @@ use std::{
str::{FromStr, Utf8Error},
};

mod os_str_slice;

use os_str_slice::OsStrSlice;

type InnerIter = std::vec::IntoIter<OsString>;

fn make_iter(iter: impl Iterator<Item = OsString>) -> InnerIter {
Expand All @@ -109,10 +115,9 @@ enum State {
PendingValue(OsString),
/// We're in the middle of -abc.
///
/// In order to satisfy OsString::from_encoded_bytes_unchecked() we make
/// sure that the usize always point to the end of a valid UTF-8 substring.
/// This is a safety invariant!
Shorts(Vec<u8>, usize),
/// In order to satisfy OsStr::slice_encoded_bytes() we make sure that the
/// usize always point to the end of a valid UTF-8 substring.
Shorts(OsString, usize),
/// We saw -- and know no more options are coming.
FinishedOpts,
}
Expand Down Expand Up @@ -170,7 +175,7 @@ impl Parser {
State::Shorts(ref arg, ref mut pos) => {
// We're somewhere inside a -abc chain. Because we're in .next(),
// not .value(), we can assume that the next character is another option.
match first_codepoint(&arg[*pos..]) {
match first_codepoint(&arg.as_encoded_bytes()[*pos..]) {
Ok(None) => {
self.state = State::None;
}
Expand All @@ -186,14 +191,13 @@ impl Parser {
});
}
Ok(Some(ch)) => {
// SAFETY: pos still points to the end of a valid UTF-8 codepoint.
*pos += ch.len_utf8();
self.last_option = LastOption::Short(ch);
return Ok(Some(Arg::Short(ch)));
}
Err(_) => {
// Skip the rest of the argument. This makes it easy to maintain the
// OsString invariants, and the caller is almost certainly going to
// OsString invariant, and the caller is almost certainly going to
// abort anyway.
self.state = State::None;
self.last_option = LastOption::Short('�');
Expand Down Expand Up @@ -222,46 +226,27 @@ impl Parser {
return self.next();
}

if arg.as_encoded_bytes().starts_with(b"--") {
let mut arg = arg.into_encoded_bytes();

let arg_bytes = arg.as_encoded_bytes();
if arg_bytes.starts_with(b"--") {
let mut arg = arg.as_os_str();
// Long options have two forms: --option and --option=value.
if let Some(ind) = arg.iter().position(|&b| b == b'=') {
if let Some(ind) = arg_bytes.iter().position(|&b| b == b'=') {
// The value can be an OsString...
let value = arg[ind + 1..].to_vec();

// SAFETY: this substring comes immediately after a valid UTF-8 sequence
// (i.e. the equals sign), and it originates from bytes we obtained from
// an OsString just now.
let value = unsafe { OsString::from_encoded_bytes_unchecked(value) };
let value = arg.slice_encoded_bytes(ind + 1..).to_owned();

self.state = State::PendingValue(value);
arg.truncate(ind);
arg = arg.slice_encoded_bytes(..ind)
}

// ...but the option has to be a string.

// Transform arg back into an OsString so we can use the platform-specific
// to_string_lossy() implementation.
// (In particular: String::from_utf8_lossy() turns a WTF-8 lone surrogate
// into three replacement characters instead of one.)
// SAFETY: arg is either an unmodified OsString or one we truncated
// right before a valid UTF-8 sequence ("=").
let arg = unsafe { OsString::from_encoded_bytes_unchecked(arg) };

// Calling arg.to_string_lossy().into_owned() would work, but because
// the return type is Cow this would perform an unnecessary copy in
// the common case where arg is already UTF-8.
// reqwest does a similar maneuver more efficiently with unsafe:
// https://github.com/seanmonstar/reqwest/blob/e6a1a09f0904e06de4ff1317278798c4ed28af66/src/async_impl/response.rs#L194
let option = match arg.into_string() {
Ok(text) => text,
Err(arg) => arg.to_string_lossy().into_owned(),
let arg = arg.to_string_lossy().into_owned();
self.last_option = LastOption::Long(arg);
let long = match self.last_option {
LastOption::Long(ref option) => &option[2..],
_ => unreachable!(),
};
Ok(Some(self.set_long(option)))
} else if arg.as_encoded_bytes().len() > 1 && arg.as_encoded_bytes()[0] == b'-' {
let arg = arg.into_encoded_bytes();
// SAFETY: 1 points at the end of the dash.
Ok(Some(Arg::Long(long)))
} else if arg_bytes.len() > 1 && arg_bytes[0] == b'-' {
self.state = State::Shorts(arg, 1);
self.next()
} else {
Expand Down Expand Up @@ -528,24 +513,19 @@ impl Parser {
fn raw_optional_value(&mut self) -> Option<(OsString, bool)> {
match replace(&mut self.state, State::None) {
State::PendingValue(value) => Some((value, true)),
State::Shorts(mut arg, mut pos) => {
State::Shorts(arg, mut pos) => {
if pos >= arg.len() {
return None;
}
let mut had_eq_sign = false;
if arg[pos] == b'=' {
if arg.as_encoded_bytes()[pos] == b'=' {
// -o=value.
// clap actually strips out all leading '='s, but that seems silly.
// We allow `-xo=value`. Python's argparse doesn't strip the = in that case.
// SAFETY: pos now points to the end of the '='.
pos += 1;
had_eq_sign = true;
}
arg.drain(..pos); // Reuse allocation

// SAFETY: arg originates from an OsString. We ensure that pos always
// points to a valid UTF-8 boundary.
let value = unsafe { OsString::from_encoded_bytes_unchecked(arg) };
let value = arg.slice_encoded_bytes(pos..).to_owned();
Some((value, had_eq_sign))
}
State::FinishedOpts => {
Expand Down Expand Up @@ -619,15 +599,6 @@ impl Parser {
{
Parser::new(None, make_iter(args.into_iter().map(Into::into)))
}

/// Store a long option so the caller can borrow it.
fn set_long(&mut self, option: String) -> Arg<'_> {
self.last_option = LastOption::Long(option);
match self.last_option {
LastOption::Long(ref option) => Arg::Long(&option[2..]),
_ => unreachable!(),
}
}
}

impl Arg<'_> {
Expand Down
87 changes: 87 additions & 0 deletions src/os_str_slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#![allow(unsafe_code)]
use std::ffi::OsStr;
use std::ops::RangeBounds;

pub(crate) trait OsStrSlice {
/// Takes a substring based on a range that corresponds to the return value of
/// [`OsStr::as_encoded_bytes`].
///
/// The range's start and end must lie on valid `OsStr` boundaries.
///
/// On Unix any boundaries are valid, as OS strings may contain arbitrary bytes.
///
/// On other platforms such as Windows the internal encoding is currently
/// unspecified, and a valid `OsStr` boundary is one of:
/// - The start of the string
/// - The end of the string
/// - Immediately before a valid non-empty UTF-8 substring
/// - Immediately after a valid non-empty UTF-8 substring
///
/// # Panics
///
/// Panics if `range` does not lie on valid `OsStr` boundaries or if it
/// exceeds the end of the string.
///
/// # Example
///
/// ```ignore
/// use std::ffi::OsStr;
///
/// let os_str = OsStr::new("foo=bar");
/// let bytes = os_str.as_encoded_bytes();
/// if let Some(index) = bytes.iter().position(|b| *b == b'=') {
/// let key = os_str.slice_encoded_bytes(..index);
/// let value = os_str.slice_encoded_bytes(index + 1..);
/// assert_eq!(key, "foo");
/// assert_eq!(value, "bar");
/// }
/// ```
fn slice_encoded_bytes<R: RangeBounds<usize>>(&self, range: R) -> &Self;
}

impl OsStrSlice for OsStr {
fn slice_encoded_bytes<R: RangeBounds<usize>>(&self, range: R) -> &Self {
let bytes = self.as_encoded_bytes();
let range = std::slice::range(range, ..bytes.len());

#[cfg(unix)]
return std::os::unix::ffi::OsStrExt::from_bytes(&bytes[range]);

#[cfg(not(unix))]
{
fn is_valid_boundary(bytes: &[u8], index: usize) -> bool {
if index == 0 || index == bytes.len() {
return true;
}

let (before, after) = bytes.split_at(index);

// UTF-8 takes at most 4 bytes per codepoint, so we don't
// need to check more than that.
let after = after.get(..4).unwrap_or(after);
match std::str::from_utf8(after) {
Ok(_) => return true,
Err(err) if err.valid_up_to() != 0 => return true,
Err(_) => (),
}

for len in 1..=4.min(index) {
let before = &before[index - len..];
if std::str::from_utf8(before).is_ok() {
return true;
}
}

false
}

assert!(is_valid_boundary(bytes, range.start));
assert!(is_valid_boundary(bytes, range.end));

// SAFETY: bytes was obtained from an OsStr just now, and we validated
// that we only slice immediately before or after a valid non-empty
// UTF-8 substring.
unsafe { Self::from_encoded_bytes_unchecked(&bytes[range]) }
}
}
}

0 comments on commit 8077851

Please sign in to comment.