Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add read_until_slice #6531

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions tokio/src/io/util/async_buf_read_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::io::util::fill_buf::{fill_buf, FillBuf};
use crate::io::util::lines::{lines, Lines};
use crate::io::util::read_line::{read_line, ReadLine};
use crate::io::util::read_until::{read_until, ReadUntil};
use crate::io::util::read_until_slice::{read_until_slice, ReadUntilSlice};
use crate::io::util::split::{split, Split};
use crate::io::AsyncBufRead;

Expand Down Expand Up @@ -100,6 +101,96 @@ cfg_io_util! {
read_until(self, byte, buf)
}

/// Reads all bytes into `buf` until the delimiter or EOF is reached.
///
/// Equivalent to:
///
/// ```ignore
/// async fn read_until_slice<'a, 'b>(&'a mut self, delimiter: &'b [u8], buf: &'a mut Vec<u8>) -> io::Result<usize>;
/// ```
///
/// This function will read bytes from the underlying stream until the
/// delimiter or EOF is found. Once found, all bytes up to, and including,
/// the delimiter (if found) will be appended to `buf`.
///
/// If successful, this function will return the total number of bytes read.
///
/// If this function returns `Ok(0)`, the stream has reached EOF.
///
/// # Errors
///
/// This function will ignore all instances of [`ErrorKind::Interrupted`] and
/// will otherwise return any errors returned by [`fill_buf`].
///
/// If an I/O error is encountered then all bytes read so far will be
/// present in `buf` and its length will have been adjusted appropriately.
///
/// [`fill_buf`]: AsyncBufRead::poll_fill_buf
/// [`ErrorKind::Interrupted`]: std::io::ErrorKind::Interrupted
///
/// # Cancel safety
///
/// This method is not cancellation safe. If the method is used as the
/// event in a [`tokio::select!`](crate::select) statement and some
/// other branch completes first, then it is not guaranted to find the
/// delimiter in the next call to `read_until_slice`.
///
/// This function does not behave like [`read_until`] because the
/// delimiter can be split across multiple read calls and thus might
/// haven partially read in a previous call to `read_until_slice`.
///
/// # Examples
///
/// [`std::io::Cursor`][`Cursor`] is a type that implements `BufRead`. In
/// this example, we use [`Cursor`] to read all the bytes in a byte slice
/// in hyphen delimited segments:
///
/// [`Cursor`]: std::io::Cursor
///
/// ```
/// use tokio::io::AsyncBufReadExt;
///
/// use std::io::Cursor;
///
/// #[tokio::main]
/// async fn main() {
/// let mut cursor = Cursor::new(b"lorem\r\nipsum");
/// let mut buf = vec![];
///
/// // cursor is at 'l'
/// let num_bytes = cursor.read_until_slice(b"\r\n", &mut buf)
/// .await
/// .expect("reading from cursor won't fail");
///
/// assert_eq!(num_bytes, 7);
/// assert_eq!(buf, b"lorem\r\n");
/// buf.clear();
///
/// // cursor is at 'i'
/// let num_bytes = cursor.read_until_slice(b"\r\n", &mut buf)
/// .await
/// .expect("reading from cursor won't fail");
///
/// assert_eq!(num_bytes, 5);
/// assert_eq!(buf, b"ipsum");
/// buf.clear();
///
/// // cursor is at EOF
/// let num_bytes = cursor.read_until_slice(b"\r\n", &mut buf)
/// .await
/// .expect("reading from cursor won't fail");
/// assert_eq!(num_bytes, 0);
/// assert_eq!(buf, b"");
/// }
/// ```
fn read_until_slice<'a, 'b>(&'a mut self, delimiter: &'b [u8], buf: &'a mut Vec<u8>) -> ReadUntilSlice<'a, 'b, Self>
where
Self: Unpin,
{
read_until_slice(self, delimiter, buf)
}


/// Reads all bytes until a newline (the 0xA byte) is reached, and append
/// them to the provided buffer.
///
Expand Down
1 change: 1 addition & 0 deletions tokio/src/io/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ cfg_io_util! {

mod read_to_string;
mod read_until;
mod read_until_slice;

mod repeat;
pub use repeat::{repeat, Repeat};
Expand Down
157 changes: 157 additions & 0 deletions tokio/src/io/util/read_until_slice.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
use crate::io::AsyncBufRead;

use pin_project_lite::pin_project;
use std::future::Future;
use std::io;
use std::marker::PhantomPinned;
use std::mem;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// Future for the [`read_until_slice`](super::AsyncReadExt::read_to_slice) method.
/// The delimiter is included in the resulting vector.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadUntilSlice<'a, 'b, R: ?Sized> {
reader: &'a mut R,
delimiter: &'b [u8],
buf: &'a mut Vec<u8>,
// The number of bytes appended to buf. This can be less than buf.len() if
// the buffer was not empty when the operation was started.
read: usize,
// Make this future `!Unpin` for compatibility with async trait methods.
#[pin]
_pin: PhantomPinned,
}
}

