From aca53acddc200ff53350760d99966fb0b87afaa5 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Wed, 23 Oct 2024 11:49:33 +1100 Subject: [PATCH 01/16] First stage of streaming implementation --- bhttp-convert/Cargo.toml | 2 +- bhttp/Cargo.toml | 14 +-- bhttp/src/err.rs | 12 +- bhttp/src/lib.rs | 73 ++++------- bhttp/src/parse.rs | 23 ++-- bhttp/src/rw.rs | 58 ++++----- bhttp/src/stream/mod.rs | 233 ++++++++++++++++++++++++++++++++++++ bhttp/tests/test.rs | 2 +- ohttp-client-cli/Cargo.toml | 2 +- ohttp-client/Cargo.toml | 2 +- ohttp-server/Cargo.toml | 2 +- ohttp/build.rs | 20 ++-- 12 files changed, 307 insertions(+), 136 deletions(-) create mode 100644 bhttp/src/stream/mod.rs diff --git a/bhttp-convert/Cargo.toml b/bhttp-convert/Cargo.toml index 8a82101..47e11f1 100644 --- a/bhttp-convert/Cargo.toml +++ b/bhttp-convert/Cargo.toml @@ -9,4 +9,4 @@ structopt = "0.3" [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "http"] +features = ["http"] diff --git a/bhttp/Cargo.toml b/bhttp/Cargo.toml index 8c89536..14aff2b 100644 --- a/bhttp/Cargo.toml +++ b/bhttp/Cargo.toml @@ -9,17 +9,15 @@ description = "Binary HTTP messages (RFC 9292)" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["bhttp"] -bhttp = ["read-bhttp", "write-bhttp"] -http = ["read-http", "write-http"] -read-bhttp = [] -write-bhttp = [] -read-http = ["url"] -write-http = [] +default = ["stream"] +http = ["url"] +stream = ["futures", "pin-project"] [dependencies] +futures = {version = "0.3", optional = true} +pin-project = {version = "1.1", optional = true} thiserror = "1" url = {version = "2", optional = true} [dev-dependencies] -hex = "0.4" +hex = "0.4" \ No newline at end of file diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index 19d5455..3a457da 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -1,6 +1,4 @@ -use thiserror::Error; - -#[derive(Error, Debug)] +#[derive(thiserror::Error, Debug)] pub enum Error { #[error("a request used the CONNECT method")] ConnectUnsupported, @@ -34,14 +32,8 @@ pub enum Error { #[error("a message included the Upgrade field")] UpgradeUnsupported, #[error("a URL could not be parsed into components: {0}")] - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] UrlParse(#[from] url::ParseError), } -#[cfg(any( - feature = "read-http", - feature = "write-http", - feature = "read-bhttp", - feature = "write-bhttp" -))] pub type Res = Result; diff --git a/bhttp/src/lib.rs b/bhttp/src/lib.rs index 3c8fbde..f92b2c4 100644 --- a/bhttp/src/lib.rs +++ b/bhttp/src/lib.rs @@ -1,46 +1,26 @@ #![deny(warnings, clippy::pedantic)] #![allow(clippy::missing_errors_doc)] // Too lazy to document these. -#[cfg(feature = "read-bhttp")] -use std::convert::TryFrom; -#[cfg(any( - feature = "read-http", - feature = "write-http", - feature = "read-bhttp", - feature = "write-bhttp" -))] -use std::io; - -#[cfg(feature = "read-http")] +use std::{borrow::BorrowMut, io}; + +#[cfg(feature = "http")] use url::Url; mod err; mod parse; -#[cfg(any(feature = "read-bhttp", feature = "write-bhttp"))] mod rw; - -#[cfg(any(feature = "read-http", feature = "read-bhttp",))] -use std::borrow::BorrowMut; +#[cfg(feature = "stream")] +pub mod stream; pub use err::Error; -#[cfg(any( - feature = "read-http", - feature = "write-http", - feature = "read-bhttp", - feature = "write-bhttp" -))] use err::Res; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] use parse::{downcase, is_ows, read_line, split_at, COLON, SEMICOLON, SLASH, SP}; use parse::{index_of, trim_ows, COMMA}; -#[cfg(feature = "read-bhttp")] -use rw::{read_varint, read_vec}; -#[cfg(feature = "write-bhttp")] -use rw::{write_len, write_varint, write_vec}; +use rw::{read_varint, read_vec, write_len, write_varint, write_vec}; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] const CONTENT_LENGTH: &[u8] = b"content-length"; -#[cfg(feature = "read-bhttp")] const COOKIE: &[u8] = b"cookie"; const TRANSFER_ENCODING: &[u8] = b"transfer-encoding"; const CHUNKED: &[u8] = b"chunked"; @@ -93,7 +73,6 @@ impl ReadSeek for io::Cursor where T: AsRef<[u8]> {} impl ReadSeek for io::BufReader where T: io::Read + io::Seek {} #[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg(any(feature = "read-bhttp", feature = "write-bhttp"))] pub enum Mode { KnownLength, IndeterminateLength, @@ -120,7 +99,7 @@ impl Field { &self.value } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { w.write_all(&self.name)?; w.write_all(b": ")?; @@ -129,14 +108,13 @@ impl Field { Ok(()) } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, w: &mut impl io::Write) -> Res<()> { write_vec(&self.name, w)?; write_vec(&self.value, w)?; Ok(()) } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] pub fn obs_fold(&mut self, extra: &[u8]) { self.value.push(SP); self.value.extend(trim_ows(extra)); @@ -192,7 +170,7 @@ impl FieldSection { /// As required by the HTTP specification, remove the Connection header /// field, everything it refers to, and a few extra fields. - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] fn strip_connection_headers(&mut self) { const CONNECTION: &[u8] = b"connection"; const PROXY_CONNECTION: &[u8] = b"proxy-connection"; @@ -232,7 +210,7 @@ impl FieldSection { }); } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] fn parse_line(fields: &mut Vec, line: Vec) -> Res<()> { // obs-fold is helpful in specs, so support it here too let f = if is_ows(line[0]) { @@ -251,7 +229,7 @@ impl FieldSection { Ok(()) } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] pub fn read_http(r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -267,7 +245,6 @@ impl FieldSection { } } - #[cfg(feature = "read-bhttp")] fn read_bhttp_fields(terminator: bool, r: &mut T) -> Res> where T: BorrowMut + ?Sized, @@ -302,7 +279,6 @@ impl FieldSection { } } - #[cfg(feature = "read-bhttp")] pub fn read_bhttp(mode: Mode, r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -320,7 +296,6 @@ impl FieldSection { Ok(Self(fields)) } - #[cfg(feature = "write-bhttp")] fn write_bhttp_headers(&self, w: &mut impl io::Write) -> Res<()> { for f in &self.0 { f.write_bhttp(w)?; @@ -328,7 +303,6 @@ impl FieldSection { Ok(()) } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { if mode == Mode::KnownLength { let mut buf = Vec::new(); @@ -341,7 +315,7 @@ impl FieldSection { Ok(()) } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { for f in &self.0 { f.write_http(w)?; @@ -420,7 +394,7 @@ impl ControlData { } } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] pub fn read_http(line: Vec) -> Res { // request-line = method SP request-target SP HTTP-version // status-line = HTTP-version SP status-code SP [reason-phrase] @@ -467,7 +441,6 @@ impl ControlData { } } - #[cfg(feature = "read-bhttp")] pub fn read_bhttp(request: bool, r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -493,7 +466,6 @@ impl ControlData { } /// If this is an informational response. - #[cfg(any(feature = "read-bhttp", feature = "read-http"))] #[must_use] fn informational(&self) -> Option { match self { @@ -502,7 +474,6 @@ impl ControlData { } } - #[cfg(feature = "write-bhttp")] #[must_use] fn code(&self, mode: Mode) -> u64 { match (self, mode) { @@ -513,7 +484,6 @@ impl ControlData { } } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, w: &mut impl io::Write) -> Res<()> { match self { Self::Request { @@ -532,7 +502,7 @@ impl ControlData { Ok(()) } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { match self { Self::Request { @@ -581,7 +551,6 @@ impl InformationalResponse { &self.fields } - #[cfg(feature = "write-bhttp")] fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { write_varint(self.status.code(), w)?; self.fields.write_bhttp(mode, w)?; @@ -662,7 +631,7 @@ impl Message { &self.trailer } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] fn read_chunked(r: &mut T) -> Res> where T: BorrowMut + ?Sized, @@ -686,7 +655,7 @@ impl Message { } } - #[cfg(feature = "read-http")] + #[cfg(feature = "http")] #[allow(clippy::read_zero_byte_vec)] // https://github.com/rust-lang/rust-clippy/issues/9274 pub fn read_http(r: &mut T) -> Res where @@ -741,7 +710,7 @@ impl Message { }) } - #[cfg(feature = "write-http")] + #[cfg(feature = "http")] pub fn write_http(&self, w: &mut impl io::Write) -> Res<()> { for info in &self.informational { ControlData::Response(info.status()).write_http(w)?; @@ -770,7 +739,6 @@ impl Message { } /// Read a BHTTP message. - #[cfg(feature = "read-bhttp")] pub fn read_bhttp(r: &mut T) -> Res where T: BorrowMut + ?Sized, @@ -815,7 +783,6 @@ impl Message { }) } - #[cfg(feature = "write-bhttp")] pub fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { write_varint(self.control.code(mode), w)?; for info in &self.informational { @@ -833,7 +800,7 @@ impl Message { } } -#[cfg(feature = "write-http")] +#[cfg(feature = "http")] impl std::fmt::Debug for Message { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { let mut buf = Vec::new(); diff --git a/bhttp/src/parse.rs b/bhttp/src/parse.rs index ee52493..06165fc 100644 --- a/bhttp/src/parse.rs +++ b/bhttp/src/parse.rs @@ -1,20 +1,21 @@ -#[cfg(feature = "read-http")] -use crate::{Error, ReadSeek, Res}; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] use std::borrow::BorrowMut; +#[cfg(feature = "http")] +use crate::{Error, ReadSeek, Res}; + pub const HTAB: u8 = 0x09; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const NL: u8 = 0x0a; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const CR: u8 = 0x0d; pub const SP: u8 = 0x20; pub const COMMA: u8 = 0x2c; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const SLASH: u8 = 0x2f; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const COLON: u8 = 0x3a; -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub const SEMICOLON: u8 = 0x3b; pub fn is_ows(x: u8) -> bool { @@ -34,7 +35,7 @@ pub fn trim_ows(v: &[u8]) -> &[u8] { &v[..0] } -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub fn downcase(n: &mut [u8]) { for i in n { if *i >= 0x41 && *i <= 0x5a { @@ -52,7 +53,7 @@ pub fn index_of(v: u8, line: &[u8]) -> Option { None } -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub fn split_at(v: u8, mut line: Vec) -> Option<(Vec, Vec)> { index_of(v, &line).map(|i| { let tail = line.split_off(i + 1); @@ -61,7 +62,7 @@ pub fn split_at(v: u8, mut line: Vec) -> Option<(Vec, Vec)> { }) } -#[cfg(feature = "read-http")] +#[cfg(feature = "http")] pub fn read_line(r: &mut T) -> Res> where T: BorrowMut + ?Sized, diff --git a/bhttp/src/rw.rs b/bhttp/src/rw.rs index 92009ed..fa7c717 100644 --- a/bhttp/src/rw.rs +++ b/bhttp/src/rw.rs @@ -1,79 +1,66 @@ -#[cfg(feature = "read-bhttp")] -use std::borrow::BorrowMut; -use std::{convert::TryFrom, io}; +use std::{borrow::BorrowMut, convert::TryFrom, io}; -use crate::err::Res; -#[cfg(feature = "read-bhttp")] -use crate::{err::Error, ReadSeek}; +use crate::{ + err::{Error, Res}, + ReadSeek, +}; -#[cfg(feature = "write-bhttp")] #[allow(clippy::cast_possible_truncation)] -fn write_uint(n: u8, v: impl Into, w: &mut impl io::Write) -> Res<()> { - let v = v.into(); - assert!(n > 0 && usize::from(n) < std::mem::size_of::()); - for i in 0..n { - w.write_all(&[((v >> (8 * (n - i - 1))) & 0xff) as u8])?; - } +pub(crate) fn write_uint(v: impl Into, w: &mut impl io::Write) -> Res<()> { + let v = v.into().to_be_bytes(); + assert!((1..=std::mem::size_of::()).contains(&N)); + w.write_all(&v[8 - N..])?; Ok(()) } -#[cfg(feature = "write-bhttp")] pub fn write_varint(v: impl Into, w: &mut impl io::Write) -> Res<()> { let v = v.into(); match () { - () if v < (1 << 6) => write_uint(1, v, w), - () if v < (1 << 14) => write_uint(2, v | (1 << 14), w), - () if v < (1 << 30) => write_uint(4, v | (2 << 30), w), - () if v < (1 << 62) => write_uint(8, v | (3 << 62), w), + () if v < (1 << 6) => write_uint::<1>(v, w), + () if v < (1 << 14) => write_uint::<2>(v | (1 << 14), w), + () if v < (1 << 30) => write_uint::<4>(v | (2 << 30), w), + () if v < (1 << 62) => write_uint::<8>(v | (3 << 62), w), () => panic!("Varint value too large"), } } -#[cfg(feature = "write-bhttp")] pub fn write_len(len: usize, w: &mut impl io::Write) -> Res<()> { write_varint(u64::try_from(len).unwrap(), w) } -#[cfg(feature = "write-bhttp")] pub fn write_vec(v: &[u8], w: &mut impl io::Write) -> Res<()> { write_len(v.len(), w)?; w.write_all(v)?; Ok(()) } -#[cfg(feature = "read-bhttp")] -fn read_uint(n: usize, r: &mut T) -> Res> +fn read_uint(r: &mut T) -> Res> where T: BorrowMut + ?Sized, R: ReadSeek + ?Sized, { - let mut buf = [0; 7]; - let count = r.borrow_mut().read(&mut buf[..n])?; + let mut buf = [0; 8]; + let count = r.borrow_mut().read(&mut buf[(8 - N)..])?; if count == 0 { Ok(None) - } else if count < n { + } else if count < N { Err(Error::Truncated) } else { - let mut v = 0; - for i in &buf[..n] { - v = (v << 8) | u64::from(*i); - } - Ok(Some(v)) + Ok(Some(u64::from_be_bytes(buf))) } } -#[cfg(feature = "read-bhttp")] pub fn read_varint(r: &mut T) -> Res> where T: BorrowMut + ?Sized, R: ReadSeek + ?Sized, { - if let Some(b1) = read_uint(1, r)? { + if let Some(b1) = read_uint::<_, _, 1>(r)? { Ok(Some(match b1 >> 6 { 0 => b1 & 0x3f, - 1 => ((b1 & 0x3f) << 8) | read_uint(1, r)?.ok_or(Error::Truncated)?, - 2 => ((b1 & 0x3f) << 24) | read_uint(3, r)?.ok_or(Error::Truncated)?, - 3 => ((b1 & 0x3f) << 56) | read_uint(7, r)?.ok_or(Error::Truncated)?, + 1 => ((b1 & 0x3f) << 8) | read_uint::<_, _, 1>(r)?.ok_or(Error::Truncated)?, + 2 => ((b1 & 0x3f) << 24) | read_uint::<_, _, 3>(r)?.ok_or(Error::Truncated)?, + 3 => ((b1 & 0x3f) << 56) | read_uint::<_, _, 7>(r)?.ok_or(Error::Truncated)?, _ => unreachable!(), })) } else { @@ -81,7 +68,6 @@ where } } -#[cfg(feature = "read-bhttp")] pub fn read_vec(r: &mut T) -> Res>> where T: BorrowMut + ?Sized, diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs new file mode 100644 index 0000000..982eda7 --- /dev/null +++ b/bhttp/src/stream/mod.rs @@ -0,0 +1,233 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::io::AsyncRead; + +use crate::{Error, Res}; + +#[pin_project::pin_project] +pub struct ReadUint<'a, S, const N: usize> { + /// The source of data. + src: Pin<&'a mut S>, + /// A buffer that holds the bytes that have been read so far. + v: [u8; 8], + /// A counter of the number of bytes that are already in place. + /// This starts out at `8-N`. + read: usize, +} + +impl<'a, S, const N: usize> Future for ReadUint<'a, S, N> +where + S: AsyncRead, +{ + type Output = Res; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.src.as_mut().poll_read(cx, &mut this.v[*this.read..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(count)) => { + if count == 0 { + return Poll::Ready(Err(Error::Truncated)); + } + *this.read += count; + if *this.read == 8 { + Poll::Ready(Ok(u64::from_be_bytes(*this.v))) + } else { + Poll::Pending + } + } + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), + } + } +} + +pub fn read_uint(src: &mut S) -> ReadUint<'_, S, N> { + ReadUint { + src: Pin::new(src), + v: [0; 8], + read: 8 - N, + } +} + +#[pin_project::pin_project(project = ReadVariantProj)] +pub enum ReadVarint<'a, S> { + First(Option>), + Extra1(#[pin] ReadUint<'a, S, 8>), + Extra3(#[pin] ReadUint<'a, S, 8>), + Extra7(#[pin] ReadUint<'a, S, 8>), +} + +impl<'a, S> Future for ReadVarint<'a, S> +where + S: AsyncRead, +{ + type Output = Res>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut(); + if let Self::First(src) = this.get_mut() { + let mut src = src.take().unwrap(); + let mut buf = [0; 1]; + if let Poll::Ready(Ok(c)) = src.as_mut().poll_read(cx, &mut buf[..]) { + if c == 0 { + return Poll::Ready(Ok(None)); + } + let b1 = buf[0]; + let mut v = [0; 8]; + let next = match b1 >> 6 { + 0 => return Poll::Ready(Ok(Some(u64::from(b1)))), + 1 => { + v[6] = b1 & 0x3f; + Self::Extra1(ReadUint { src, v, read: 7 }) + } + 2 => { + v[4] = b1 & 0x3f; + Self::Extra3(ReadUint { src, v, read: 5 }) + } + 3 => { + v[0] = b1 & 0x3f; + Self::Extra7(ReadUint { src, v, read: 1 }) + } + _ => unreachable!(), + }; + + self.set(next); + } + } + let extra = match self.project() { + ReadVariantProj::Extra1(s) + | ReadVariantProj::Extra3(s) + | ReadVariantProj::Extra7(s) => s.poll(cx), + ReadVariantProj::First(_) => return Poll::Pending, + }; + if let Poll::Ready(v) = extra { + Poll::Ready(v.map(Some)) + } else { + Poll::Pending + } + } +} + +pub fn read_varint(src: &mut S) -> ReadVarint<'_, S> { + ReadVarint::First(Some(Pin::new(src))) +} + +#[cfg(test)] +mod test { + use std::task::{Context, Poll}; + + use futures::{Future, FutureExt}; + + use crate::{ + rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, + stream::{read_uint as stream_read_uint, read_varint as stream_read_varint}, + }; + + pub fn noop_context() -> Context<'static> { + use std::{ + ptr::null, + task::{RawWaker, RawWakerVTable, Waker}, + }; + + const fn noop_raw_waker() -> RawWaker { + unsafe fn noop_clone(_data: *const ()) -> RawWaker { + noop_raw_waker() + } + + unsafe fn noop(_data: *const ()) {} + + const NOOP_WAKER_VTABLE: RawWakerVTable = + RawWakerVTable::new(noop_clone, noop, noop, noop); + RawWaker::new(null(), &NOOP_WAKER_VTABLE) + } + + pub fn noop_waker_ref() -> &'static Waker { + struct SyncRawWaker(RawWaker); + unsafe impl Sync for SyncRawWaker {} + + static NOOP_WAKER_INSTANCE: SyncRawWaker = SyncRawWaker(noop_raw_waker()); + + // SAFETY: `Waker` is #[repr(transparent)] over its `RawWaker`. + unsafe { &*(std::ptr::addr_of!(NOOP_WAKER_INSTANCE.0).cast()) } + } + + Context::from_waker(noop_waker_ref()) + } + + fn assert_unpin(v: T) -> T { + v + } + + fn read_uint(mut buf: &[u8]) -> u64 { + println!("{buf:?}"); + let mut cx = noop_context(); + let mut fut = assert_unpin(stream_read_uint::<_, N>(&mut buf)); + let mut v = fut.poll_unpin(&mut cx); + while v.is_pending() { + v = fut.poll_unpin(&mut cx); + } + if let Poll::Ready(Ok(v)) = v { + v + } else { + panic!("v is not OK: {v:?}"); + } + } + + #[test] + fn read_uint_values() { + macro_rules! validate_uint_range { + (@ $n:expr) => { + let m = u64::MAX >> (64 - 8 * $n); + for v in [0, 1, m] { + println!("{n} byte encoding of 0x{v:x}", n = $n); + let mut buf = Vec::with_capacity($n); + sync_write_uint::<$n>(v, &mut buf).unwrap(); + assert_eq!(v, read_uint::<$n>(&buf[..])); + } + }; + ($($n:expr),+ $(,)?) => { + $( + validate_uint_range!(@ $n); + )+ + } + } + validate_uint_range!(1, 2, 3, 4, 5, 6, 7, 8); + } + + fn read_varint(mut buf: &[u8]) -> u64 { + let mut cx = noop_context(); + let mut fut = assert_unpin(stream_read_varint(&mut buf)); + let mut v = fut.poll_unpin(&mut cx); + while v.is_pending() { + v = fut.poll_unpin(&mut cx); + } + if let Poll::Ready(Ok(Some(v))) = v { + v + } else { + panic!("v is not OK: {v:?}"); + } + } + + #[test] + fn read_varint_values() { + for i in [ + 0, + 1, + 63, + 64, + (1 << 14) - 1, + 1 << 14, + (1 << 30) - 1, + 1 << 30, + (1 << 62) - 1, + ] { + let mut buf = Vec::new(); + sync_write_varint(i, &mut buf).unwrap(); + assert_eq!(i, read_varint(&buf[..])); + } + } +} diff --git a/bhttp/tests/test.rs b/bhttp/tests/test.rs index c6729c6..9ed9731 100644 --- a/bhttp/tests/test.rs +++ b/bhttp/tests/test.rs @@ -1,5 +1,5 @@ // Rather than grapple with #[cfg(...)] for every variable and import. -#![cfg(all(feature = "http", feature = "bhttp"))] +#![cfg(feature = "http")] use std::{io::Cursor, mem::drop}; diff --git a/ohttp-client-cli/Cargo.toml b/ohttp-client-cli/Cargo.toml index 4f530f6..f40b198 100644 --- a/ohttp-client-cli/Cargo.toml +++ b/ohttp-client-cli/Cargo.toml @@ -15,7 +15,7 @@ hex = "0.4" [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "http"] +features = ["http"] [dependencies.ohttp] path= "../ohttp" diff --git a/ohttp-client/Cargo.toml b/ohttp-client/Cargo.toml index 7677608..1fa27fc 100644 --- a/ohttp-client/Cargo.toml +++ b/ohttp-client/Cargo.toml @@ -19,7 +19,7 @@ tokio = { version = "1", features = ["full"] } [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "http"] +features = ["http"] [dependencies.ohttp] path= "../ohttp" diff --git a/ohttp-server/Cargo.toml b/ohttp-server/Cargo.toml index 026df48..87b595e 100644 --- a/ohttp-server/Cargo.toml +++ b/ohttp-server/Cargo.toml @@ -18,7 +18,7 @@ warp = { version = "0.3", features = ["tls"] } [dependencies.bhttp] path= "../bhttp" -features = ["bhttp", "write-http"] +features = ["http"] [dependencies.ohttp] path= "../ohttp" diff --git a/ohttp/build.rs b/ohttp/build.rs index 1c01e3f..312cce2 100644 --- a/ohttp/build.rs +++ b/ohttp/build.rs @@ -8,8 +8,6 @@ #[cfg(feature = "nss")] mod nss { - use bindgen::Builder; - use serde_derive::Deserialize; use std::{ collections::HashMap, env, fs, @@ -17,6 +15,9 @@ mod nss { process::Command, }; + use bindgen::Builder; + use serde_derive::Deserialize; + const BINDINGS_DIR: &str = "bindings"; const BINDINGS_CONFIG: &str = "bindings.toml"; @@ -114,7 +115,6 @@ mod nss { let mut build_nss = vec![ String::from("./build.sh"), String::from("-Ddisable_tests=1"), - String::from("-Denable_draft_hpke=1"), ]; if is_debug() { build_nss.push(String::from("--static")); @@ -191,16 +191,8 @@ mod nss { } fn static_link(nsslibdir: &Path, use_static_softoken: bool, use_static_nspr: bool) { - let mut static_libs = vec![ - "certdb", - "certhi", - "cryptohi", - "nss_static", - "nssb", - "nssdev", - "nsspki", - "nssutil", - ]; + // The ordering of these libraries is critical for the linker. + let mut static_libs = vec!["cryptohi", "nss_static"]; let mut dynamic_libs = vec![]; if use_static_softoken { @@ -211,6 +203,8 @@ mod nss { static_libs.push("pk11wrap"); } + static_libs.extend_from_slice(&["nsspki", "nssdev", "nssb", "certhi", "certdb", "nssutil"]); + if use_static_nspr { static_libs.append(&mut nspr_libs()); } else { From a9d76f7f1b4603b956f54bc76b36030c0a6afdb5 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Wed, 23 Oct 2024 16:15:06 +1100 Subject: [PATCH 02/16] Adding vector reading capabilities --- bhttp/src/err.rs | 3 + bhttp/src/stream/context.rs | 62 +++++++++ bhttp/src/stream/int.rs | 261 ++++++++++++++++++++++++++++++++++++ bhttp/src/stream/mod.rs | 236 +------------------------------- bhttp/src/stream/vec.rs | 229 +++++++++++++++++++++++++++++++ 5 files changed, 559 insertions(+), 232 deletions(-) create mode 100644 bhttp/src/stream/context.rs create mode 100644 bhttp/src/stream/int.rs create mode 100644 bhttp/src/stream/vec.rs diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index 3a457da..e53e0a1 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -19,6 +19,9 @@ pub enum Error { InvalidStatus, #[error("IO error {0}")] Io(#[from] std::io::Error), + #[cfg(feature = "stream")] + #[error("the size of a vector exceeded the limit that was set")] + LimitExceeded, #[error("a field or line was missing a necessary character 0x{0:x}")] Missing(u8), #[error("a URL was missing a key component")] diff --git a/bhttp/src/stream/context.rs b/bhttp/src/stream/context.rs new file mode 100644 index 0000000..c6987b0 --- /dev/null +++ b/bhttp/src/stream/context.rs @@ -0,0 +1,62 @@ +use std::{ + future::Future, + task::{Context, Poll}, +}; + +use futures::FutureExt; + +fn noop_context() -> Context<'static> { + use std::{ + ptr::null, + task::{RawWaker, RawWakerVTable, Waker}, + }; + + const fn noop_raw_waker() -> RawWaker { + unsafe fn noop_clone(_data: *const ()) -> RawWaker { + noop_raw_waker() + } + + unsafe fn noop(_data: *const ()) {} + + const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop); + RawWaker::new(null(), &NOOP_WAKER_VTABLE) + } + + pub fn noop_waker_ref() -> &'static Waker { + struct SyncRawWaker(RawWaker); + unsafe impl Sync for SyncRawWaker {} + + static NOOP_WAKER_INSTANCE: SyncRawWaker = SyncRawWaker(noop_raw_waker()); + + // SAFETY: `Waker` is #[repr(transparent)] over its `RawWaker`. + unsafe { &*(std::ptr::addr_of!(NOOP_WAKER_INSTANCE.0).cast()) } + } + + Context::from_waker(noop_waker_ref()) +} + +fn assert_unpin(v: F) -> F { + v +} + +/// Drives the given future (`f`) until it resolves. +/// Executes the indicated function (`p`) each time the +/// poll returned `Poll::Pending`. +pub fn sync_resolve_with(f: F, p: P) -> F::Output { + let mut cx = noop_context(); + let mut fut = assert_unpin(f); + let mut v = fut.poll_unpin(&mut cx); + while v.is_pending() { + p(&mut fut); + v = fut.poll_unpin(&mut cx); + } + if let Poll::Ready(v) = v { + v + } else { + unreachable!(); + } +} + +pub fn sync_resolve(f: F) -> F::Output { + sync_resolve_with(f, |_| {}) +} diff --git a/bhttp/src/stream/int.rs b/bhttp/src/stream/int.rs new file mode 100644 index 0000000..02b05df --- /dev/null +++ b/bhttp/src/stream/int.rs @@ -0,0 +1,261 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::io::AsyncRead; + +use crate::{Error, Res}; + +#[pin_project::pin_project] +pub struct ReadUint<'a, S, const N: usize> { + /// The source of data. + src: Pin<&'a mut S>, + /// A buffer that holds the bytes that have been read so far. + v: [u8; 8], + /// A counter of the number of bytes that are already in place. + /// This starts out at `8-N`. + read: usize, +} + +impl<'a, S, const N: usize> ReadUint<'a, S, N> { + pub fn stream(self) -> Pin<&'a mut S> { + self.src + } +} + +impl<'a, S, const N: usize> Future for ReadUint<'a, S, N> +where + S: AsyncRead, +{ + type Output = Res; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.src.as_mut().poll_read(cx, &mut this.v[*this.read..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(count)) => { + if count == 0 { + return Poll::Ready(Err(Error::Truncated)); + } + *this.read += count; + if *this.read == 8 { + Poll::Ready(Ok(u64::from_be_bytes(*this.v))) + } else { + Poll::Pending + } + } + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), + } + } +} + +pub fn read_uint(src: &mut S) -> ReadUint<'_, S, N> { + ReadUint { + src: Pin::new(src), + v: [0; 8], + read: 8 - N, + } +} + +#[pin_project::pin_project(project = ReadVarintProj)] +pub enum ReadVarint<'a, S> { + // Invariant: this Option always contains Some. + First(Option>), + Extra1(#[pin] ReadUint<'a, S, 8>), + Extra3(#[pin] ReadUint<'a, S, 8>), + Extra7(#[pin] ReadUint<'a, S, 8>), +} + +impl<'a, S> ReadVarint<'a, S> { + pub fn stream(self) -> Pin<&'a mut S> { + match self { + Self::Extra1(s) | Self::Extra3(s) | Self::Extra7(s) => s.stream(), + Self::First(mut s) => s.take().unwrap(), + } + } +} + +impl<'a, S> Future for ReadVarint<'a, S> +where + S: AsyncRead, +{ + type Output = Res>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut(); + if let Self::First(ref mut src) = this.get_mut() { + let mut buf = [0; 1]; + let src_ref = src.as_mut().unwrap().as_mut(); + if let Poll::Ready(res) = src_ref.poll_read(cx, &mut buf[..]) { + match res { + Ok(0) => return Poll::Ready(Ok(None)), + Ok(_) => (), + Err(e) => return Poll::Ready(Err(Error::from(e))), + } + + let b1 = buf[0]; + let mut v = [0; 8]; + let next = match b1 >> 6 { + 0 => return Poll::Ready(Ok(Some(u64::from(b1)))), + 1 => { + let src = src.take().unwrap(); + v[6] = b1 & 0x3f; + Self::Extra1(ReadUint { src, v, read: 7 }) + } + 2 => { + let src = src.take().unwrap(); + v[4] = b1 & 0x3f; + Self::Extra3(ReadUint { src, v, read: 5 }) + } + 3 => { + let src = src.take().unwrap(); + v[0] = b1 & 0x3f; + Self::Extra7(ReadUint { src, v, read: 1 }) + } + _ => unreachable!(), + }; + + self.set(next); + } + } + let extra = match self.project() { + ReadVarintProj::Extra1(s) | ReadVarintProj::Extra3(s) | ReadVarintProj::Extra7(s) => { + s.poll(cx) + } + ReadVarintProj::First(_) => return Poll::Pending, + }; + if let Poll::Ready(v) = extra { + Poll::Ready(v.map(Some)) + } else { + Poll::Pending + } + } +} + +pub fn read_varint(src: &mut S) -> ReadVarint<'_, S> { + ReadVarint::First(Some(Pin::new(src))) +} + +#[cfg(test)] +mod test { + use crate::{ + err::Error, + rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, + stream::{ + context::sync_resolve, + int::{read_uint, read_varint}, + }, + }; + + const VARINTS: &[u64] = &[ + 0, + 1, + 63, + 64, + (1 << 14) - 1, + 1 << 14, + (1 << 30) - 1, + 1 << 30, + (1 << 62) - 1, + ]; + + #[test] + fn read_uint_values() { + macro_rules! validate_uint_range { + (@ $n:expr) => { + let m = u64::MAX >> (64 - 8 * $n); + for v in [0, 1, m] { + println!("{n} byte encoding of 0x{v:x}", n = $n); + let mut buf = Vec::with_capacity($n); + sync_write_uint::<$n>(v, &mut buf).unwrap(); + let mut buf_ref = &buf[..]; + let mut fut = read_uint::<_, $n>(&mut buf_ref); + assert_eq!(v, sync_resolve(&mut fut).unwrap()); + let s = fut.stream(); + assert!(s.is_empty()); + } + }; + ($($n:expr),+ $(,)?) => { + $( + validate_uint_range!(@ $n); + )+ + } + } + validate_uint_range!(1, 2, 3, 4, 5, 6, 7, 8); + } + + #[test] + fn read_uint_truncated() { + macro_rules! validate_uint_truncated { + (@ $n:expr) => { + let m = u64::MAX >> (64 - 8 * $n); + for v in [0, 1, m] { + println!("{n} byte encoding of 0x{v:x}", n = $n); + let mut buf = Vec::with_capacity($n); + sync_write_uint::<$n>(v, &mut buf).unwrap(); + for i in 1..buf.len() { + let err = sync_resolve(read_uint::<_, $n>(&mut &buf[..i])).unwrap_err(); + assert!(matches!(err, Error::Truncated)); + } + } + }; + ($($n:expr),+ $(,)?) => { + $( + validate_uint_truncated!(@ $n); + )+ + } + } + validate_uint_truncated!(1, 2, 3, 4, 5, 6, 7, 8); + } + + #[test] + fn read_varint_values() { + for &v in VARINTS { + let mut buf = Vec::new(); + sync_write_varint(v, &mut buf).unwrap(); + let mut buf_ref = &buf[..]; + let mut fut = read_varint(&mut buf_ref); + assert_eq!(Some(v), sync_resolve(&mut fut).unwrap()); + let s = fut.stream(); + assert!(s.is_empty()); + } + } + + #[test] + fn read_varint_none() { + assert!(sync_resolve(read_varint(&mut &[][..])).unwrap().is_none()); + } + + #[test] + fn read_varint_truncated() { + for &v in VARINTS { + let mut buf = Vec::new(); + sync_write_varint(v, &mut buf).unwrap(); + for i in 1..buf.len() { + let err = { + let mut buf: &[u8] = &buf[..i]; + sync_resolve(read_varint(&mut buf)) + } + .unwrap_err(); + assert!(matches!(err, Error::Truncated)); + } + } + } + + #[test] + fn read_varint_extra() { + const EXTRA: &[u8] = &[161, 2, 49]; + for &v in VARINTS { + let mut buf = Vec::new(); + sync_write_varint(v, &mut buf).unwrap(); + buf.extend_from_slice(EXTRA); + let mut buf_ref = &buf[..]; + let mut fut = read_varint(&mut buf_ref); + assert_eq!(Some(v), sync_resolve(&mut fut).unwrap()); + let s = fut.stream(); + assert_eq!(&s[..], EXTRA); + } + } +} diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs index 982eda7..f00b008 100644 --- a/bhttp/src/stream/mod.rs +++ b/bhttp/src/stream/mod.rs @@ -1,233 +1,5 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -use futures::io::AsyncRead; - -use crate::{Error, Res}; - -#[pin_project::pin_project] -pub struct ReadUint<'a, S, const N: usize> { - /// The source of data. - src: Pin<&'a mut S>, - /// A buffer that holds the bytes that have been read so far. - v: [u8; 8], - /// A counter of the number of bytes that are already in place. - /// This starts out at `8-N`. - read: usize, -} - -impl<'a, S, const N: usize> Future for ReadUint<'a, S, N> -where - S: AsyncRead, -{ - type Output = Res; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - match this.src.as_mut().poll_read(cx, &mut this.v[*this.read..]) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(count)) => { - if count == 0 { - return Poll::Ready(Err(Error::Truncated)); - } - *this.read += count; - if *this.read == 8 { - Poll::Ready(Ok(u64::from_be_bytes(*this.v))) - } else { - Poll::Pending - } - } - Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), - } - } -} - -pub fn read_uint(src: &mut S) -> ReadUint<'_, S, N> { - ReadUint { - src: Pin::new(src), - v: [0; 8], - read: 8 - N, - } -} - -#[pin_project::pin_project(project = ReadVariantProj)] -pub enum ReadVarint<'a, S> { - First(Option>), - Extra1(#[pin] ReadUint<'a, S, 8>), - Extra3(#[pin] ReadUint<'a, S, 8>), - Extra7(#[pin] ReadUint<'a, S, 8>), -} - -impl<'a, S> Future for ReadVarint<'a, S> -where - S: AsyncRead, -{ - type Output = Res>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.as_mut(); - if let Self::First(src) = this.get_mut() { - let mut src = src.take().unwrap(); - let mut buf = [0; 1]; - if let Poll::Ready(Ok(c)) = src.as_mut().poll_read(cx, &mut buf[..]) { - if c == 0 { - return Poll::Ready(Ok(None)); - } - let b1 = buf[0]; - let mut v = [0; 8]; - let next = match b1 >> 6 { - 0 => return Poll::Ready(Ok(Some(u64::from(b1)))), - 1 => { - v[6] = b1 & 0x3f; - Self::Extra1(ReadUint { src, v, read: 7 }) - } - 2 => { - v[4] = b1 & 0x3f; - Self::Extra3(ReadUint { src, v, read: 5 }) - } - 3 => { - v[0] = b1 & 0x3f; - Self::Extra7(ReadUint { src, v, read: 1 }) - } - _ => unreachable!(), - }; - - self.set(next); - } - } - let extra = match self.project() { - ReadVariantProj::Extra1(s) - | ReadVariantProj::Extra3(s) - | ReadVariantProj::Extra7(s) => s.poll(cx), - ReadVariantProj::First(_) => return Poll::Pending, - }; - if let Poll::Ready(v) = extra { - Poll::Ready(v.map(Some)) - } else { - Poll::Pending - } - } -} - -pub fn read_varint(src: &mut S) -> ReadVarint<'_, S> { - ReadVarint::First(Some(Pin::new(src))) -} - +#![allow(dead_code)] // TODO #[cfg(test)] -mod test { - use std::task::{Context, Poll}; - - use futures::{Future, FutureExt}; - - use crate::{ - rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, - stream::{read_uint as stream_read_uint, read_varint as stream_read_varint}, - }; - - pub fn noop_context() -> Context<'static> { - use std::{ - ptr::null, - task::{RawWaker, RawWakerVTable, Waker}, - }; - - const fn noop_raw_waker() -> RawWaker { - unsafe fn noop_clone(_data: *const ()) -> RawWaker { - noop_raw_waker() - } - - unsafe fn noop(_data: *const ()) {} - - const NOOP_WAKER_VTABLE: RawWakerVTable = - RawWakerVTable::new(noop_clone, noop, noop, noop); - RawWaker::new(null(), &NOOP_WAKER_VTABLE) - } - - pub fn noop_waker_ref() -> &'static Waker { - struct SyncRawWaker(RawWaker); - unsafe impl Sync for SyncRawWaker {} - - static NOOP_WAKER_INSTANCE: SyncRawWaker = SyncRawWaker(noop_raw_waker()); - - // SAFETY: `Waker` is #[repr(transparent)] over its `RawWaker`. - unsafe { &*(std::ptr::addr_of!(NOOP_WAKER_INSTANCE.0).cast()) } - } - - Context::from_waker(noop_waker_ref()) - } - - fn assert_unpin(v: T) -> T { - v - } - - fn read_uint(mut buf: &[u8]) -> u64 { - println!("{buf:?}"); - let mut cx = noop_context(); - let mut fut = assert_unpin(stream_read_uint::<_, N>(&mut buf)); - let mut v = fut.poll_unpin(&mut cx); - while v.is_pending() { - v = fut.poll_unpin(&mut cx); - } - if let Poll::Ready(Ok(v)) = v { - v - } else { - panic!("v is not OK: {v:?}"); - } - } - - #[test] - fn read_uint_values() { - macro_rules! validate_uint_range { - (@ $n:expr) => { - let m = u64::MAX >> (64 - 8 * $n); - for v in [0, 1, m] { - println!("{n} byte encoding of 0x{v:x}", n = $n); - let mut buf = Vec::with_capacity($n); - sync_write_uint::<$n>(v, &mut buf).unwrap(); - assert_eq!(v, read_uint::<$n>(&buf[..])); - } - }; - ($($n:expr),+ $(,)?) => { - $( - validate_uint_range!(@ $n); - )+ - } - } - validate_uint_range!(1, 2, 3, 4, 5, 6, 7, 8); - } - - fn read_varint(mut buf: &[u8]) -> u64 { - let mut cx = noop_context(); - let mut fut = assert_unpin(stream_read_varint(&mut buf)); - let mut v = fut.poll_unpin(&mut cx); - while v.is_pending() { - v = fut.poll_unpin(&mut cx); - } - if let Poll::Ready(Ok(Some(v))) = v { - v - } else { - panic!("v is not OK: {v:?}"); - } - } - - #[test] - fn read_varint_values() { - for i in [ - 0, - 1, - 63, - 64, - (1 << 14) - 1, - 1 << 14, - (1 << 30) - 1, - 1 << 30, - (1 << 62) - 1, - ] { - let mut buf = Vec::new(); - sync_write_varint(i, &mut buf).unwrap(); - assert_eq!(i, read_varint(&buf[..])); - } - } -} +mod context; +mod int; +mod vec; diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs new file mode 100644 index 0000000..8b02474 --- /dev/null +++ b/bhttp/src/stream/vec.rs @@ -0,0 +1,229 @@ +use std::{ + future::Future, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{io::AsyncRead, FutureExt}; + +use super::int::{read_varint, ReadVarint}; +use crate::{Error, Res}; + +#[pin_project::pin_project(project = ReadVecProj)] +#[allow(clippy::module_name_repetitions)] +pub enum ReadVec<'a, S> { + // Invariant: This Option is always Some. + ReadLen { + src: Option>, + cap: u64, + }, + ReadBody { + src: Pin<&'a mut S>, + buf: Vec, + remaining: usize, + }, +} + +impl<'a, S> ReadVec<'a, S> { + /// # Panics + /// If `limit` is more than `usize::MAX` or + /// if this is called after the length is read. + fn limit(&mut self, limit: u64) { + usize::try_from(limit).expect("cannot set a limit larger than usize::MAX"); + if let Self::ReadLen { ref mut cap, .. } = self { + *cap = limit; + } else { + panic!("cannot set a limit once the size has been read"); + } + } + + fn stream(self) -> Pin<&'a mut S> { + match self { + Self::ReadLen { mut src, .. } => src.take().unwrap().stream(), + Self::ReadBody { src, .. } => src, + } + } +} + +impl<'a, S> Future for ReadVec<'a, S> +where + S: AsyncRead, +{ + type Output = Res>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.as_mut(); + if let Self::ReadLen { src, cap } = this.get_mut() { + match src.as_mut().unwrap().poll_unpin(cx) { + Poll::Ready(Ok(None)) => return Poll::Ready(Ok(None)), + Poll::Ready(Ok(Some(0))) => return Poll::Ready(Ok(Some(Vec::new()))), + Poll::Ready(Ok(Some(sz))) => { + if sz > *cap { + return Poll::Ready(Err(Error::LimitExceeded)); + } + // `cap` cannot exceed min(usize::MAX, u64::MAX). + let sz = usize::try_from(sz).unwrap(); + let body = Self::ReadBody { + src: src.take().unwrap().stream(), + buf: vec![0; sz], + remaining: sz, + }; + self.set(body); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + + let ReadVecProj::ReadBody { + src, + buf, + remaining, + } = self.project() + else { + return Poll::Pending; + }; + + let offset = buf.len() - *remaining; + match src.as_mut().poll_read(cx, &mut buf[offset..]) { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), + Poll::Ready(Ok(0)) => Poll::Ready(Err(Error::Truncated)), + Poll::Ready(Ok(c)) => { + *remaining -= c; + if *remaining > 0 { + Poll::Pending + } else { + Poll::Ready(Ok(Some(mem::take(buf)))) + } + } + } + } +} + +#[allow(clippy::module_name_repetitions)] +pub fn read_vec(src: &mut S) -> ReadVec<'_, S> { + ReadVec::ReadLen { + src: Some(read_varint(src)), + cap: u64::try_from(usize::MAX).unwrap_or(u64::MAX), + } +} + +#[cfg(test)] +mod test { + + use std::{ + cmp, + fmt::Debug, + io::Result, + pin::Pin, + task::{Context, Poll}, + }; + + use futures::AsyncRead; + + use crate::{ + rw::write_varint as sync_write_varint, + stream::{ + context::{sync_resolve, sync_resolve_with}, + vec::read_vec, + }, + Error, + }; + + const FILL_VALUE: u8 = 90; + + fn fill(len: T) -> Vec + where + u64: TryFrom, + >::Error: Debug, + usize: TryFrom, + >::Error: Debug, + T: Debug + Copy, + { + let mut buf = Vec::new(); + sync_write_varint(u64::try_from(len).unwrap(), &mut buf).unwrap(); + buf.resize(buf.len() + usize::try_from(len).unwrap(), FILL_VALUE); + buf + } + + #[test] + fn read_vecs() { + for len in [0, 1, 2, 3, 64] { + let buf = fill(len); + let mut buf_ref = &buf[..]; + let mut fut = read_vec(&mut buf_ref); + if let Ok(Some(out)) = sync_resolve(&mut fut) { + assert_eq!(len, out.len()); + assert!(out.iter().all(|&v| v == FILL_VALUE)); + + assert!(fut.stream().is_empty()); + } + } + } + + #[test] + fn exceed_cap() { + const LEN: u64 = 20; + let buf = fill(LEN); + let mut buf_ref = &buf[..]; + let mut fut = read_vec(&mut buf_ref); + fut.limit(LEN - 1); + assert!(matches!(sync_resolve(&mut fut), Err(Error::LimitExceeded))); + } + + /// This class implements `AsyncRead`, but + /// always blocks after returning a fixed value. + #[derive(Default)] + struct IncompleteRead<'a> { + data: &'a [u8], + consumed: usize, + } + + impl<'a> IncompleteRead<'a> { + fn new(data: &'a [u8]) -> Self { + Self { data, consumed: 0 } + } + } + + impl AsyncRead for IncompleteRead<'_> { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let remaining = &self.data[self.consumed..]; + if remaining.is_empty() { + Poll::Pending + } else { + let copied = cmp::min(buf.len(), remaining.len()); + buf[..copied].copy_from_slice(&remaining[..copied]); + self.as_mut().consumed += copied; + Poll::Ready(std::io::Result::Ok(copied)) + } + } + } + + #[test] + #[should_panic(expected = "cannot set a limit once the size has been read")] + fn late_cap() { + let mut buf = IncompleteRead::new(&[2, 1]); + _ = sync_resolve_with(read_vec(&mut buf), |f| { + println!("pending"); + f.limit(100); + }); + } + + #[test] + #[cfg(target_pointer_width = "32")] + #[should_panic(expected = "cannot set a limit larger than usize::MAX")] + fn too_large_cap() { + const LEN: u64 = 20; + let buf = fill(LEN); + + let mut buf_ref = &buf[..]; + let mut fut = read_vec(&mut buf_ref); + fut.limit(u64::try_from(usize::MAX).unwrap() + 1); + } +} From 30b7ef53cb04338acfd18958b31bc83737e88147 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Wed, 23 Oct 2024 16:40:58 +1100 Subject: [PATCH 03/16] Add 16-bit arch as well, I guess --- bhttp/src/stream/vec.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs index 8b02474..d71040a 100644 --- a/bhttp/src/stream/vec.rs +++ b/bhttp/src/stream/vec.rs @@ -216,7 +216,7 @@ mod test { } #[test] - #[cfg(target_pointer_width = "32")] + #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] #[should_panic(expected = "cannot set a limit larger than usize::MAX")] fn too_large_cap() { const LEN: u64 = 20; From c432a05a2b11b13a21f7b85c4b73dfcfd5766a08 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Fri, 25 Oct 2024 10:41:22 +1100 Subject: [PATCH 04/16] Checkpoint --- bhttp/src/err.rs | 3 + bhttp/src/lib.rs | 96 +++++--- bhttp/src/stream/{context.rs => future.rs} | 62 +++-- bhttp/src/stream/int.rs | 14 +- bhttp/src/stream/mod.rs | 260 ++++++++++++++++++++- bhttp/src/stream/vec.rs | 13 +- 6 files changed, 386 insertions(+), 62 deletions(-) rename bhttp/src/stream/{context.rs => future.rs} (50%) diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index e53e0a1..eb14acc 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -17,6 +17,9 @@ pub enum Error { InvalidMode, #[error("the status code of a response needs to be in 100..=599")] InvalidStatus, + #[cfg(feature = "stream")] + #[error("a method was called when the message was in the wrong state")] + InvalidState, #[error("IO error {0}")] Io(#[from] std::io::Error), #[cfg(feature = "stream")] diff --git a/bhttp/src/lib.rs b/bhttp/src/lib.rs index f92b2c4..be299a2 100644 --- a/bhttp/src/lib.rs +++ b/bhttp/src/lib.rs @@ -25,7 +25,7 @@ const COOKIE: &[u8] = b"cookie"; const TRANSFER_ENCODING: &[u8] = b"transfer-encoding"; const CHUNKED: &[u8] = b"chunked"; -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub struct StatusCode(u16); impl StatusCode { @@ -78,6 +78,17 @@ pub enum Mode { IndeterminateLength, } +impl TryFrom for Mode { + type Error = Error; + fn try_from(t: u64) -> Result { + match t { + 0 | 1 => Ok(Self::KnownLength), + 2 | 3 => Ok(Self::IndeterminateLength), + _ => Err(Error::InvalidMode), + } + } +} + pub struct Field { name: Vec, value: Vec, @@ -558,10 +569,49 @@ impl InformationalResponse { } } +pub struct Header { + control: ControlData, + fields: FieldSection, +} + +impl Header { + #[must_use] + pub fn control(&self) -> &ControlData { + &self.control + } +} + +impl From for Header { + fn from(control: ControlData) -> Self { + Self { + control, + fields: FieldSection::default(), + } + } +} + +impl From<(ControlData, FieldSection)> for Header { + fn from((control, fields): (ControlData, FieldSection)) -> Self { + Self { control, fields } + } +} + +impl std::ops::Deref for Header { + type Target = FieldSection; + fn deref(&self) -> &Self::Target { + &self.fields + } +} + +impl std::ops::DerefMut for Header { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.fields + } +} + pub struct Message { informational: Vec, - control: ControlData, - header: FieldSection, + header: Header, content: Vec, trailer: FieldSection, } @@ -571,13 +621,12 @@ impl Message { pub fn request(method: Vec, scheme: Vec, authority: Vec, path: Vec) -> Self { Self { informational: Vec::new(), - control: ControlData::Request { + header: Header::from(ControlData::Request { method, scheme, authority, path, - }, - header: FieldSection::default(), + }), content: Vec::new(), trailer: FieldSection::default(), } @@ -587,8 +636,7 @@ impl Message { pub fn response(status: StatusCode) -> Self { Self { informational: Vec::new(), - control: ControlData::Response(status), - header: FieldSection::default(), + header: Header::from(ControlData::Response(status)), content: Vec::new(), trailer: FieldSection::default(), } @@ -613,11 +661,11 @@ impl Message { #[must_use] pub fn control(&self) -> &ControlData { - &self.control + self.header.control() } #[must_use] - pub fn header(&self) -> &FieldSection { + pub fn header(&self) -> &Header { &self.header } @@ -672,20 +720,20 @@ impl Message { control = ControlData::read_http(line)?; } - let mut header = FieldSection::read_http(r)?; + let mut hfields = FieldSection::read_http(r)?; let (content, trailer) = if matches!(control.status().map(StatusCode::code), Some(204 | 304)) { // 204 and 304 have no body, no matter what Content-Length says. // Unfortunately, we can't do the same for responses to HEAD. (Vec::new(), FieldSection::default()) - } else if header.is_chunked() { + } else if hfields.is_chunked() { let content = Self::read_chunked(r)?; let trailer = FieldSection::read_http(r)?; (content, trailer) } else { let mut content = Vec::new(); - if let Some(cl) = header.get(CONTENT_LENGTH) { + if let Some(cl) = hfields.get(CONTENT_LENGTH) { let cl_str = String::from_utf8(Vec::from(cl))?; let cl_int = cl_str.parse::()?; if cl_int > 0 { @@ -700,11 +748,10 @@ impl Message { (content, FieldSection::default()) }; - header.strip_connection_headers(); + hfields.strip_connection_headers(); Ok(Self { informational, - control, - header, + header: Header::from((control, hfields)), content, trailer, }) @@ -716,7 +763,7 @@ impl Message { ControlData::Response(info.status()).write_http(w)?; info.fields().write_http(w)?; } - self.control.write_http(w)?; + self.header.control.write_http(w)?; if !self.content.is_empty() { if self.trailer.is_empty() { write!(w, "Content-Length: {}\r\n", self.content.len())?; @@ -746,11 +793,7 @@ impl Message { { let t = read_varint(r)?.ok_or(Error::Truncated)?; let request = t == 0 || t == 2; - let mode = match t { - 0 | 1 => Mode::KnownLength, - 2 | 3 => Mode::IndeterminateLength, - _ => return Err(Error::InvalidMode), - }; + let mode = Mode::try_from(t)?; let mut control = ControlData::read_bhttp(request, r)?; let mut informational = Vec::new(); @@ -759,7 +802,7 @@ impl Message { informational.push(InformationalResponse::new(status, fields)); control = ControlData::read_bhttp(request, r)?; } - let header = FieldSection::read_bhttp(mode, r)?; + let hfields = FieldSection::read_bhttp(mode, r)?; let mut content = read_vec(r)?.unwrap_or_default(); if mode == Mode::IndeterminateLength && !content.is_empty() { @@ -776,19 +819,18 @@ impl Message { Ok(Self { informational, - control, - header, + header: Header::from((control, hfields)), content, trailer, }) } pub fn write_bhttp(&self, mode: Mode, w: &mut impl io::Write) -> Res<()> { - write_varint(self.control.code(mode), w)?; + write_varint(self.header.control.code(mode), w)?; for info in &self.informational { info.write_bhttp(mode, w)?; } - self.control.write_bhttp(w)?; + self.header.control.write_bhttp(w)?; self.header.write_bhttp(mode, w)?; write_vec(&self.content, w)?; diff --git a/bhttp/src/stream/context.rs b/bhttp/src/stream/future.rs similarity index 50% rename from bhttp/src/stream/context.rs rename to bhttp/src/stream/future.rs index c6987b0..ee26fc9 100644 --- a/bhttp/src/stream/context.rs +++ b/bhttp/src/stream/future.rs @@ -1,9 +1,12 @@ use std::{ future::Future, + pin::{pin, Pin}, task::{Context, Poll}, }; -use futures::FutureExt; +use futures::{TryStream, TryStreamExt}; + +use crate::Error; fn noop_context() -> Context<'static> { use std::{ @@ -35,28 +38,51 @@ fn noop_context() -> Context<'static> { Context::from_waker(noop_waker_ref()) } -fn assert_unpin(v: F) -> F { - v -} - /// Drives the given future (`f`) until it resolves. /// Executes the indicated function (`p`) each time the /// poll returned `Poll::Pending`. -pub fn sync_resolve_with(f: F, p: P) -> F::Output { - let mut cx = noop_context(); - let mut fut = assert_unpin(f); - let mut v = fut.poll_unpin(&mut cx); - while v.is_pending() { - p(&mut fut); - v = fut.poll_unpin(&mut cx); +pub trait SyncResolve { + type Output; + + fn sync_resolve(&mut self) -> Self::Output { + self.sync_resolve_with(|_| {}) } - if let Poll::Ready(v) = v { - v - } else { - unreachable!(); + + fn sync_resolve_with)>(&mut self, p: P) -> Self::Output; +} + +impl SyncResolve for F { + type Output = F::Output; + + fn sync_resolve_with)>(&mut self, p: P) -> Self::Output { + let mut cx = noop_context(); + let mut fut = Pin::new(self); + let mut v = fut.as_mut().poll(&mut cx); + while v.is_pending() { + p(fut.as_mut()); + v = fut.as_mut().poll(&mut cx); + } + if let Poll::Ready(v) = v { + v + } else { + unreachable!(); + } } } -pub fn sync_resolve(f: F) -> F::Output { - sync_resolve_with(f, |_| {}) +pub trait SyncCollect { + type Item; + + fn sync_collect(self) -> Result, Error>; +} + +impl SyncCollect for S +where + S: TryStream, +{ + type Item = S::Ok; + + fn sync_collect(self) -> Result, Error> { + pin!(self.try_collect::>()).sync_resolve() + } } diff --git a/bhttp/src/stream/int.rs b/bhttp/src/stream/int.rs index 02b05df..460e3f1 100644 --- a/bhttp/src/stream/int.rs +++ b/bhttp/src/stream/int.rs @@ -144,7 +144,7 @@ mod test { err::Error, rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, stream::{ - context::sync_resolve, + future::SyncResolve, int::{read_uint, read_varint}, }, }; @@ -172,7 +172,7 @@ mod test { sync_write_uint::<$n>(v, &mut buf).unwrap(); let mut buf_ref = &buf[..]; let mut fut = read_uint::<_, $n>(&mut buf_ref); - assert_eq!(v, sync_resolve(&mut fut).unwrap()); + assert_eq!(v, fut.sync_resolve().unwrap()); let s = fut.stream(); assert!(s.is_empty()); } @@ -196,7 +196,7 @@ mod test { let mut buf = Vec::with_capacity($n); sync_write_uint::<$n>(v, &mut buf).unwrap(); for i in 1..buf.len() { - let err = sync_resolve(read_uint::<_, $n>(&mut &buf[..i])).unwrap_err(); + let err = read_uint::<_, $n>(&mut &buf[..i]).sync_resolve().unwrap_err(); assert!(matches!(err, Error::Truncated)); } } @@ -217,7 +217,7 @@ mod test { sync_write_varint(v, &mut buf).unwrap(); let mut buf_ref = &buf[..]; let mut fut = read_varint(&mut buf_ref); - assert_eq!(Some(v), sync_resolve(&mut fut).unwrap()); + assert_eq!(Some(v), fut.sync_resolve().unwrap()); let s = fut.stream(); assert!(s.is_empty()); } @@ -225,7 +225,7 @@ mod test { #[test] fn read_varint_none() { - assert!(sync_resolve(read_varint(&mut &[][..])).unwrap().is_none()); + assert!(read_varint(&mut &[][..]).sync_resolve().unwrap().is_none()); } #[test] @@ -236,7 +236,7 @@ mod test { for i in 1..buf.len() { let err = { let mut buf: &[u8] = &buf[..i]; - sync_resolve(read_varint(&mut buf)) + read_varint(&mut buf).sync_resolve() } .unwrap_err(); assert!(matches!(err, Error::Truncated)); @@ -253,7 +253,7 @@ mod test { buf.extend_from_slice(EXTRA); let mut buf_ref = &buf[..]; let mut fut = read_varint(&mut buf_ref); - assert_eq!(Some(v), sync_resolve(&mut fut).unwrap()); + assert_eq!(Some(v), fut.sync_resolve().unwrap()); let s = fut.stream(); assert_eq!(&s[..], EXTRA); } diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs index f00b008..011e0d1 100644 --- a/bhttp/src/stream/mod.rs +++ b/bhttp/src/stream/mod.rs @@ -1,5 +1,261 @@ -#![allow(dead_code)] // TODO +#![allow(dead_code)] + +use std::{ + io::{Cursor, Result as IoResult}, + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{stream::unfold, AsyncRead, Stream, TryStreamExt}; +use int::ReadVarint; + +use crate::{ + err::Res, + stream::{int::read_varint, vec::read_vec}, + ControlData, Error, Field, FieldSection, Header, InformationalResponse, Message, Mode, COOKIE, +}; #[cfg(test)] -mod context; +mod future; mod int; mod vec; + +trait AsyncReadControlData: Sized { + async fn async_read(request: bool, src: &mut S) -> Res; +} + +impl AsyncReadControlData for ControlData { + async fn async_read(request: bool, src: &mut S) -> Res { + let v = if request { + let method = read_vec(src).await?.ok_or(Error::Truncated)?; + let scheme = read_vec(src).await?.ok_or(Error::Truncated)?; + let authority = read_vec(src).await?.ok_or(Error::Truncated)?; + let path = read_vec(src).await?.ok_or(Error::Truncated)?; + Self::Request { + method, + scheme, + authority, + path, + } + } else { + Self::Response(crate::StatusCode::try_from( + read_varint(src).await?.ok_or(Error::Truncated)?, + )?) + }; + Ok(v) + } +} + +trait AsyncReadFieldSection: Sized { + async fn async_read(mode: Mode, src: &mut S) -> Res; +} + +impl AsyncReadFieldSection for FieldSection { + async fn async_read(mode: Mode, src: &mut S) -> Res { + let fields = if mode == Mode::KnownLength { + // Known-length fields can just be read into a buffer. + if let Some(buf) = read_vec(src).await? { + Self::read_bhttp_fields(false, &mut Cursor::new(&buf[..]))? + } else { + Vec::new() + } + } else { + // The async version needs to be implemented directly. + let mut fields: Vec = Vec::new(); + let mut cookie_index: Option = None; + loop { + if let Some(n) = read_vec(src).await? { + if n.is_empty() { + break fields; + } + let mut v = read_vec(src).await?.ok_or(Error::Truncated)?; + if n == COOKIE { + if let Some(i) = &cookie_index { + fields[*i].value.extend_from_slice(b"; "); + fields[*i].value.append(&mut v); + continue; + } + cookie_index = Some(fields.len()); + } + fields.push(Field::new(n, v)); + } else { + return Err(Error::Truncated); + } + } + }; + Ok(Self(fields)) + } +} + +enum BodyState<'a, S> { + // When reading the length, use this. + ReadLength(ReadVarint<'a, S>), + // When reading the data, track how much is left. + ReadData { + remaining: usize, + src: Pin<&'a mut S>, + }, +} + +#[pin_project::pin_project] +struct Body<'a, S> { + mode: Mode, + state: BodyState<'a, S>, +} + +impl<'a, S: AsyncRead> AsyncRead for Body<'a, S> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.project().src.as_mut().poll_read(cx, buf) + } +} + +enum AsyncMessageState { + // Processing Informational responses (or before that). + Informational, + // Having obtained the control data for the header, this is it. + Header(ControlData), + // Processing the Body. + Body, + // Processing the trailer. + Trailer, +} + +struct AsyncMessage<'a, S> { + // Whether this is a request and which mode. + framing: Option<(bool, Mode)>, + state: AsyncMessageState, + src: Pin<&'a mut S>, +} + +impl<'a, S: AsyncRead> AsyncMessage<'a, S> { + /// Get the mode. This panics if the header hasn't been read yet. + fn mode(&self) -> Mode { + self.framing.unwrap().1 + } + + async fn next_info(&mut self) -> Res> { + if !matches!(self.state, AsyncMessageState::Informational) { + return Ok(None); + } + + let (request, mode) = if let Some((request, mode)) = self.framing { + (request, mode) + } else { + let t = read_varint(&mut self.src).await?.ok_or(Error::Truncated)?; + let request = t == 0 || t == 2; + let mode = Mode::try_from(t)?; + self.framing = Some((request, mode)); + (request, mode) + }; + + let control = ControlData::async_read(request, &mut self.src).await?; + if let Some(status) = control.informational() { + let fields = FieldSection::async_read(mode, &mut self.src).await?; + Ok(Some(InformationalResponse::new(status, fields))) + } else { + self.state = AsyncMessageState::Header(control); + Ok(None) + } + } + + /// Produces a stream of informational responses from a fresh message. + /// Returns an empty stream if called at other times. + /// Error values on the stream indicate failures. + /// + /// There is no need to call this method to read a request, though + /// doing so is harmless. + /// + /// You can discard the stream that this function returns + /// without affecting the message. You can then either call this + /// method again to get any additional informational responses or + /// call `header()` to get the message header. + pub fn informational( + &mut self, + ) -> impl Stream> + use<'_, 'a, S> { + unfold(self, |this| async move { + this.next_info().await.transpose().map(|info| (info, this)) + }) + } + + /// This reads the header. If you have not called `informational` + /// and drained the resulting stream, this will do that for you. + pub async fn header(&mut self) -> Res
{ + if matches!(self.state, AsyncMessageState::Informational) { + // Need to scrub for errors, + // so that this can abort properly if there is one. + // The `try_any` usage is there to ensure that the stream is fully drained. + _ = self.informational().try_any(|_| async { false }).await?; + } + if matches!(self.state, AsyncMessageState::Header(_)) { + let AsyncMessageState::Header(control) = + mem::replace(&mut self.state, AsyncMessageState::Body) + else { + unreachable!(); + }; + let mode = self.mode(); + let hfields = FieldSection::async_read(mode, &mut self.src).await?; + Ok(Header::from((control, hfields))) + } else { + Err(Error::InvalidState) + } + } + + pub fn body<'s>(&'s mut self) -> Res> + where + 'a: 's, + { + if matches!(self.state, AsyncMessageState::Body) { + Ok(Body { + mode: self.mode(), + state: BodyState::ReadLength(read_varint(self.src.as_mut())), + }) + } else { + Err(Error::InvalidState) + } + } +} + +trait AsyncReadMessage: Sized { + fn async_read(src: &mut S) -> AsyncMessage<'_, S>; +} + +impl AsyncReadMessage for Message { + fn async_read(src: &mut S) -> AsyncMessage<'_, S> { + AsyncMessage { + framing: None, + state: AsyncMessageState::Informational, + src: Pin::new(src), + } + } +} + +#[cfg(test)] +mod test { + use std::pin::pin; + + use crate::{ + stream::{ + future::{SyncCollect, SyncResolve}, + AsyncReadMessage, + }, + Message, + }; + + #[test] + fn informational() { + const INFO: &[u8] = &[1, 64, 100, 0, 64, 200, 0]; + let mut buf_alias = INFO; + let mut msg = Message::async_read(&mut buf_alias); + let info = msg.informational().sync_collect().unwrap(); + assert_eq!(info.len(), 1); + let info = msg.informational().sync_collect().unwrap(); + assert!(info.is_empty()); + let hdr = pin!(msg.header()).sync_resolve().unwrap(); + assert_eq!(hdr.control().status().unwrap().code(), 200); + assert!(hdr.is_empty()); + } +} diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs index d71040a..16dd433 100644 --- a/bhttp/src/stream/vec.rs +++ b/bhttp/src/stream/vec.rs @@ -125,10 +125,7 @@ mod test { use crate::{ rw::write_varint as sync_write_varint, - stream::{ - context::{sync_resolve, sync_resolve_with}, - vec::read_vec, - }, + stream::{future::SyncResolve, vec::read_vec}, Error, }; @@ -154,7 +151,7 @@ mod test { let buf = fill(len); let mut buf_ref = &buf[..]; let mut fut = read_vec(&mut buf_ref); - if let Ok(Some(out)) = sync_resolve(&mut fut) { + if let Ok(Some(out)) = fut.sync_resolve() { assert_eq!(len, out.len()); assert!(out.iter().all(|&v| v == FILL_VALUE)); @@ -170,7 +167,7 @@ mod test { let mut buf_ref = &buf[..]; let mut fut = read_vec(&mut buf_ref); fut.limit(LEN - 1); - assert!(matches!(sync_resolve(&mut fut), Err(Error::LimitExceeded))); + assert!(matches!(fut.sync_resolve(), Err(Error::LimitExceeded))); } /// This class implements `AsyncRead`, but @@ -209,9 +206,9 @@ mod test { #[should_panic(expected = "cannot set a limit once the size has been read")] fn late_cap() { let mut buf = IncompleteRead::new(&[2, 1]); - _ = sync_resolve_with(read_vec(&mut buf), |f| { + _ = read_vec(&mut buf).sync_resolve_with(|mut f| { println!("pending"); - f.limit(100); + f.as_mut().limit(100); }); } From 7d78bc7b48fd7f15f27072cd18db509b95273363 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 28 Oct 2024 11:51:49 +1100 Subject: [PATCH 05/16] Checkpoint, fuck lifetimes --- bhttp/src/lib.rs | 153 ++++++++++++++++- bhttp/src/stream/future.rs | 28 ++- bhttp/src/stream/int.rs | 49 +++--- bhttp/src/stream/mod.rs | 342 +++++++++++++++++++++++++++++++------ bhttp/src/stream/vec.rs | 25 ++- 5 files changed, 495 insertions(+), 102 deletions(-) diff --git a/bhttp/src/lib.rs b/bhttp/src/lib.rs index be299a2..2082b0f 100644 --- a/bhttp/src/lib.rs +++ b/bhttp/src/lib.rs @@ -1,7 +1,11 @@ #![deny(warnings, clippy::pedantic)] #![allow(clippy::missing_errors_doc)] // Too lazy to document these. -use std::{borrow::BorrowMut, io}; +use std::{ + borrow::BorrowMut, + io, + ops::{Deref, DerefMut}, +}; #[cfg(feature = "http")] use url::Url; @@ -25,7 +29,7 @@ const COOKIE: &[u8] = b"cookie"; const TRANSFER_ENCODING: &[u8] = b"transfer-encoding"; const CHUNKED: &[u8] = b"chunked"; -#[derive(Clone, Copy, PartialEq, Eq, Debug)] +#[derive(Clone, Copy, Debug)] pub struct StatusCode(u16); impl StatusCode { @@ -68,6 +72,26 @@ impl From for u16 { } } +#[cfg(test)] +impl PartialEq for StatusCode +where + Self: TryFrom, + T: Copy, +{ + fn eq(&self, other: &T) -> bool { + StatusCode::try_from(*other).map_or(false, |o| o.0 == self.0) + } +} + +#[cfg(not(test))] +impl PartialEq for StatusCode { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for StatusCode {} + pub trait ReadSeek: io::BufRead + io::Seek {} impl ReadSeek for io::Cursor where T: AsRef<[u8]> {} impl ReadSeek for io::BufReader where T: io::Read + io::Seek {} @@ -132,6 +156,18 @@ impl Field { } } +#[cfg(test)] +impl std::fmt::Debug for Field { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + write!( + f, + "{n}: {v}", + n = String::from_utf8_lossy(&self.name), + v = String::from_utf8_lossy(&self.value), + ) + } +} + #[derive(Default)] pub struct FieldSection(Vec); impl FieldSection { @@ -140,15 +176,26 @@ impl FieldSection { self.0.is_empty() } + #[must_use] + pub fn len(&self) -> usize { + self.0.len() + } + /// Gets the value from the first instance of the field. #[must_use] pub fn get(&self, n: &[u8]) -> Option<&[u8]> { - for f in &self.0 { + self.get_all(n).next() + } + + /// Gets all of the values of the named field. + pub fn get_all<'a, 'b>(&'a self, n: &'b [u8]) -> impl Iterator + use<'a, 'b> { + self.0.iter().filter_map(move |f| { if &f.name[..] == n { - return Some(&f.value); + Some(&f.value[..]) + } else { + None } - } - None + }) } pub fn put(&mut self, name: impl Into>, value: impl Into>) { @@ -336,6 +383,16 @@ impl FieldSection { } } +#[cfg(test)] +impl std::fmt::Debug for FieldSection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + for fv in self.fields() { + fv.fmt(f)?; + } + Ok(()) + } +} + pub enum ControlData { Request { method: Vec, @@ -541,6 +598,68 @@ impl ControlData { } } +#[cfg(test)] +impl PartialEq<(M, S, A, P)> for ControlData +where + M: AsRef<[u8]>, + S: AsRef<[u8]>, + A: AsRef<[u8]>, + P: AsRef<[u8]>, +{ + fn eq(&self, other: &(M, S, A, P)) -> bool { + match self { + Self::Request { + method, + scheme, + authority, + path, + } => { + method == other.0.as_ref() + && scheme == other.1.as_ref() + && authority == other.2.as_ref() + && path == other.3.as_ref() + } + Self::Response(_) => false, + } + } +} + +#[cfg(test)] +impl PartialEq for ControlData +where + StatusCode: TryFrom, + T: Copy, +{ + fn eq(&self, other: &T) -> bool { + match self { + Self::Request { .. } => false, + Self::Response(code) => code == other, + } + } +} + +#[cfg(test)] +impl std::fmt::Debug for ControlData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + match self { + Self::Request { + method, + scheme, + authority, + path, + } => write!( + f, + "{m} {s}://{a}{p}", + m = String::from_utf8_lossy(method), + s = String::from_utf8_lossy(scheme), + a = String::from_utf8_lossy(authority), + p = String::from_utf8_lossy(path), + ), + Self::Response(code) => write!(f, "{code:?}"), + } + } +} + pub struct InformationalResponse { status: StatusCode, fields: FieldSection, @@ -569,6 +688,20 @@ impl InformationalResponse { } } +impl Deref for InformationalResponse { + type Target = FieldSection; + + fn deref(&self) -> &Self::Target { + &self.fields + } +} + +impl DerefMut for InformationalResponse { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.fields + } +} + pub struct Header { control: ControlData, fields: FieldSection, @@ -609,6 +742,14 @@ impl std::ops::DerefMut for Header { } } +#[cfg(test)] +impl std::fmt::Debug for Header { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + self.control.fmt(f)?; + self.fields.fmt(f) + } +} + pub struct Message { informational: Vec, header: Header, diff --git a/bhttp/src/stream/future.rs b/bhttp/src/stream/future.rs index ee26fc9..3eae2bb 100644 --- a/bhttp/src/stream/future.rs +++ b/bhttp/src/stream/future.rs @@ -4,7 +4,7 @@ use std::{ task::{Context, Poll}, }; -use futures::{TryStream, TryStreamExt}; +use futures::{AsyncRead, AsyncReadExt, TryStream, TryStreamExt}; use crate::Error; @@ -76,13 +76,31 @@ pub trait SyncCollect { fn sync_collect(self) -> Result, Error>; } -impl SyncCollect for S -where - S: TryStream, -{ +impl> SyncCollect for S { type Item = S::Ok; fn sync_collect(self) -> Result, Error> { pin!(self.try_collect::>()).sync_resolve() } } + +pub trait SyncRead { + fn sync_read_exact(&mut self, amount: usize) -> Vec; + fn sync_read_to_end(&mut self) -> Vec; +} + +impl SyncRead for S { + fn sync_read_exact(&mut self, amount: usize) -> Vec { + let mut buf = vec![0; amount]; + let res = self.read_exact(&mut buf[..]); + pin!(res).sync_resolve().unwrap(); + buf + } + + fn sync_read_to_end(&mut self) -> Vec { + let mut buf = Vec::new(); + let res = self.read_to_end(&mut buf); + pin!(res).sync_resolve().unwrap(); + buf + } +} diff --git a/bhttp/src/stream/int.rs b/bhttp/src/stream/int.rs index 460e3f1..5b248df 100644 --- a/bhttp/src/stream/int.rs +++ b/bhttp/src/stream/int.rs @@ -1,6 +1,6 @@ use std::{ future::Future, - pin::Pin, + pin::{pin, Pin}, task::{Context, Poll}, }; @@ -9,9 +9,9 @@ use futures::io::AsyncRead; use crate::{Error, Res}; #[pin_project::pin_project] -pub struct ReadUint<'a, S, const N: usize> { +pub struct ReadUint { /// The source of data. - src: Pin<&'a mut S>, + src: S, /// A buffer that holds the bytes that have been read so far. v: [u8; 8], /// A counter of the number of bytes that are already in place. @@ -19,21 +19,18 @@ pub struct ReadUint<'a, S, const N: usize> { read: usize, } -impl<'a, S, const N: usize> ReadUint<'a, S, N> { - pub fn stream(self) -> Pin<&'a mut S> { +impl ReadUint { + pub fn stream(self) -> S { self.src } } -impl<'a, S, const N: usize> Future for ReadUint<'a, S, N> -where - S: AsyncRead, -{ +impl Future for ReadUint { type Output = Res; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - match this.src.as_mut().poll_read(cx, &mut this.v[*this.read..]) { + match pin!(this.src).poll_read(cx, &mut this.v[*this.read..]) { Poll::Pending => Poll::Pending, Poll::Ready(Ok(count)) => { if count == 0 { @@ -51,25 +48,26 @@ where } } -pub fn read_uint(src: &mut S) -> ReadUint<'_, S, N> { +#[cfg(test)] +fn read_uint(src: S) -> ReadUint { ReadUint { - src: Pin::new(src), + src, v: [0; 8], read: 8 - N, } } #[pin_project::pin_project(project = ReadVarintProj)] -pub enum ReadVarint<'a, S> { +pub enum ReadVarint { // Invariant: this Option always contains Some. - First(Option>), - Extra1(#[pin] ReadUint<'a, S, 8>), - Extra3(#[pin] ReadUint<'a, S, 8>), - Extra7(#[pin] ReadUint<'a, S, 8>), + First(Option), + Extra1(#[pin] ReadUint), + Extra3(#[pin] ReadUint), + Extra7(#[pin] ReadUint), } -impl<'a, S> ReadVarint<'a, S> { - pub fn stream(self) -> Pin<&'a mut S> { +impl ReadVarint { + pub fn stream(self) -> S { match self { Self::Extra1(s) | Self::Extra3(s) | Self::Extra7(s) => s.stream(), Self::First(mut s) => s.take().unwrap(), @@ -77,18 +75,15 @@ impl<'a, S> ReadVarint<'a, S> { } } -impl<'a, S> Future for ReadVarint<'a, S> -where - S: AsyncRead, -{ +impl Future for ReadVarint { type Output = Res>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.as_mut(); if let Self::First(ref mut src) = this.get_mut() { let mut buf = [0; 1]; - let src_ref = src.as_mut().unwrap().as_mut(); - if let Poll::Ready(res) = src_ref.poll_read(cx, &mut buf[..]) { + let src_ref = src.as_mut().unwrap(); + if let Poll::Ready(res) = pin!(src_ref).poll_read(cx, &mut buf[..]) { match res { Ok(0) => return Poll::Ready(Ok(None)), Ok(_) => (), @@ -134,8 +129,8 @@ where } } -pub fn read_varint(src: &mut S) -> ReadVarint<'_, S> { - ReadVarint::First(Some(Pin::new(src))) +pub fn read_varint(src: S) -> ReadVarint { + ReadVarint::First(Some(src)) } #[cfg(test)] diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs index 011e0d1..9839059 100644 --- a/bhttp/src/stream/mod.rs +++ b/bhttp/src/stream/mod.rs @@ -1,13 +1,14 @@ -#![allow(dead_code)] +#![allow(clippy::incompatible_msrv)] // This module uses features from rust 1.82 use std::{ - io::{Cursor, Result as IoResult}, + cmp::min, + io::{Cursor, Error as IoError, Result as IoResult}, mem, - pin::Pin, + pin::{pin, Pin}, task::{Context, Poll}, }; -use futures::{stream::unfold, AsyncRead, Stream, TryStreamExt}; +use futures::{stream::unfold, AsyncRead, FutureExt, Stream, TryStreamExt}; use int::ReadVarint; use crate::{ @@ -21,16 +22,16 @@ mod int; mod vec; trait AsyncReadControlData: Sized { - async fn async_read(request: bool, src: &mut S) -> Res; + async fn async_read(request: bool, src: S) -> Res; } impl AsyncReadControlData for ControlData { - async fn async_read(request: bool, src: &mut S) -> Res { + async fn async_read(request: bool, mut src: S) -> Res { let v = if request { - let method = read_vec(src).await?.ok_or(Error::Truncated)?; - let scheme = read_vec(src).await?.ok_or(Error::Truncated)?; - let authority = read_vec(src).await?.ok_or(Error::Truncated)?; - let path = read_vec(src).await?.ok_or(Error::Truncated)?; + let method = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + let scheme = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + let authority = read_vec(&mut src).await?.ok_or(Error::Truncated)?; + let path = read_vec(&mut src).await?.ok_or(Error::Truncated)?; Self::Request { method, scheme, @@ -38,23 +39,22 @@ impl AsyncReadControlData for ControlData { path, } } else { - Self::Response(crate::StatusCode::try_from( - read_varint(src).await?.ok_or(Error::Truncated)?, - )?) + let code = read_varint(&mut src).await?.ok_or(Error::Truncated)?; + Self::Response(crate::StatusCode::try_from(code)?) }; Ok(v) } } trait AsyncReadFieldSection: Sized { - async fn async_read(mode: Mode, src: &mut S) -> Res; + async fn async_read(mode: Mode, src: S) -> Res; } impl AsyncReadFieldSection for FieldSection { - async fn async_read(mode: Mode, src: &mut S) -> Res { + async fn async_read(mode: Mode, mut src: S) -> Res { let fields = if mode == Mode::KnownLength { // Known-length fields can just be read into a buffer. - if let Some(buf) = read_vec(src).await? { + if let Some(buf) = read_vec(&mut src).await? { Self::read_bhttp_fields(false, &mut Cursor::new(&buf[..]))? } else { Vec::new() @@ -64,11 +64,11 @@ impl AsyncReadFieldSection for FieldSection { let mut fields: Vec = Vec::new(); let mut cookie_index: Option = None; loop { - if let Some(n) = read_vec(src).await? { + if let Some(n) = read_vec(&mut src).await? { if n.is_empty() { break fields; } - let mut v = read_vec(src).await?.ok_or(Error::Truncated)?; + let mut v = read_vec(&mut src).await?.ok_or(Error::Truncated)?; if n == COOKIE { if let Some(i) = &cookie_index { fields[*i].value.extend_from_slice(b"; "); @@ -78,6 +78,8 @@ impl AsyncReadFieldSection for FieldSection { cookie_index = Some(fields.len()); } fields.push(Field::new(n, v)); + } else if fields.is_empty() { + break fields; } else { return Err(Error::Truncated); } @@ -87,51 +89,113 @@ impl AsyncReadFieldSection for FieldSection { } } -enum BodyState<'a, S> { +#[allow(clippy::mut_mut)] // TODO look into this more. +enum BodyState<'a, 'b, S> { // When reading the length, use this. - ReadLength(ReadVarint<'a, S>), + // Invariant: This is always `Some`. + ReadLength(Option>), // When reading the data, track how much is left. + // Invariant: `src` is always `Some`. ReadData { remaining: usize, - src: Pin<&'a mut S>, + src: Option<&'b mut &'a mut S>, }, } -#[pin_project::pin_project] -struct Body<'a, S> { +pub struct Body<'a, 'b, S> { mode: Mode, - state: BodyState<'a, S>, + state: &'b mut AsyncMessageState<'a, 'b, S>, } -impl<'a, S: AsyncRead> AsyncRead for Body<'a, S> { +impl<'a, 'b, S> Body<'a, 'b, S> { + fn set_state(&mut self, s: BodyState<'a, 'b, S>) { + *self.state = AsyncMessageState::Body(s); + } + + fn done(&mut self) { + *self.state = AsyncMessageState::Trailer; + } +} + +impl<'a, 'b, S: AsyncRead + Unpin> AsyncRead for Body<'a, 'b, S> { fn poll_read( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - self.project().src.as_mut().poll_read(cx, buf) + fn poll_error(e: Error) -> Poll> { + Poll::Ready(Err(IoError::other(e))) + } + + let mode = self.mode; + if let AsyncMessageState::Body(BodyState::ReadLength(r)) = &mut self.state { + match r.as_mut().unwrap().poll_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(Some(0) | None)) => { + self.done(); + return Poll::Ready(Ok(0)); + } + Poll::Ready(Ok(Some(len))) => { + match usize::try_from(len) { + Ok(remaining) => { + let src = r.take().map(ReadVarint::stream); + self.set_state(BodyState::ReadData { remaining, src }); + // fall through to maybe read the body + } + Err(e) => return poll_error(Error::IntRange(e)), + } + } + Poll::Ready(Err(e)) => return poll_error(e), + } + } + + if let AsyncMessageState::Body(BodyState::ReadData { remaining, src }) = &mut self.state { + let amount = min(*remaining, buf.len()); + let res = pin!(src.as_mut().unwrap()).poll_read(cx, &mut buf[..amount]); + match res { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(0)) => poll_error(Error::Truncated), + Poll::Ready(Ok(len)) => { + *remaining -= len; + if *remaining == 0 { + if mode == Mode::IndeterminateLength { + let src = src.take().map(read_varint); + self.set_state(BodyState::ReadLength(src)); + } else { + self.done(); + } + } + Poll::Ready(Ok(len)) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } else { + Poll::Pending + } } } -enum AsyncMessageState { +enum AsyncMessageState<'a, 'b, S> { // Processing Informational responses (or before that). Informational, // Having obtained the control data for the header, this is it. Header(ControlData), // Processing the Body. - Body, + Body(BodyState<'a, 'b, S>), // Processing the trailer. Trailer, } -struct AsyncMessage<'a, S> { +pub struct AsyncMessage<'a, 'b, S> { // Whether this is a request and which mode. framing: Option<(bool, Mode)>, - state: AsyncMessageState, - src: Pin<&'a mut S>, + state: AsyncMessageState<'a, 'b, S>, + src: &'a mut S, } -impl<'a, S: AsyncRead> AsyncMessage<'a, S> { +unsafe impl Send for AsyncMessage<'_, '_, S> {} + +impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { /// Get the mode. This panics if the header hasn't been read yet. fn mode(&self) -> Mode { self.framing.unwrap().1 @@ -175,7 +239,7 @@ impl<'a, S: AsyncRead> AsyncMessage<'a, S> { /// call `header()` to get the message header. pub fn informational( &mut self, - ) -> impl Stream> + use<'_, 'a, S> { + ) -> impl Stream> + use<'_, 'a, 'b, S> { unfold(self, |this| async move { this.next_info().await.transpose().map(|info| (info, this)) }) @@ -183,7 +247,7 @@ impl<'a, S: AsyncRead> AsyncMessage<'a, S> { /// This reads the header. If you have not called `informational` /// and drained the resulting stream, this will do that for you. - pub async fn header(&mut self) -> Res
{ + pub async fn header(&'b mut self) -> Res
{ if matches!(self.state, AsyncMessageState::Informational) { // Need to scrub for errors, // so that this can abort properly if there is one. @@ -191,44 +255,62 @@ impl<'a, S: AsyncRead> AsyncMessage<'a, S> { _ = self.informational().try_any(|_| async { false }).await?; } if matches!(self.state, AsyncMessageState::Header(_)) { + let mode = self.mode(); + let hfields = FieldSection::async_read(mode, &mut self.src).await?; + + let bs: BodyState<'a, 'b, S> = BodyState::ReadLength(Some(read_varint(&mut self.src))); let AsyncMessageState::Header(control) = - mem::replace(&mut self.state, AsyncMessageState::Body) + mem::replace(&mut self.state, AsyncMessageState::Body(bs)) else { unreachable!(); }; - let mode = self.mode(); - let hfields = FieldSection::async_read(mode, &mut self.src).await?; Ok(Header::from((control, hfields))) } else { Err(Error::InvalidState) } } - pub fn body<'s>(&'s mut self) -> Res> - where - 'a: 's, - { - if matches!(self.state, AsyncMessageState::Body) { + /// Read the body. + /// This produces an implementation of `AsyncRead` that filters out + /// the framing from the message body. + /// # Errors + /// This errors when the header has not been read. + /// Any IO errors are generated by the returned `Body` instance. + pub fn body(&'b mut self) -> Res> { + if matches!(self.state, AsyncMessageState::Body(_)) { + let mode = self.mode(); Ok(Body { - mode: self.mode(), - state: BodyState::ReadLength(read_varint(self.src.as_mut())), + mode, + state: &mut self.state, }) } else { Err(Error::InvalidState) } } + + /// Read any trailer. + /// This might be empty. + /// # Errors + /// This errors when the body has not been read. + pub async fn trailer(&mut self) -> Res { + if matches!(self.state, AsyncMessageState::Trailer) { + Ok(FieldSection::async_read(self.mode(), &mut self.src).await?) + } else { + Err(Error::InvalidState) + } + } } -trait AsyncReadMessage: Sized { - fn async_read(src: &mut S) -> AsyncMessage<'_, S>; +pub trait AsyncReadMessage: Sized { + fn async_read<'b, S: AsyncRead + Unpin>(src: &mut S) -> AsyncMessage<'_, 'b, S>; } impl AsyncReadMessage for Message { - fn async_read(src: &mut S) -> AsyncMessage<'_, S> { + fn async_read<'b, S: AsyncRead + Unpin>(src: &mut S) -> AsyncMessage<'_, 'b, S> { AsyncMessage { framing: None, state: AsyncMessageState::Informational, - src: Pin::new(src), + src, } } } @@ -237,14 +319,41 @@ impl AsyncReadMessage for Message { mod test { use std::pin::pin; + use futures::TryStreamExt; + use crate::{ stream::{ - future::{SyncCollect, SyncResolve}, + future::{SyncCollect, SyncRead, SyncResolve}, AsyncReadMessage, }, - Message, + Error, Message, }; + // Example from Section 5.1 of RFC 9292. + const REQUEST1: &[u8] = &[ + 0x00, 0x03, 0x47, 0x45, 0x54, 0x05, 0x68, 0x74, 0x74, 0x70, 0x73, 0x00, 0x0a, 0x2f, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x74, 0x78, 0x74, 0x40, 0x6c, 0x0a, 0x75, 0x73, 0x65, 0x72, + 0x2d, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x34, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, + 0x36, 0x2e, 0x33, 0x20, 0x6c, 0x69, 0x62, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, + 0x36, 0x2e, 0x33, 0x20, 0x4f, 0x70, 0x65, 0x6e, 0x53, 0x53, 0x4c, 0x2f, 0x30, 0x2e, 0x39, + 0x2e, 0x37, 0x6c, 0x20, 0x7a, 0x6c, 0x69, 0x62, 0x2f, 0x31, 0x2e, 0x32, 0x2e, 0x33, 0x04, + 0x68, 0x6f, 0x73, 0x74, 0x0f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, + 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x0f, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x6c, 0x61, + 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x06, 0x65, 0x6e, 0x2c, 0x20, 0x6d, 0x69, 0x00, 0x00, + ]; + const REQUEST2: &[u8] = &[ + 0x02, 0x03, 0x47, 0x45, 0x54, 0x05, 0x68, 0x74, 0x74, 0x70, 0x73, 0x00, 0x0a, 0x2f, 0x68, + 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x74, 0x78, 0x74, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x2d, 0x61, + 0x67, 0x65, 0x6e, 0x74, 0x34, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, 0x36, 0x2e, + 0x33, 0x20, 0x6c, 0x69, 0x62, 0x63, 0x75, 0x72, 0x6c, 0x2f, 0x37, 0x2e, 0x31, 0x36, 0x2e, + 0x33, 0x20, 0x4f, 0x70, 0x65, 0x6e, 0x53, 0x53, 0x4c, 0x2f, 0x30, 0x2e, 0x39, 0x2e, 0x37, + 0x6c, 0x20, 0x7a, 0x6c, 0x69, 0x62, 0x2f, 0x31, 0x2e, 0x32, 0x2e, 0x33, 0x04, 0x68, 0x6f, + 0x73, 0x74, 0x0f, 0x77, 0x77, 0x77, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, + 0x63, 0x6f, 0x6d, 0x0f, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, + 0x75, 0x61, 0x67, 0x65, 0x06, 0x65, 0x6e, 0x2c, 0x20, 0x6d, 0x69, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + #[test] fn informational() { const INFO: &[u8] = &[1, 64, 100, 0, 64, 200, 0]; @@ -258,4 +367,135 @@ mod test { assert_eq!(hdr.control().status().unwrap().code(), 200); assert!(hdr.is_empty()); } + + #[test] + fn sample_requests() { + fn validate_sample_request(mut buf: &[u8]) { + let mut msg = Message::async_read(&mut buf); + let info = msg.informational().sync_collect().unwrap(); + assert!(info.is_empty()); + + let hdr = pin!(msg.header()).sync_resolve().unwrap(); + assert_eq!(hdr.control(), &(b"GET", b"https", b"", b"/hello.txt")); + assert_eq!( + hdr.get(b"user-agent"), + Some(&b"curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"[..]), + ); + assert_eq!(hdr.get(b"host"), Some(&b"www.example.com"[..])); + assert_eq!(hdr.get(b"accept-language"), Some(&b"en, mi"[..])); + assert_eq!(hdr.len(), 3); + + let body = pin!(msg.body().unwrap()).sync_read_to_end(); + assert!(body.is_empty()); + + let trailer = pin!(msg.trailer()).sync_resolve().unwrap(); + assert!(trailer.is_empty()); + } + + validate_sample_request(REQUEST1); + validate_sample_request(REQUEST2); + validate_sample_request(&REQUEST2[..REQUEST2.len() - 12]); + } + + #[test] + fn truncated_header() { + // The indefinite-length request example includes 10 bytes of padding. + // The three additional zero values at the end represent: + // 1. The terminating zero for the header field section. + // 2. The terminating zero for the (empty) body. + // 3. The terminating zero for the (absent) trailer field section. + // The latter two (body and trailer) can be cut and the message will still work. + // The first is not optional; dropping it means that the message is truncated. + let mut buf = &mut &REQUEST2[..REQUEST2.len() - 13]; + let mut msg = Message::async_read(&mut buf); + // Use this test to test skipping a few things. + let err = pin!(msg.header()).sync_resolve().unwrap_err(); + assert!(matches!(err, Error::Truncated)); + } + + #[test] + fn sample_responses() { + const RESPONSE: &[u8] = &[ + 0x03, 0x40, 0x66, 0x07, 0x72, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x0a, 0x22, 0x73, + 0x6c, 0x65, 0x65, 0x70, 0x20, 0x31, 0x35, 0x22, 0x00, 0x40, 0x67, 0x04, 0x6c, 0x69, + 0x6e, 0x6b, 0x23, 0x3c, 0x2f, 0x73, 0x74, 0x79, 0x6c, 0x65, 0x2e, 0x63, 0x73, 0x73, + 0x3e, 0x3b, 0x20, 0x72, 0x65, 0x6c, 0x3d, 0x70, 0x72, 0x65, 0x6c, 0x6f, 0x61, 0x64, + 0x3b, 0x20, 0x61, 0x73, 0x3d, 0x73, 0x74, 0x79, 0x6c, 0x65, 0x04, 0x6c, 0x69, 0x6e, + 0x6b, 0x24, 0x3c, 0x2f, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x2e, 0x6a, 0x73, 0x3e, + 0x3b, 0x20, 0x72, 0x65, 0x6c, 0x3d, 0x70, 0x72, 0x65, 0x6c, 0x6f, 0x61, 0x64, 0x3b, + 0x20, 0x61, 0x73, 0x3d, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x00, 0x40, 0xc8, 0x04, + 0x64, 0x61, 0x74, 0x65, 0x1d, 0x4d, 0x6f, 0x6e, 0x2c, 0x20, 0x32, 0x37, 0x20, 0x4a, + 0x75, 0x6c, 0x20, 0x32, 0x30, 0x30, 0x39, 0x20, 0x31, 0x32, 0x3a, 0x32, 0x38, 0x3a, + 0x35, 0x33, 0x20, 0x47, 0x4d, 0x54, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x06, + 0x41, 0x70, 0x61, 0x63, 0x68, 0x65, 0x0d, 0x6c, 0x61, 0x73, 0x74, 0x2d, 0x6d, 0x6f, + 0x64, 0x69, 0x66, 0x69, 0x65, 0x64, 0x1d, 0x57, 0x65, 0x64, 0x2c, 0x20, 0x32, 0x32, + 0x20, 0x4a, 0x75, 0x6c, 0x20, 0x32, 0x30, 0x30, 0x39, 0x20, 0x31, 0x39, 0x3a, 0x31, + 0x35, 0x3a, 0x35, 0x36, 0x20, 0x47, 0x4d, 0x54, 0x04, 0x65, 0x74, 0x61, 0x67, 0x14, + 0x22, 0x33, 0x34, 0x61, 0x61, 0x33, 0x38, 0x37, 0x2d, 0x64, 0x2d, 0x31, 0x35, 0x36, + 0x38, 0x65, 0x62, 0x30, 0x30, 0x22, 0x0d, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, + 0x72, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x05, 0x62, 0x79, 0x74, 0x65, 0x73, 0x0e, 0x63, + 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x02, + 0x35, 0x31, 0x04, 0x76, 0x61, 0x72, 0x79, 0x0f, 0x41, 0x63, 0x63, 0x65, 0x70, 0x74, + 0x2d, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x0c, 0x63, 0x6f, 0x6e, 0x74, + 0x65, 0x6e, 0x74, 0x2d, 0x74, 0x79, 0x70, 0x65, 0x0a, 0x74, 0x65, 0x78, 0x74, 0x2f, + 0x70, 0x6c, 0x61, 0x69, 0x6e, 0x00, 0x33, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x57, + 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x20, 0x4d, 0x79, 0x20, 0x63, 0x6f, 0x6e, 0x74, 0x65, + 0x6e, 0x74, 0x20, 0x69, 0x6e, 0x63, 0x6c, 0x75, 0x64, 0x65, 0x73, 0x20, 0x61, 0x20, + 0x74, 0x72, 0x61, 0x69, 0x6c, 0x69, 0x6e, 0x67, 0x20, 0x43, 0x52, 0x4c, 0x46, 0x2e, + 0x0d, 0x0a, 0x00, 0x00, + ]; + + let mut buf = RESPONSE; + let mut msg = Message::async_read(&mut buf); + + { + // Need to scope access to `info` or it will hold the reference to `msg`. + let mut info = pin!(msg.informational()); + + let info1 = info.try_next().sync_resolve().unwrap().unwrap(); + assert_eq!(info1.status(), 102_u16); + assert_eq!(info1.len(), 1); + assert_eq!(info1.get(b"running"), Some(&b"\"sleep 15\""[..])); + + let info2 = info.try_next().sync_resolve().unwrap().unwrap(); + assert_eq!(info2.status(), 103_u16); + assert_eq!(info2.len(), 2); + let links = info2.get_all(b"link").collect::>(); + assert_eq!( + &links, + &[ + &b"; rel=preload; as=style"[..], + &b"; rel=preload; as=script"[..], + ] + ); + + assert!(info.try_next().sync_resolve().unwrap().is_none()); + } + + let hdr = pin!(msg.header()).sync_resolve().unwrap(); + assert_eq!(hdr.control(), &200_u16); + assert_eq!(hdr.len(), 8); + assert_eq!(hdr.get(b"vary"), Some(&b"Accept-Encoding"[..])); + assert_eq!(hdr.get(b"etag"), Some(&b"\"34aa387-d-1568eb00\""[..])); + + { + let mut body = pin!(msg.body().unwrap()); + assert_eq!(body.sync_read_exact(12), b"Hello World!"); + } + // Attempting to read the trailer before finishing the body should fail. + assert!(matches!( + pin!(msg.trailer()).sync_resolve(), + Err(Error::InvalidState) + )); + { + // Picking up the body again should work fine. + let mut body = pin!(msg.body().unwrap()); + assert_eq!( + body.sync_read_to_end(), + b" My content includes a trailing CRLF.\r\n" + ); + } + let trailer = pin!(msg.trailer()).sync_resolve().unwrap(); + assert!(trailer.is_empty()); + } } diff --git a/bhttp/src/stream/vec.rs b/bhttp/src/stream/vec.rs index 16dd433..05a4e24 100644 --- a/bhttp/src/stream/vec.rs +++ b/bhttp/src/stream/vec.rs @@ -1,7 +1,7 @@ use std::{ future::Future, mem, - pin::Pin, + pin::{pin, Pin}, task::{Context, Poll}, }; @@ -12,24 +12,26 @@ use crate::{Error, Res}; #[pin_project::pin_project(project = ReadVecProj)] #[allow(clippy::module_name_repetitions)] -pub enum ReadVec<'a, S> { +pub enum ReadVec { // Invariant: This Option is always Some. ReadLen { - src: Option>, + src: Option>, cap: u64, }, ReadBody { - src: Pin<&'a mut S>, + src: S, buf: Vec, remaining: usize, }, } -impl<'a, S> ReadVec<'a, S> { +impl ReadVec { + #![allow(dead_code)] // TODO these really need to be used. + /// # Panics /// If `limit` is more than `usize::MAX` or /// if this is called after the length is read. - fn limit(&mut self, limit: u64) { + pub fn limit(&mut self, limit: u64) { usize::try_from(limit).expect("cannot set a limit larger than usize::MAX"); if let Self::ReadLen { ref mut cap, .. } = self { *cap = limit; @@ -38,7 +40,7 @@ impl<'a, S> ReadVec<'a, S> { } } - fn stream(self) -> Pin<&'a mut S> { + pub fn stream(self) -> S { match self { Self::ReadLen { mut src, .. } => src.take().unwrap().stream(), Self::ReadBody { src, .. } => src, @@ -46,10 +48,7 @@ impl<'a, S> ReadVec<'a, S> { } } -impl<'a, S> Future for ReadVec<'a, S> -where - S: AsyncRead, -{ +impl Future for ReadVec { type Output = Res>>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -86,7 +85,7 @@ where }; let offset = buf.len() - *remaining; - match src.as_mut().poll_read(cx, &mut buf[offset..]) { + match pin!(src).poll_read(cx, &mut buf[offset..]) { Poll::Pending => Poll::Pending, Poll::Ready(Err(e)) => Poll::Ready(Err(Error::from(e))), Poll::Ready(Ok(0)) => Poll::Ready(Err(Error::Truncated)), @@ -103,7 +102,7 @@ where } #[allow(clippy::module_name_repetitions)] -pub fn read_vec(src: &mut S) -> ReadVec<'_, S> { +pub fn read_vec(src: S) -> ReadVec { ReadVec::ReadLen { src: Some(read_varint(src)), cap: u64::try_from(usize::MAX).unwrap_or(u64::MAX), From 6f4ec9a616fee6d139f3d2b3a5a2247b02e22782 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 28 Oct 2024 18:20:11 +1100 Subject: [PATCH 06/16] Working enough for now --- bhttp/src/stream/future.rs | 20 +++ bhttp/src/stream/mod.rs | 340 +++++++++++++++++++++++-------------- 2 files changed, 237 insertions(+), 123 deletions(-) diff --git a/bhttp/src/stream/future.rs b/bhttp/src/stream/future.rs index 3eae2bb..9d0362c 100644 --- a/bhttp/src/stream/future.rs +++ b/bhttp/src/stream/future.rs @@ -104,3 +104,23 @@ impl SyncRead for S { buf } } + +pub struct Dribble { + src: S, +} + +impl Dribble { + pub fn new(src: S) -> Self { + Self { src } + } +} + +impl AsyncRead for Dribble { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + pin!(&mut self.src).poll_read(cx, &mut buf[..1]) + } +} diff --git a/bhttp/src/stream/mod.rs b/bhttp/src/stream/mod.rs index 9839059..7496d8f 100644 --- a/bhttp/src/stream/mod.rs +++ b/bhttp/src/stream/mod.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] #![allow(clippy::incompatible_msrv)] // This module uses features from rust 1.82 use std::{ @@ -8,8 +9,7 @@ use std::{ task::{Context, Poll}, }; -use futures::{stream::unfold, AsyncRead, FutureExt, Stream, TryStreamExt}; -use int::ReadVarint; +use futures::{stream::unfold, AsyncRead, Stream, TryStreamExt}; use crate::{ err::Res, @@ -89,135 +89,95 @@ impl AsyncReadFieldSection for FieldSection { } } -#[allow(clippy::mut_mut)] // TODO look into this more. -enum BodyState<'a, 'b, S> { +#[derive(Default)] +enum BodyState { + // The starting state. + #[default] + Init, // When reading the length, use this. - // Invariant: This is always `Some`. - ReadLength(Option>), + ReadLength { + buf: [u8; 8], + read: usize, + }, // When reading the data, track how much is left. - // Invariant: `src` is always `Some`. ReadData { remaining: usize, - src: Option<&'b mut &'a mut S>, }, } -pub struct Body<'a, 'b, S> { - mode: Mode, - state: &'b mut AsyncMessageState<'a, 'b, S>, -} - -impl<'a, 'b, S> Body<'a, 'b, S> { - fn set_state(&mut self, s: BodyState<'a, 'b, S>) { - *self.state = AsyncMessageState::Body(s); +impl BodyState { + fn read_len() -> Self { + Self::ReadLength { + buf: [0; 8], + read: 0, + } } +} - fn done(&mut self) { - *self.state = AsyncMessageState::Trailer; - } +pub struct Body<'b, S> { + msg: &'b mut AsyncMessage, } -impl<'a, 'b, S: AsyncRead + Unpin> AsyncRead for Body<'a, 'b, S> { +impl<'b, S> Body<'b, S> {} + +impl<'b, S: AsyncRead + Unpin> AsyncRead for Body<'b, S> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - fn poll_error(e: Error) -> Poll> { - Poll::Ready(Err(IoError::other(e))) - } - - let mode = self.mode; - if let AsyncMessageState::Body(BodyState::ReadLength(r)) = &mut self.state { - match r.as_mut().unwrap().poll_unpin(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(Ok(Some(0) | None)) => { - self.done(); - return Poll::Ready(Ok(0)); - } - Poll::Ready(Ok(Some(len))) => { - match usize::try_from(len) { - Ok(remaining) => { - let src = r.take().map(ReadVarint::stream); - self.set_state(BodyState::ReadData { remaining, src }); - // fall through to maybe read the body - } - Err(e) => return poll_error(Error::IntRange(e)), - } - } - Poll::Ready(Err(e)) => return poll_error(e), - } - } - - if let AsyncMessageState::Body(BodyState::ReadData { remaining, src }) = &mut self.state { - let amount = min(*remaining, buf.len()); - let res = pin!(src.as_mut().unwrap()).poll_read(cx, &mut buf[..amount]); - match res { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(0)) => poll_error(Error::Truncated), - Poll::Ready(Ok(len)) => { - *remaining -= len; - if *remaining == 0 { - if mode == Mode::IndeterminateLength { - let src = src.take().map(read_varint); - self.set_state(BodyState::ReadLength(src)); - } else { - self.done(); - } - } - Poll::Ready(Ok(len)) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - } - } else { - Poll::Pending - } + self.msg.read_body(cx, buf).map_err(IoError::other) } } -enum AsyncMessageState<'a, 'b, S> { +/// A helper function for the more complex body-reading code. +fn poll_error(e: Error) -> Poll> { + Poll::Ready(Err(IoError::other(e))) +} + +enum AsyncMessageState { + Init, // Processing Informational responses (or before that). - Informational, + Informational(bool), // Having obtained the control data for the header, this is it. Header(ControlData), // Processing the Body. - Body(BodyState<'a, 'b, S>), + Body(BodyState), // Processing the trailer. Trailer, + // All done. + Done, } -pub struct AsyncMessage<'a, 'b, S> { +pub struct AsyncMessage { // Whether this is a request and which mode. - framing: Option<(bool, Mode)>, - state: AsyncMessageState<'a, 'b, S>, - src: &'a mut S, + mode: Option, + state: AsyncMessageState, + src: S, } -unsafe impl Send for AsyncMessage<'_, '_, S> {} - -impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { - /// Get the mode. This panics if the header hasn't been read yet. - fn mode(&self) -> Mode { - self.framing.unwrap().1 - } +unsafe impl Send for AsyncMessage {} +impl AsyncMessage { async fn next_info(&mut self) -> Res> { - if !matches!(self.state, AsyncMessageState::Informational) { - return Ok(None); - } - - let (request, mode) = if let Some((request, mode)) = self.framing { - (request, mode) - } else { + let request = if matches!(self.state, AsyncMessageState::Init) { + // Read control data ... let t = read_varint(&mut self.src).await?.ok_or(Error::Truncated)?; let request = t == 0 || t == 2; - let mode = Mode::try_from(t)?; - self.framing = Some((request, mode)); - (request, mode) + self.mode = Some(Mode::try_from(t)?); + self.state = AsyncMessageState::Informational(request); + request + } else { + // ... or recover it. + let AsyncMessageState::Informational(request) = self.state else { + return Err(Error::InvalidState); + }; + request }; let control = ControlData::async_read(request, &mut self.src).await?; if let Some(status) = control.informational() { + let mode = self.mode.unwrap(); let fields = FieldSection::async_read(mode, &mut self.src).await?; Ok(Some(InformationalResponse::new(status, fields))) } else { @@ -227,7 +187,7 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { } /// Produces a stream of informational responses from a fresh message. - /// Returns an empty stream if called at other times. + /// Returns an empty stream if passed a request (or if there are no informational responses). /// Error values on the stream indicate failures. /// /// There is no need to call this method to read a request, though @@ -237,9 +197,7 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { /// without affecting the message. You can then either call this /// method again to get any additional informational responses or /// call `header()` to get the message header. - pub fn informational( - &mut self, - ) -> impl Stream> + use<'_, 'a, 'b, S> { + pub fn informational(&mut self) -> impl Stream> + use<'_, S> { unfold(self, |this| async move { this.next_info().await.transpose().map(|info| (info, this)) }) @@ -247,21 +205,27 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { /// This reads the header. If you have not called `informational` /// and drained the resulting stream, this will do that for you. - pub async fn header(&'b mut self) -> Res
{ - if matches!(self.state, AsyncMessageState::Informational) { + /// # Panics + /// Never. + pub async fn header(&mut self) -> Res
{ + if matches!( + self.state, + AsyncMessageState::Init | AsyncMessageState::Informational(_) + ) { // Need to scrub for errors, // so that this can abort properly if there is one. // The `try_any` usage is there to ensure that the stream is fully drained. _ = self.informational().try_any(|_| async { false }).await?; } + if matches!(self.state, AsyncMessageState::Header(_)) { - let mode = self.mode(); + let mode = self.mode.unwrap(); let hfields = FieldSection::async_read(mode, &mut self.src).await?; - let bs: BodyState<'a, 'b, S> = BodyState::ReadLength(Some(read_varint(&mut self.src))); - let AsyncMessageState::Header(control) = - mem::replace(&mut self.state, AsyncMessageState::Body(bs)) - else { + let AsyncMessageState::Header(control) = mem::replace( + &mut self.state, + AsyncMessageState::Body(BodyState::default()), + ) else { unreachable!(); }; Ok(Header::from((control, hfields))) @@ -270,21 +234,146 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { } } + fn body_state(&mut self, s: BodyState) { + self.state = AsyncMessageState::Body(s); + } + + fn body_done(&mut self) { + self.state = AsyncMessageState::Trailer; + } + + /// Read the length of a body chunk. + /// This updates the values of `read` and `buf` to track the portion of the length + /// that was successfully read. + /// Returns `Some` with the error code that should be used if the reading + /// resulted in a conclusive outcome. + fn read_body_len( + cx: &mut Context<'_>, + mut src: &mut S, + first: bool, + read: &mut usize, + buf: &mut [u8; 8], + ) -> Option>> { + let mut src = pin!(src); + if *read == 0 { + let mut b = [0; 1]; + match src.as_mut().poll_read(cx, &mut b[..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => { + return if first { + // It's OK for the first length to be absent. + // Just skip to the end. + *read = 8; + None + } else { + // ...it's not OK to drop length when continuing. + Some(poll_error(Error::Truncated)) + }; + } + Poll::Ready(Ok(1)) => match b[0] >> 6 { + 0 => { + buf[7] = b[0] & 0x3f; + *read = 8; + } + 1 => { + buf[6] = b[0] & 0x3f; + *read = 7; + } + 2 => { + buf[4] = b[0] & 0x3f; + *read = 5; + } + 3 => { + buf[0] = b[0] & 0x3f; + *read = 1; + } + _ => unreachable!(), + }, + Poll::Ready(Ok(_)) => unreachable!(), + Poll::Ready(Err(e)) => return Some(Poll::Ready(Err(e))), + } + } + if *read < 8 { + match src.as_mut().poll_read(cx, &mut buf[*read..]) { + Poll::Pending => return Some(Poll::Pending), + Poll::Ready(Ok(0)) => return Some(poll_error(Error::Truncated)), + Poll::Ready(Ok(len)) => { + *read += len; + } + Poll::Ready(Err(e)) => return Some(Poll::Ready(Err(e))), + } + } + None + } + + fn read_body(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + // The length that precedes the first chunk can be absent. + // Only allow that for the first chunk (if indeterminate length). + let first = if let AsyncMessageState::Body(BodyState::Init) = &self.state { + self.body_state(BodyState::read_len()); + true + } else { + false + }; + + // Read the length. This uses `read_body_len` to track the state of this reading. + // This doesn't use `ReadVarint` or any convenience functions because we + // need to track the state and we don't want the borrow checker to flip out. + if let AsyncMessageState::Body(BodyState::ReadLength { buf, read }) = &mut self.state { + if let Some(res) = Self::read_body_len(cx, &mut self.src, first, read, buf) { + return res; + } + if *read == 8 { + match usize::try_from(u64::from_be_bytes(*buf)) { + Ok(0) => { + self.body_done(); + return Poll::Ready(Ok(0)); + } + Ok(remaining) => { + self.body_state(BodyState::ReadData { remaining }); + } + Err(e) => return poll_error(Error::IntRange(e)), + } + } + } + + match &mut self.state { + AsyncMessageState::Body(BodyState::ReadData { remaining }) => { + let amount = min(*remaining, buf.len()); + let res = pin!(&mut self.src).poll_read(cx, &mut buf[..amount]); + match res { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(0)) => poll_error(Error::Truncated), + Poll::Ready(Ok(len)) => { + *remaining -= len; + if *remaining == 0 { + let mode = self.mode.unwrap(); + if mode == Mode::IndeterminateLength { + self.body_state(BodyState::read_len()); + } else { + self.body_done(); + } + } + Poll::Ready(Ok(len)) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + } + } + AsyncMessageState::Trailer => Poll::Ready(Ok(0)), + _ => Poll::Pending, + } + } + /// Read the body. /// This produces an implementation of `AsyncRead` that filters out /// the framing from the message body. /// # Errors /// This errors when the header has not been read. /// Any IO errors are generated by the returned `Body` instance. - pub fn body(&'b mut self) -> Res> { - if matches!(self.state, AsyncMessageState::Body(_)) { - let mode = self.mode(); - Ok(Body { - mode, - state: &mut self.state, - }) - } else { - Err(Error::InvalidState) + pub fn body(&mut self) -> Res> { + match self.state { + AsyncMessageState::Body(_) => Ok(Body { msg: self }), + _ => Err(Error::InvalidState), } } @@ -292,9 +381,13 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { /// This might be empty. /// # Errors /// This errors when the body has not been read. + /// # Panics + /// Never. pub async fn trailer(&mut self) -> Res { if matches!(self.state, AsyncMessageState::Trailer) { - Ok(FieldSection::async_read(self.mode(), &mut self.src).await?) + let trailer = FieldSection::async_read(self.mode.unwrap(), &mut self.src).await?; + self.state = AsyncMessageState::Done; + Ok(trailer) } else { Err(Error::InvalidState) } @@ -302,14 +395,14 @@ impl<'a, 'b, S: AsyncRead + Unpin> AsyncMessage<'a, 'b, S> { } pub trait AsyncReadMessage: Sized { - fn async_read<'b, S: AsyncRead + Unpin>(src: &mut S) -> AsyncMessage<'_, 'b, S>; + fn async_read(src: S) -> AsyncMessage; } impl AsyncReadMessage for Message { - fn async_read<'b, S: AsyncRead + Unpin>(src: &mut S) -> AsyncMessage<'_, 'b, S> { + fn async_read(src: S) -> AsyncMessage { AsyncMessage { - framing: None, - state: AsyncMessageState::Informational, + mode: None, + state: AsyncMessageState::Init, src, } } @@ -323,7 +416,7 @@ mod test { use crate::{ stream::{ - future::{SyncCollect, SyncRead, SyncResolve}, + future::{Dribble, SyncCollect, SyncRead, SyncResolve}, AsyncReadMessage, }, Error, Message, @@ -361,8 +454,8 @@ mod test { let mut msg = Message::async_read(&mut buf_alias); let info = msg.informational().sync_collect().unwrap(); assert_eq!(info.len(), 1); - let info = msg.informational().sync_collect().unwrap(); - assert!(info.is_empty()); + let err = msg.informational().sync_collect(); + assert!(matches!(err, Err(Error::InvalidState))); let hdr = pin!(msg.header()).sync_resolve().unwrap(); assert_eq!(hdr.control().status().unwrap().code(), 200); assert!(hdr.is_empty()); @@ -413,8 +506,9 @@ mod test { assert!(matches!(err, Error::Truncated)); } + /// This test is crazy. It reads a byte at a time and checks the state constantly. #[test] - fn sample_responses() { + fn sample_response() { const RESPONSE: &[u8] = &[ 0x03, 0x40, 0x66, 0x07, 0x72, 0x75, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x0a, 0x22, 0x73, 0x6c, 0x65, 0x65, 0x70, 0x20, 0x31, 0x35, 0x22, 0x00, 0x40, 0x67, 0x04, 0x6c, 0x69, @@ -446,7 +540,7 @@ mod test { ]; let mut buf = RESPONSE; - let mut msg = Message::async_read(&mut buf); + let mut msg = Message::async_read(Dribble::new(&mut buf)); { // Need to scope access to `info` or it will hold the reference to `msg`. From aac02a98a940153ab8fe1a8109612f79931f51f2 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 11:32:55 +1100 Subject: [PATCH 07/16] Update to use dep: syntax for dependencies --- bhttp/Cargo.toml | 6 +++--- ohttp/Cargo.toml | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bhttp/Cargo.toml b/bhttp/Cargo.toml index 14aff2b..21776f9 100644 --- a/bhttp/Cargo.toml +++ b/bhttp/Cargo.toml @@ -9,9 +9,9 @@ description = "Binary HTTP messages (RFC 9292)" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["stream"] -http = ["url"] -stream = ["futures", "pin-project"] +default = [] +http = ["dep:url"] +stream = ["dep:futures", "dep:pin-project"] [dependencies] futures = {version = "0.3", optional = true} diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index ffaabd5..237e3b1 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -14,11 +14,11 @@ default = ["client", "server", "rust-hpke"] app-svc = ["nss"] client = [] external-sqlite = [] -gecko = ["nss", "mozbuild"] -nss = ["bindgen", "regex-mess"] -pq = ["hpke-pq"] -regex-mess = ["regex", "regex-automata", "regex-syntax"] -rust-hpke = ["rand", "aead", "aes-gcm", "chacha20poly1305", "hkdf", "sha2", "hpke"] +gecko = ["nss", "dep:mozbuild"] +nss = ["dep:bindgen", "regex-mess"] +pq = ["dep:hpke-pq"] +regex-mess = ["dep:regex", "dep:regex-automata", "dep:regex-syntax"] +rust-hpke = ["dep:rand", "dep:aead", "dep:aes-gcm", "dep:chacha20poly1305", "dep:hkdf", "dep:sha2", "dep:hpke"] server = [] [dependencies] From ab494ddca7491ad2047779d8dd8ebfbe0432c830 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 13:24:20 +1100 Subject: [PATCH 08/16] Fix NSS dependency ordering; remove dead conditional code --- bhttp/src/err.rs | 3 --- ohttp/build.rs | 20 +++++++------------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index 19d5455..d9d9b6f 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -6,9 +6,6 @@ pub enum Error { ConnectUnsupported, #[error("a field contained invalid Unicode: {0}")] CharacterEncoding(#[from] std::string::FromUtf8Error), - #[error("a chunk of data of {0} bytes is too large")] - #[cfg(feature = "stream")] - ChunkTooLarge(u64), #[error("read a response when expecting a request")] ExpectedRequest, #[error("read a request when expecting a response")] diff --git a/ohttp/build.rs b/ohttp/build.rs index 1c01e3f..312cce2 100644 --- a/ohttp/build.rs +++ b/ohttp/build.rs @@ -8,8 +8,6 @@ #[cfg(feature = "nss")] mod nss { - use bindgen::Builder; - use serde_derive::Deserialize; use std::{ collections::HashMap, env, fs, @@ -17,6 +15,9 @@ mod nss { process::Command, }; + use bindgen::Builder; + use serde_derive::Deserialize; + const BINDINGS_DIR: &str = "bindings"; const BINDINGS_CONFIG: &str = "bindings.toml"; @@ -114,7 +115,6 @@ mod nss { let mut build_nss = vec![ String::from("./build.sh"), String::from("-Ddisable_tests=1"), - String::from("-Denable_draft_hpke=1"), ]; if is_debug() { build_nss.push(String::from("--static")); @@ -191,16 +191,8 @@ mod nss { } fn static_link(nsslibdir: &Path, use_static_softoken: bool, use_static_nspr: bool) { - let mut static_libs = vec![ - "certdb", - "certhi", - "cryptohi", - "nss_static", - "nssb", - "nssdev", - "nsspki", - "nssutil", - ]; + // The ordering of these libraries is critical for the linker. + let mut static_libs = vec!["cryptohi", "nss_static"]; let mut dynamic_libs = vec![]; if use_static_softoken { @@ -211,6 +203,8 @@ mod nss { static_libs.push("pk11wrap"); } + static_libs.extend_from_slice(&["nsspki", "nssdev", "nssb", "certhi", "certdb", "nssutil"]); + if use_static_nspr { static_libs.append(&mut nspr_libs()); } else { From 69a22b95843b8f50f831eeaad620e670664940f6 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 13:24:42 +1100 Subject: [PATCH 09/16] Improve varint codec test coverage --- bhttp/src/rw.rs | 84 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 73 insertions(+), 11 deletions(-) diff --git a/bhttp/src/rw.rs b/bhttp/src/rw.rs index 92009ed..b081dea 100644 --- a/bhttp/src/rw.rs +++ b/bhttp/src/rw.rs @@ -8,12 +8,10 @@ use crate::{err::Error, ReadSeek}; #[cfg(feature = "write-bhttp")] #[allow(clippy::cast_possible_truncation)] -fn write_uint(n: u8, v: impl Into, w: &mut impl io::Write) -> Res<()> { - let v = v.into(); - assert!(n > 0 && usize::from(n) < std::mem::size_of::()); - for i in 0..n { - w.write_all(&[((v >> (8 * (n - i - 1))) & 0xff) as u8])?; - } +pub(crate) fn write_uint(v: impl Into, w: &mut impl io::Write) -> Res<()> { + let v = v.into().to_be_bytes(); + assert!((1..=std::mem::size_of::()).contains(&N)); + w.write_all(&v[std::mem::size_of::() - N..])?; Ok(()) } @@ -21,11 +19,11 @@ fn write_uint(n: u8, v: impl Into, w: &mut impl io::Write) -> Res<()> { pub fn write_varint(v: impl Into, w: &mut impl io::Write) -> Res<()> { let v = v.into(); match () { - () if v < (1 << 6) => write_uint(1, v, w), - () if v < (1 << 14) => write_uint(2, v | (1 << 14), w), - () if v < (1 << 30) => write_uint(4, v | (2 << 30), w), - () if v < (1 << 62) => write_uint(8, v | (3 << 62), w), - () => panic!("Varint value too large"), + () if v < (1 << 6) => write_uint::<1>(v, w), + () if v < (1 << 14) => write_uint::<2>(v | (1 << 14), w), + () if v < (1 << 30) => write_uint::<4>(v | (2 << 30), w), + () if v < (1 << 62) => write_uint::<8>(v | (3 << 62), w), + () => panic!("varint value too large"), } } @@ -106,3 +104,67 @@ where Ok(None) } } + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use super::{read_varint, write_varint}; + use crate::{rw::read_vec, Error}; + + #[test] + fn basics() { + for i in [ + 0_u64, + 1, + 17, + 63, + 64, + 100, + 0x3fff, + 0x4000, + 0x1_0002, + 0x3fff_ffff, + 0x4000_0000, + 0x3456_dead_beef, + 0x3fff_ffff_ffff_ffff, + ] { + let mut buf = Vec::new(); + write_varint(i, &mut buf).unwrap(); + let sz_bytes = (64 - i.leading_zeros() + 2 + 7) / 8; // +2 size bits, +7 to round up + assert_eq!( + buf.len(), + usize::try_from(sz_bytes.next_power_of_two()).unwrap() + ); + + let o = read_varint(&mut Cursor::new(buf.clone())).unwrap(); + assert_eq!(Some(i), o); + + for cut in 1..buf.len() { + let e = read_varint(&mut Cursor::new(buf[..cut].to_vec())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } + } + } + + #[test] + fn read_nothing() { + let o = read_varint(&mut Cursor::new(Vec::new())).unwrap(); + assert!(o.is_none()); + } + + #[test] + #[should_panic(expected = "varint value too large")] + fn too_big() { + _ = write_varint(0x4000_0000_0000_0000_u64, &mut Vec::new()); + } + + #[test] + fn too_big_vec() { + let mut buf = Vec::new(); + write_varint(10_u64, &mut buf).unwrap(); + buf.resize(10, 0); // Not enough extra for the promised length. + let e = read_vec(&mut Cursor::new(buf.clone())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } +} From fc800e300e17fe18b7ae7eacae9c3c2f98e789a4 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 13:25:45 +1100 Subject: [PATCH 10/16] Add mutants output to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 200abe7..39ac3d3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *~ *.swp /.vscode/ +/mutants.out*/ From 00cfe4f018b52c52f8ae5b00273bda4af12ef8d8 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 15:00:05 +1100 Subject: [PATCH 11/16] Merge main --- .gitignore | 1 + bhttp/src/err.rs | 3 --- bhttp/src/rw.rs | 68 ++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 200abe7..39ac3d3 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *~ *.swp /.vscode/ +/mutants.out*/ diff --git a/bhttp/src/err.rs b/bhttp/src/err.rs index eb14acc..45d6fab 100644 --- a/bhttp/src/err.rs +++ b/bhttp/src/err.rs @@ -4,9 +4,6 @@ pub enum Error { ConnectUnsupported, #[error("a field contained invalid Unicode: {0}")] CharacterEncoding(#[from] std::string::FromUtf8Error), - #[error("a chunk of data of {0} bytes is too large")] - #[cfg(feature = "stream")] - ChunkTooLarge(u64), #[error("read a response when expecting a request")] ExpectedRequest, #[error("read a request when expecting a response")] diff --git a/bhttp/src/rw.rs b/bhttp/src/rw.rs index fa7c717..4659dd4 100644 --- a/bhttp/src/rw.rs +++ b/bhttp/src/rw.rs @@ -9,7 +9,7 @@ use crate::{ pub(crate) fn write_uint(v: impl Into, w: &mut impl io::Write) -> Res<()> { let v = v.into().to_be_bytes(); assert!((1..=std::mem::size_of::()).contains(&N)); - w.write_all(&v[8 - N..])?; + w.write_all(&v[std::mem::size_of::() - N..])?; Ok(()) } @@ -20,7 +20,7 @@ pub fn write_varint(v: impl Into, w: &mut impl io::Write) -> Res<()> { () if v < (1 << 14) => write_uint::<2>(v | (1 << 14), w), () if v < (1 << 30) => write_uint::<4>(v | (2 << 30), w), () if v < (1 << 62) => write_uint::<8>(v | (3 << 62), w), - () => panic!("Varint value too large"), + () => panic!("varint value too large"), } } @@ -92,3 +92,67 @@ where Ok(None) } } + +#[cfg(test)] +mod test { + use std::io::Cursor; + + use super::{read_varint, write_varint}; + use crate::{rw::read_vec, Error}; + + #[test] + fn basics() { + for i in [ + 0_u64, + 1, + 17, + 63, + 64, + 100, + 0x3fff, + 0x4000, + 0x1_0002, + 0x3fff_ffff, + 0x4000_0000, + 0x3456_dead_beef, + 0x3fff_ffff_ffff_ffff, + ] { + let mut buf = Vec::new(); + write_varint(i, &mut buf).unwrap(); + let sz_bytes = (64 - i.leading_zeros() + 2 + 7) / 8; // +2 size bits, +7 to round up + assert_eq!( + buf.len(), + usize::try_from(sz_bytes.next_power_of_two()).unwrap() + ); + + let o = read_varint(&mut Cursor::new(buf.clone())).unwrap(); + assert_eq!(Some(i), o); + + for cut in 1..buf.len() { + let e = read_varint(&mut Cursor::new(buf[..cut].to_vec())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } + } + } + + #[test] + fn read_nothing() { + let o = read_varint(&mut Cursor::new(Vec::new())).unwrap(); + assert!(o.is_none()); + } + + #[test] + #[should_panic(expected = "varint value too large")] + fn too_big() { + _ = write_varint(0x4000_0000_0000_0000_u64, &mut Vec::new()); + } + + #[test] + fn too_big_vec() { + let mut buf = Vec::new(); + write_varint(10_u64, &mut buf).unwrap(); + buf.resize(10, 0); // Not enough extra for the promised length. + let e = read_vec(&mut Cursor::new(buf.clone())).unwrap_err(); + assert!(matches!(e, Error::Truncated)); + } +} From e19de296dff3e35cae5b828f512af7bc2863fa1b Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Tue, 29 Oct 2024 15:40:19 +1100 Subject: [PATCH 12/16] Better formatting --- bhttp-convert/src/main.rs | 3 +- ohttp-client-cli/src/main.rs | 3 +- ohttp-client/src/main.rs | 3 +- ohttp/src/config.rs | 17 +++++----- ohttp/src/lib.rs | 29 ++++++++-------- ohttp/src/nss/aead.rs | 14 ++++---- ohttp/src/nss/err.rs | 3 +- ohttp/src/nss/hkdf.rs | 6 ++-- ohttp/src/nss/hpke.rs | 15 +++++---- ohttp/src/nss/mod.rs | 6 ++-- ohttp/src/nss/p11.rs | 5 +-- ohttp/src/rh/aead.rs | 8 +++-- ohttp/src/rh/hkdf.rs | 7 ++-- ohttp/src/rh/hpke.rs | 23 ++++++------- pre-commit | 65 +++++++++++++++++++++++++----------- 15 files changed, 123 insertions(+), 84 deletions(-) diff --git a/bhttp-convert/src/main.rs b/bhttp-convert/src/main.rs index 763fadf..054c64f 100644 --- a/bhttp-convert/src/main.rs +++ b/bhttp-convert/src/main.rs @@ -1,11 +1,12 @@ #![deny(warnings, clippy::pedantic)] -use bhttp::{Message, Mode}; use std::{ fs::File, io::{self, Read}, path::PathBuf, }; + +use bhttp::{Message, Mode}; use structopt::StructOpt; #[derive(Debug, StructOpt)] diff --git a/ohttp-client-cli/src/main.rs b/ohttp-client-cli/src/main.rs index 7acc0f1..37a948d 100644 --- a/ohttp-client-cli/src/main.rs +++ b/ohttp-client-cli/src/main.rs @@ -1,8 +1,9 @@ #![deny(warnings, clippy::pedantic)] +use std::io::{self, BufRead, Write}; + use bhttp::{Message, Mode}; use ohttp::{init, ClientRequest}; -use std::io::{self, BufRead, Write}; fn main() { init(); diff --git a/ohttp-client/src/main.rs b/ohttp-client/src/main.rs index 1d71dc2..4e8a431 100644 --- a/ohttp-client/src/main.rs +++ b/ohttp-client/src/main.rs @@ -1,7 +1,8 @@ #![deny(warnings, clippy::pedantic)] -use bhttp::{Message, Mode}; use std::{fs::File, io, io::Read, ops::Deref, path::PathBuf, str::FromStr}; + +use bhttp::{Message, Mode}; use structopt::StructOpt; type Res = Result>; diff --git a/ohttp/src/config.rs b/ohttp/src/config.rs index 9b55755..ebadb3c 100644 --- a/ohttp/src/config.rs +++ b/ohttp/src/config.rs @@ -1,24 +1,24 @@ -use crate::{ - err::{Error, Res}, - hpke::{Aead as AeadId, Kdf, Kem}, - KeyId, -}; -use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; use std::{ convert::TryFrom, io::{BufRead, BufReader, Cursor, Read}, }; +use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; + #[cfg(feature = "nss")] use crate::nss::{ hpke::{generate_key_pair, Config as HpkeConfig, HpkeR}, PrivateKey, PublicKey, }; - #[cfg(feature = "rust-hpke")] use crate::rh::hpke::{ derive_key_pair, generate_key_pair, Config as HpkeConfig, HpkeR, PrivateKey, PublicKey, }; +use crate::{ + err::{Error, Res}, + hpke::{Aead as AeadId, Kdf, Kem}, + KeyId, +}; /// A tuple of KDF and AEAD identifiers. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -270,11 +270,12 @@ impl AsRef for KeyConfig { #[cfg(test)] mod test { + use std::iter::zip; + use crate::{ hpke::{Aead, Kdf, Kem}, init, Error, KeyConfig, KeyId, SymmetricSuite, }; - use std::iter::zip; const KEY_ID: KeyId = 1; const KEM: Kem = Kem::X25519Sha256; diff --git a/ohttp/src/lib.rs b/ohttp/src/lib.rs index 38e3666..8c07fa3 100644 --- a/ohttp/src/lib.rs +++ b/ohttp/src/lib.rs @@ -15,17 +15,6 @@ mod rand; #[cfg(feature = "rust-hpke")] mod rh; -pub use crate::{ - config::{KeyConfig, SymmetricSuite}, - err::Error, -}; - -use crate::{ - err::Res, - hpke::{Aead as AeadId, Kdf, Kem}, -}; -use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; -use log::trace; use std::{ cmp::max, convert::TryFrom, @@ -33,6 +22,9 @@ use std::{ mem::size_of, }; +use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; +use log::trace; + #[cfg(feature = "nss")] use crate::nss::random; #[cfg(feature = "nss")] @@ -41,7 +33,6 @@ use crate::nss::{ hkdf::{Hkdf, KeyMechanism}, hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, }; - #[cfg(feature = "rust-hpke")] use crate::rand::random; #[cfg(feature = "rust-hpke")] @@ -50,6 +41,14 @@ use crate::rh::{ hkdf::{Hkdf, KeyMechanism}, hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, }; +pub use crate::{ + config::{KeyConfig, SymmetricSuite}, + err::Error, +}; +use crate::{ + err::Res, + hpke::{Aead as AeadId, Kdf, Kem}, +}; /// The request header is a `KeyId` and 2 each for KEM, KDF, and AEAD identifiers const REQUEST_HEADER_LEN: usize = size_of::() + 6; @@ -312,14 +311,16 @@ impl ClientResponse { #[cfg(all(test, feature = "client", feature = "server"))] mod test { + use std::{fmt::Debug, io::ErrorKind}; + + use log::trace; + use crate::{ config::SymmetricSuite, err::Res, hpke::{Aead, Kdf, Kem}, ClientRequest, Error, KeyConfig, KeyId, Server, }; - use log::trace; - use std::{fmt::Debug, io::ErrorKind}; const KEY_ID: KeyId = 1; const KEM: Kem = Kem::X25519Sha256; diff --git a/ohttp/src/nss/aead.rs b/ohttp/src/nss/aead.rs index 18f0b66..aecfc90 100644 --- a/ohttp/src/nss/aead.rs +++ b/ohttp/src/nss/aead.rs @@ -1,3 +1,11 @@ +use std::{ + convert::{TryFrom, TryInto}, + mem, + os::raw::c_int, +}; + +use log::trace; + use super::{ err::secstatus_to_res, p11::{ @@ -13,12 +21,6 @@ use crate::{ err::{Error, Res}, hpke::Aead as AeadId, }; -use log::trace; -use std::{ - convert::{TryFrom, TryInto}, - mem, - os::raw::c_int, -}; /// All the nonces are the same length. Exploit that. pub const NONCE_LEN: usize = 12; diff --git a/ohttp/src/nss/err.rs b/ohttp/src/nss/err.rs index af85066..bc7c86d 100644 --- a/ohttp/src/nss/err.rs +++ b/ohttp/src/nss/err.rs @@ -10,9 +10,10 @@ clippy::module_name_repetitions )] +use std::os::raw::c_char; + use super::{SECStatus, SECSuccess}; use crate::err::Res; -use std::os::raw::c_char; include!(concat!(env!("OUT_DIR"), "/nspr_error.rs")); mod codes { diff --git a/ohttp/src/nss/hkdf.rs b/ohttp/src/nss/hkdf.rs index 470b1dd..af38ba9 100644 --- a/ohttp/src/nss/hkdf.rs +++ b/ohttp/src/nss/hkdf.rs @@ -1,3 +1,7 @@ +use std::{convert::TryFrom, os::raw::c_int, ptr::null_mut}; + +use log::trace; + use super::{ super::hpke::{Aead, Kdf}, p11::{ @@ -10,8 +14,6 @@ use super::{ }, }; use crate::err::Res; -use log::trace; -use std::{convert::TryFrom, os::raw::c_int, ptr::null_mut}; #[derive(Clone, Copy)] pub enum KeyMechanism { diff --git a/ohttp/src/nss/hpke.rs b/ohttp/src/nss/hpke.rs index b7ef845..16d018d 100644 --- a/ohttp/src/nss/hpke.rs +++ b/ohttp/src/nss/hpke.rs @@ -1,10 +1,3 @@ -use super::{ - super::hpke::{Aead, Kdf, Kem}, - err::{sec::SEC_ERROR_INVALID_ARGS, secstatus_to_res, Error}, - p11::{sys, Item, PrivateKey, PublicKey, Slot, SymKey}, -}; -use crate::err::Res; -use log::{log_enabled, trace}; use std::{ convert::TryFrom, ops::Deref, @@ -12,8 +5,16 @@ use std::{ ptr::{addr_of_mut, null, null_mut}, }; +use log::{log_enabled, trace}; pub use sys::{HpkeAeadId as AeadId, HpkeKdfId as KdfId, HpkeKemId as KemId}; +use super::{ + super::hpke::{Aead, Kdf, Kem}, + err::{sec::SEC_ERROR_INVALID_ARGS, secstatus_to_res, Error}, + p11::{sys, Item, PrivateKey, PublicKey, Slot, SymKey}, +}; +use crate::err::Res; + /// Configuration for `Hpke`. #[derive(Clone, Copy)] pub struct Config { diff --git a/ohttp/src/nss/mod.rs b/ohttp/src/nss/mod.rs index 1b60c9e..c9d4c7e 100644 --- a/ohttp/src/nss/mod.rs +++ b/ohttp/src/nss/mod.rs @@ -11,11 +11,13 @@ pub mod aead; pub mod hkdf; pub mod hpke; -pub use self::p11::{random, PrivateKey, PublicKey}; +use std::ptr::null; + use err::secstatus_to_res; pub use err::Error; use lazy_static::lazy_static; -use std::ptr::null; + +pub use self::p11::{random, PrivateKey, PublicKey}; #[allow(clippy::pedantic, non_upper_case_globals, clippy::upper_case_acronyms)] mod nss_init { diff --git a/ohttp/src/nss/p11.rs b/ohttp/src/nss/p11.rs index 9bddef6..19d9420 100644 --- a/ohttp/src/nss/p11.rs +++ b/ohttp/src/nss/p11.rs @@ -4,8 +4,6 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use super::err::{secstatus_to_res, Error}; -use crate::err::Res; use std::{ convert::TryFrom, marker::PhantomData, @@ -14,6 +12,9 @@ use std::{ ptr::null_mut, }; +use super::err::{secstatus_to_res, Error}; +use crate::err::Res; + #[allow( clippy::pedantic, clippy::upper_case_acronyms, diff --git a/ohttp/src/rh/aead.rs b/ohttp/src/rh/aead.rs index f209086..9bfe13a 100644 --- a/ohttp/src/rh/aead.rs +++ b/ohttp/src/rh/aead.rs @@ -1,11 +1,13 @@ #![allow(dead_code)] // TODO: remove -use super::SymKey; -use crate::{err::Res, hpke::Aead as AeadId}; +use std::convert::TryFrom; + use aead::{AeadMut, Key, NewAead, Nonce, Payload}; use aes_gcm::{Aes128Gcm, Aes256Gcm}; use chacha20poly1305::ChaCha20Poly1305; -use std::convert::TryFrom; + +use super::SymKey; +use crate::{err::Res, hpke::Aead as AeadId}; /// All the nonces are the same length. Exploit that. pub const NONCE_LEN: usize = 12; diff --git a/ohttp/src/rh/hkdf.rs b/ohttp/src/rh/hkdf.rs index aeb3a8d..a88f60e 100644 --- a/ohttp/src/rh/hkdf.rs +++ b/ohttp/src/rh/hkdf.rs @@ -1,13 +1,14 @@ #![allow(dead_code)] // TODO: remove +use hkdf::Hkdf as HkdfImpl; +use log::trace; +use sha2::{Sha256, Sha384, Sha512}; + use super::SymKey; use crate::{ err::{Error, Res}, hpke::{Aead, Kdf}, }; -use hkdf::Hkdf as HkdfImpl; -use log::trace; -use sha2::{Sha256, Sha384, Sha512}; #[derive(Clone, Copy)] pub enum KeyMechanism { diff --git a/ohttp/src/rh/hpke.rs b/ohttp/src/rh/hpke.rs index 4b81152..ce8ee78 100644 --- a/ohttp/src/rh/hpke.rs +++ b/ohttp/src/rh/hpke.rs @@ -1,15 +1,13 @@ -use super::SymKey; -use crate::{ - hpke::{Aead, Kdf, Kem}, - Error, Res, -}; +use std::ops::Deref; #[cfg(not(feature = "pq"))] use ::hpke as rust_hpke; - #[cfg(feature = "pq")] use ::hpke_pq as rust_hpke; - +use ::rand::thread_rng; +use log::trace; +#[cfg(feature = "pq")] +use rust_hpke::kem::X25519Kyber768Draft00; use rust_hpke::{ aead::{AeadCtxR, AeadCtxS, AeadTag, AesGcm128, ChaCha20Poly1305}, kdf::HkdfSha256, @@ -17,12 +15,11 @@ use rust_hpke::{ setup_receiver, setup_sender, Deserializable, OpModeR, OpModeS, Serializable, }; -#[cfg(feature = "pq")] -use rust_hpke::kem::X25519Kyber768Draft00; - -use ::rand::thread_rng; -use log::trace; -use std::ops::Deref; +use super::SymKey; +use crate::{ + hpke::{Aead, Kdf, Kem}, + Error, Res, +}; /// Configuration for `Hpke`. #[derive(Clone, Copy)] diff --git a/pre-commit b/pre-commit index 758b923..18e917d 100755 --- a/pre-commit +++ b/pre-commit @@ -6,10 +6,12 @@ # $ ln -s ../../hooks/pre-commit .git/hooks/pre-commit root="$(git rev-parse --show-toplevel 2>/dev/null)" +RUST_FMT_CFG="imports_granularity=Crate,group_imports=StdExternalCrate" # Some sanity checking. -hash cargo || exit 1 -[[ -n "$root" ]] || exit 1 +set -e +hash cargo +[[ -n "$root" ]] # Installation. if [[ "$1" == "install" ]]; then @@ -23,31 +25,54 @@ if [[ "$1" == "install" ]]; then exit fi -# Check formatting. +# Stash unstaged changes if [[ "$1" != "all" ]]; then - msg="pre-commit stash @$(git rev-parse --short @) $RANDOM" - trap 'git stash list -1 --format="format:%s" | grep -q "'"$msg"'" && git stash pop -q' EXIT - git stash push -k -u -q -m "$msg" + stashdir="$(mktemp -d "$root"/.pre-commit.stashXXXXXX)" + msg="pre-commit stash @$(git rev-parse --short @) ${stashdir##*.stash}" + gitdir="$(git rev-parse --git-dir 2>/dev/null)" + + stash() { + # Move MERGE_[HEAD|MODE|MSG] files to the root directory, and let `git stash push` save them. + find "$gitdir" -maxdepth 1 -name 'MERGE_*' -exec mv \{\} "$stashdir" \; + git stash push -k -u -q -m "$msg" + } + + unstash() { + git stash list -1 --format="format:%s" | grep -q "$msg" && git stash pop -q + # Moves MERGE files restored by `git stash pop` back into .git/ directory. + if [[ -d "$stashdir" ]]; then + find "$stashdir" -exec mv -n \{\} "$gitdir" \; + rmdir "$stashdir" + fi + } + + trap unstash EXIT + stash fi -if ! errors=($(cargo fmt -- --check --config imports_granularity=crate -l)); then - echo "Formatting errors found." - echo "Run \`cargo fmt\` to fix the following files:" + +# Check formatting +if ! errors=($(cargo fmt -- --check --config "$RUST_FMT_CFG" -l)); then + echo "Formatting errors found in:" for err in "${errors[@]}"; do echo " $err" done + echo "To fix, run \`cargo fmt -- --config $RUST_FMT_CFG\`" exit 1 fi -if ! cargo clippy --tests; then - exit 1 -fi -if ! cargo test; then - exit 1 -fi -if [[ -n "$NSS_DIR" ]]; then - if ! cargo clippy --tests --no-default-features --features nss; then - exit 1 - fi - if ! cargo test --no-default-features --features nss; then + +check() { + msg="$1" + shift + if ! echo "$@"; then + echo "${msg}: Failed command:" + echo " ${@@Q}" exit 1 fi +} + +check "clippy" cargo clippy --tests +check "test" cargo test +if [[ -n "$NSS_DIR" ]]; then + check "clippy(NSS)" cargo clippy --tests --no-default-features --features nss + check "test(NSS)" cargo test --no-default-features --features nss fi From c91933c7c21e066ed313b40d87e83af0a63612eb Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Sat, 2 Nov 2024 10:34:24 +0000 Subject: [PATCH 13/16] Refactor stream helper functions --- Cargo.toml | 2 +- bhttp/Cargo.toml | 5 ++++- bhttp/src/stream/int.rs | 7 +++---- sync-async/Cargo.toml | 12 ++++++++++++ bhttp/src/stream/future.rs => sync-async/src/lib.rs | 11 ++++++----- 5 files changed, 26 insertions(+), 11 deletions(-) create mode 100644 sync-async/Cargo.toml rename bhttp/src/stream/future.rs => sync-async/src/lib.rs (92%) diff --git a/Cargo.toml b/Cargo.toml index 0621518..fb7825e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,5 +6,5 @@ members = [ "ohttp", "ohttp-client", "ohttp-client-cli", - "ohttp-server", + "ohttp-server", "sync-async", ] diff --git a/bhttp/Cargo.toml b/bhttp/Cargo.toml index 21776f9..fe3d325 100644 --- a/bhttp/Cargo.toml +++ b/bhttp/Cargo.toml @@ -20,4 +20,7 @@ thiserror = "1" url = {version = "2", optional = true} [dev-dependencies] -hex = "0.4" \ No newline at end of file +hex = "0.4" + +[dev-dependencies.sync-async] +path= "../sync-async" diff --git a/bhttp/src/stream/int.rs b/bhttp/src/stream/int.rs index 5b248df..f70231b 100644 --- a/bhttp/src/stream/int.rs +++ b/bhttp/src/stream/int.rs @@ -135,13 +135,12 @@ pub fn read_varint(src: S) -> ReadVarint { #[cfg(test)] mod test { + use sync_async::SyncResolve; + use crate::{ err::Error, rw::{write_uint as sync_write_uint, write_varint as sync_write_varint}, - stream::{ - future::SyncResolve, - int::{read_uint, read_varint}, - }, + stream::int::{read_uint, read_varint}, }; const VARINTS: &[u64] = &[ diff --git a/sync-async/Cargo.toml b/sync-async/Cargo.toml new file mode 100644 index 0000000..c419df1 --- /dev/null +++ b/sync-async/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "sync-async" +version = "0.5.3" +authors = ["Martin Thomson "] +edition = "2021" +license = "MIT OR Apache-2.0" +description = "Synchronous Helpers for Async Code" +repository = "https://github.com/martinthomson/ohttp" + +[dependencies] +futures = "0.3" +pin-project = "1.1" \ No newline at end of file diff --git a/bhttp/src/stream/future.rs b/sync-async/src/lib.rs similarity index 92% rename from bhttp/src/stream/future.rs rename to sync-async/src/lib.rs index 9d0362c..e21ca07 100644 --- a/bhttp/src/stream/future.rs +++ b/sync-async/src/lib.rs @@ -6,8 +6,6 @@ use std::{ use futures::{AsyncRead, AsyncReadExt, TryStream, TryStreamExt}; -use crate::Error; - fn noop_context() -> Context<'static> { use std::{ ptr::null, @@ -26,6 +24,7 @@ fn noop_context() -> Context<'static> { } pub fn noop_waker_ref() -> &'static Waker { + #[repr(transparent)] struct SyncRawWaker(RawWaker); unsafe impl Sync for SyncRawWaker {} @@ -72,14 +71,16 @@ impl SyncResolve for F { pub trait SyncCollect { type Item; + type Error; - fn sync_collect(self) -> Result, Error>; + fn sync_collect(self) -> Result, Self::Error>; } -impl> SyncCollect for S { +impl SyncCollect for S { type Item = S::Ok; + type Error = S::Error; - fn sync_collect(self) -> Result, Error> { + fn sync_collect(self) -> Result, Self::Error> { pin!(self.try_collect::>()).sync_resolve() } } From 880d958e94341eefa01e7cb38cffd98ed8630c28 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Sat, 2 Nov 2024 10:51:42 +0000 Subject: [PATCH 14/16] Move to use OnceLock --- ohttp/Cargo.toml | 3 ++- ohttp/src/nss/mod.rs | 23 ++++++++++------------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index 237e3b1..5e8c484 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -20,6 +20,7 @@ pq = ["dep:hpke-pq"] regex-mess = ["dep:regex", "dep:regex-automata", "dep:regex-syntax"] rust-hpke = ["dep:rand", "dep:aead", "dep:aes-gcm", "dep:chacha20poly1305", "dep:hkdf", "dep:sha2", "dep:hpke"] server = [] +stream = [] [dependencies] aead = {version = "0.4", optional = true, features = ["std"]} @@ -29,7 +30,6 @@ chacha20poly1305 = {version = "0.8", optional = true} hex = "0.4" hkdf = {version = "0.11", optional = true} hpke = {version = "0.11.0", optional = true, default-features = false, features = ["std", "x25519"]} -lazy_static = "1.4" log = {version = "0.4", default-features = false} rand = {version = "0.8", optional = true} # bindgen uses regex and friends, which have been updated past our MSRV @@ -64,3 +64,4 @@ features = ["runtime"] [dev-dependencies] env_logger = {version = "0.10", default-features = false} +sync-async = {path = "../sync-async"} diff --git a/ohttp/src/nss/mod.rs b/ohttp/src/nss/mod.rs index c9d4c7e..c282a13 100644 --- a/ohttp/src/nss/mod.rs +++ b/ohttp/src/nss/mod.rs @@ -15,7 +15,6 @@ use std::ptr::null; use err::secstatus_to_res; pub use err::Error; -use lazy_static::lazy_static; pub use self::p11::{random, PrivateKey, PublicKey}; @@ -47,17 +46,7 @@ impl Drop for NssLoaded { } } -lazy_static! { - static ref INITIALIZED: NssLoaded = { - if already_initialized() { - return NssLoaded::External; - } - - secstatus_to_res(unsafe { nss_init::NSS_NoDB_Init(null()) }).expect("NSS_NoDB_Init failed"); - - NssLoaded::NoDb - }; -} +static INITIALIZED: OnceLock = OnceLock::new(); fn already_initialized() -> bool { unsafe { nss_init::NSS_IsInitialized() != 0 } @@ -65,5 +54,13 @@ fn already_initialized() -> bool { /// Initialize NSS. This only executes the initialization routines once. pub fn init() { - lazy_static::initialize(&INITIALIZED); + INITIALIZED.get_or_init(|| { + if already_initialized() { + NssLoaded::External + } else { + secstatus_to_res(unsafe { nss_init::NSS_NoDB_Init(null()) }) + .expect("NSS_NoDB_Init failed"); + NssLoaded::NoDb + } + }); } From 69402f463e8d041aecbb6a08d6196e35620febd6 Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Mon, 11 Nov 2024 12:01:39 +0000 Subject: [PATCH 15/16] Checkpoint for request streaming --- ohttp/Cargo.toml | 6 +++-- ohttp/src/lib.rs | 68 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/ohttp/Cargo.toml b/ohttp/Cargo.toml index 5e8c484..01a2027 100644 --- a/ohttp/Cargo.toml +++ b/ohttp/Cargo.toml @@ -10,7 +10,7 @@ description = "Oblivious HTTP" repository = "https://github.com/martinthomson/ohttp" [features] -default = ["client", "server", "rust-hpke"] +default = ["client", "server", "rust-hpke", "stream"] app-svc = ["nss"] client = [] external-sqlite = [] @@ -20,17 +20,19 @@ pq = ["dep:hpke-pq"] regex-mess = ["dep:regex", "dep:regex-automata", "dep:regex-syntax"] rust-hpke = ["dep:rand", "dep:aead", "dep:aes-gcm", "dep:chacha20poly1305", "dep:hkdf", "dep:sha2", "dep:hpke"] server = [] -stream = [] +stream = ["dep:futures", "dep:pin-project"] [dependencies] aead = {version = "0.4", optional = true, features = ["std"]} aes-gcm = {version = "0.9", optional = true} byteorder = "1.4" chacha20poly1305 = {version = "0.8", optional = true} +futures = {version = "0.3", optional = true} hex = "0.4" hkdf = {version = "0.11", optional = true} hpke = {version = "0.11.0", optional = true, default-features = false, features = ["std", "x25519"]} log = {version = "0.4", default-features = false} +pin-project = {version = "1.1", optional = true} rand = {version = "0.8", optional = true} # bindgen uses regex and friends, which have been updated past our MSRV # however, the cargo resolver happily resolves versions that it can't compile diff --git a/ohttp/src/lib.rs b/ohttp/src/lib.rs index 8c07fa3..0ce3e79 100644 --- a/ohttp/src/lib.rs +++ b/ohttp/src/lib.rs @@ -14,6 +14,8 @@ mod nss; mod rand; #[cfg(feature = "rust-hpke")] mod rh; +#[cfg(feature = "stream")] +mod stream; use std::{ cmp::max, @@ -23,7 +25,10 @@ use std::{ }; use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; +#[cfg(feature = "stream")] +use futures::AsyncRead; use log::trace; +use rh::hpke::PublicKey; #[cfg(feature = "nss")] use crate::nss::random; @@ -41,6 +46,8 @@ use crate::rh::{ hkdf::{Hkdf, KeyMechanism}, hpke::{Config as HpkeConfig, Exporter, HpkeR, HpkeS}, }; +#[cfg(feature = "stream")] +use crate::stream::ClientRequestStream; pub use crate::{ config::{KeyConfig, SymmetricSuite}, err::Error, @@ -53,8 +60,6 @@ use crate::{ /// The request header is a `KeyId` and 2 each for KEM, KDF, and AEAD identifiers const REQUEST_HEADER_LEN: usize = size_of::() + 6; const INFO_REQUEST: &[u8] = b"message/bhttp request"; -/// The info used for HPKE export is `INFO_REQUEST`, a zero byte, and the header. -const INFO_LEN: usize = INFO_REQUEST.len() + 1 + REQUEST_HEADER_LEN; const LABEL_RESPONSE: &[u8] = b"message/bhttp response"; const INFO_KEY: &[u8] = b"key"; const INFO_NONCE: &[u8] = b"nonce"; @@ -68,9 +73,9 @@ pub fn init() { } /// Construct the info parameter we use to initialize an `HpkeS` instance. -fn build_info(key_id: KeyId, config: HpkeConfig) -> Res> { - let mut info = Vec::with_capacity(INFO_LEN); - info.extend_from_slice(INFO_REQUEST); +fn build_info(label: &[u8], key_id: KeyId, config: HpkeConfig) -> Res> { + let mut info = Vec::with_capacity(label.len() + 1 + REQUEST_HEADER_LEN); + info.extend_from_slice(label); info.push(0); info.write_u8(key_id)?; info.write_u16::(u16::from(config.kem()))?; @@ -84,8 +89,9 @@ fn build_info(key_id: KeyId, config: HpkeConfig) -> Res> { /// This might not be necessary if we agree on a format. #[cfg(feature = "client")] pub struct ClientRequest { - hpke: HpkeS, - header: Vec, + key_id: KeyId, + config: HpkeConfig, + pk: PublicKey, } #[cfg(feature = "client")] @@ -94,14 +100,11 @@ impl ClientRequest { pub fn from_config(config: &mut KeyConfig) -> Res { // TODO(mt) choose the best config, not just the first. let selected = config.select(config.symmetric[0])?; - - // Build the info, which contains the message header. - let info = build_info(config.key_id, selected)?; - let hpke = HpkeS::new(selected, &mut config.pk, &info)?; - - let header = Vec::from(&info[INFO_REQUEST.len() + 1..]); - debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); - Ok(Self { hpke, header }) + Ok(Self { + key_id: config.key_id, + config: selected, + pk: config.pk.clone(), + }) } /// Reads an encoded configuration and constructs a single use client sender. @@ -126,21 +129,41 @@ impl ClientRequest { /// Encapsulate a request. This consumes this object. /// This produces a response handler and the bytes of an encapsulated request. pub fn encapsulate(mut self, request: &[u8]) -> Res<(Vec, ClientResponse)> { - let extra = - self.hpke.config().kem().n_enc() + self.hpke.config().aead().n_t() + request.len(); - let expected_len = self.header.len() + extra; + // Build the info, which contains the message header. + let info = build_info(INFO_REQUEST, self.key_id, self.config)?; + let mut hpke = HpkeS::new(self.config, &mut self.pk, &info)?; - let mut enc_request = self.header; + let header = Vec::from(&info[INFO_REQUEST.len() + 1..]); + debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); + + let extra = hpke.config().kem().n_enc() + hpke.config().aead().n_t() + request.len(); + let expected_len = header.len() + extra; + + let mut enc_request = header; enc_request.reserve_exact(extra); - let enc = self.hpke.enc()?; + let enc = hpke.enc()?; enc_request.extend_from_slice(&enc); - let mut ct = self.hpke.seal(&[], request)?; + let mut ct = hpke.seal(&[], request)?; enc_request.append(&mut ct); debug_assert_eq!(expected_len, enc_request.len()); - Ok((enc_request, ClientResponse::new(self.hpke, enc))) + Ok((enc_request, ClientResponse::new(hpke, enc))) + } + + #[cfg(feature = "stream")] + pub fn encapsulate_stream(mut self, src: S) -> Res> { + let info = build_info(crate::stream::INFO_REQUEST, self.key_id, self.config)?; + let hpke = HpkeS::new(self.config, &mut self.pk, &info)?; + + let mut header = Vec::from(&info[crate::stream::INFO_REQUEST.len() + 1..]); + debug_assert_eq!(header.len(), REQUEST_HEADER_LEN); + + let mut e = hpke.enc()?; + header.append(&mut e); + + Ok(ClientRequestStream::new(src, hpke, header)) } } @@ -191,6 +214,7 @@ impl Server { let sym = SymmetricSuite::new(kdf_id, aead_id); let info = build_info( + INFO_REQUEST, key_id, HpkeConfig::new(self.config.kem, sym.kdf(), sym.aead()), )?; From 7222a26f7ddbedd53690252893d933963501da2d Mon Sep 17 00:00:00 2001 From: Martin Thomson Date: Fri, 13 Dec 2024 05:41:28 +0000 Subject: [PATCH 16/16] Add missing file --- ohttp/src/stream.rs | 110 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 ohttp/src/stream.rs diff --git a/ohttp/src/stream.rs b/ohttp/src/stream.rs new file mode 100644 index 0000000..1f9bd90 --- /dev/null +++ b/ohttp/src/stream.rs @@ -0,0 +1,110 @@ +use std::{ + cmp::min, + io::{Error as IoError, Result as IoResult}, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::AsyncRead; + +use crate::HpkeS; + +pub(crate) const INFO_REQUEST: &[u8] = b"message/bhttp chunked request"; + +fn write_len(w: &mut [u8], len: usize) -> usize { + let v: u64 = len.try_into().unwrap(); + let (v, len) = match () { + () if v < (1 << 6) => (v, 1), + () if v < (1 << 14) => (v | 1 << 14, 2), + () if v < (1 << 30) => (v | (2 << 30), 4), + () if v < (1 << 62) => (v | (3 << 62), 8), + () => panic!("varint value too large"), + }; + w[..len].copy_from_slice(&v.to_be_bytes()[(8 - len)..]); + len +} + +#[pin_project::pin_project] +pub struct ClientRequestStream { + #[pin] + src: S, + hpke: HpkeS, + buf: Vec, +} + +impl ClientRequestStream { + pub fn new(src: S, hpke: HpkeS, header: Vec) -> Self { + Self { + src, + hpke, + buf: header, + } + } +} + +impl AsyncRead for ClientRequestStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &mut [u8], + ) -> Poll> { + let this = self.project(); + // We have buffered data, so dump it into the output directly. + let mut written = if this.buf.is_empty() { + 0 + } else { + let amnt = min(this.buf.len(), buf.len()); + buf[..amnt].copy_from_slice(&this.buf[..amnt]); + buf = &mut buf[amnt..]; + *this.buf = this.buf.split_off(amnt); + if buf.is_empty() { + return Poll::Ready(Ok(amnt)); + } + amnt + }; + + // Now read into the buffer. + // Because we are expanding the data, when the buffer we are provided is too small, + // we have to use a temporary buffer so that we can save some bytes. + let mut tmp = [0; 64]; + let read_buf = if buf.len() < tmp.len() { + // Use the provided buffer, but leave room for AEAD tag and a varint. + let read_len = min(buf.len(), 1 << 62) - this.hpke.aead().n_t(); + &mut buf[8..read_len] + } else { + &mut tmp[..] + }; + let (aad, len): (&[u8], _) = match this.src.poll_read(cx, read_buf) { + Poll::Pending => return Poll::Pending, + Poll::Ready(Ok(0)) => (&b"final"[..], 0), + Poll::Ready(Ok(len)) => (&[], len), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + }; + + let ct = this + .hpke + .seal(aad, &mut read_buf[..len]) + .map_err(IoError::other)?; + + // Now we need to write the length of the chunk. + let len_len = write_len(&mut tmp, ct.len()); + if len_len <= buf.len() { + // If the length fits in the buffer, that's easy. + buf[..len_len].copy_from_slice(&tmp[..len_len]); + written += len_len; + buf = &mut buf[len_len..]; + } else { + // Otherwise, we need to save any remainder in our own buffer. + buf.copy_from_slice(&tmp[..buf.len()]); + this.buf.extend_from_slice(&tmp[buf.len()..len_len]); + let amnt = buf.len(); + written += amnt; + buf = &mut buf[amnt..]; + } + + let amnt = min(ct.len(), buf.len()); + buf[..amnt].copy_from_slice(&ct[..amnt]); + this.buf.extend_from_slice(&ct[amnt..]); + Poll::Ready(Ok(amnt + written)) + } +}