pub(crate) fn read_until_slice<'a, 'b, R>(
reader: &'a mut R,
delimiter: &'b [u8],
buf: &'a mut Vec<u8>,
) -> ReadUntilSlice<'a, 'b, R>
where
R: AsyncBufRead + ?Sized + Unpin,
{
ReadUntilSlice {
reader,
delimiter,
buf,
read: 0,
_pin: PhantomPinned,
}
}

pub(crate) fn read_until_slice_internal<R: AsyncBufRead + ?Sized>(
mut reader: Pin<&mut R>,
cx: &mut Context<'_>,
delimiter: &'_ [u8],
buf: &mut Vec<u8>,
read: &mut usize,
) -> Poll<io::Result<usize>> {
let mut match_len = 0usize;
loop {
let (done, used) = {
let available = ready!(reader.as_mut().poll_fill_buf(cx))?;
if let Some(i) = search(delimiter, available, &mut match_len) {
buf.extend_from_slice(&available[..i]);
(true, i)
} else {
buf.extend_from_slice(available);
(false, available.len())
}
};
reader.as_mut().consume(used);
*read += used;
if done || used == 0 {
return Poll::Ready(Ok(mem::replace(read, 0)));
}
}
}

/// Returns the first index matching the `needle` in the `haystack`. `match_len` specifies how
/// many bytes from the needle were already matched during the previous lookup.
/// If we reach the end of the `haystack` with a partial match, then this is a partial match,
/// and we update the `match_len` value accordingly, even though we still return `None`.
fn search(needle: &[u8], haystack: &[u8], match_len: &mut usize) -> Option<usize> {
let haystack_len = haystack.len();
let needle_len = needle.len();
#[allow(clippy::needless_range_loop)]
for i in 0..haystack_len {
if haystack[i] == needle[*match_len] {
*match_len += 1;
if *match_len == needle_len {
return Some(i + 1);
}
} else if *match_len > 0 {
*match_len = 0;
}
}
None
}
Comment on lines +73 to +92
Copy link
Contributor

@Darksonn Darksonn May 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, this logic is really tricky. It's going to need really good tests. Probably something that uses some loops to try a lot of different situations. For example, for a few different lengths of search strings, try to split some input in all possible places and ensure that you get the same result no matter how the input is split.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed but I am testing it in our app already and I didnt catch a bug yet so I am confident it is correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Famous last words. :)


impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntilSlice<'_, '_, R> {
type Output = io::Result<usize>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = self.project();
read_until_slice_internal(Pin::new(*me.reader), cx, me.delimiter, me.buf, me.read)
}
}

#[cfg(test)]
mod tests {
use super::search;

#[test]
fn search_test() {
let haystack = b"123abc456\0\xffabc\n";
let mut match_len = 0;

assert_eq!(search(b"ab", haystack, &mut match_len), Some(5));
assert_eq!(match_len, 2);
match_len = 0;
assert_eq!(search(&[0xff], haystack, &mut match_len), Some(11));
assert_eq!(match_len, 1);
match_len = 0;
assert_eq!(search(b"\n", haystack, &mut match_len), Some(15));
assert_eq!(match_len, 1);
match_len = 0;
assert_eq!(search(b"\r", haystack, &mut match_len), None);
assert_eq!(match_len, 0);
}

#[test]
fn split_needle_test() {
let haystack1 = b"123abc\r";
let haystack2 = b"\n987gfd";
let mut match_len = 0;

assert_eq!(search(b"\r\n", haystack1, &mut match_len), None);
assert_eq!(match_len, 1);
assert_eq!(search(b"\r\n", haystack2, &mut match_len), Some(1));
assert_eq!(match_len, 2);
}

#[test]
fn invalid_needle_test() {
let haystack1 = b"123abc\r";
let haystack2 = b"a\n987gfd";
let mut match_len = 0;

assert_eq!(search(b"\r\n", haystack1, &mut match_len), None);
assert_eq!(match_len, 1);
assert_eq!(search(b"\r\n", haystack2, &mut match_len), None);
assert_eq!(match_len, 0);
}

#[test]
fn small_haystack_test() {
let haystack = b"\r";
let mut match_len = 0;

assert_eq!(search(b"\r\n", haystack, &mut match_len), None);
assert_eq!(match_len, 1);
}
}
Loading
Loading