Skip to content

Commit

Permalink
abstract raw byte readers from strict readers and io::Read
Browse files Browse the repository at this point in the history
  • Loading branch information
dr-orlovsky committed Feb 3, 2024
1 parent e83ccc4 commit df81691
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 88 deletions.
34 changes: 24 additions & 10 deletions rust/src/embedded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,16 @@ use amplify::{Array, Wrapper};
use crate::constants::*;
use crate::stl::AsciiSym;
use crate::{
DecodeError, DefineUnion, ReadTuple, ReadUnion, RestrictedCharacter, RestrictedString, Sizing,
StrictDecode, StrictDumb, StrictEncode, StrictProduct, StrictStruct, StrictSum, StrictTuple,
StrictType, StrictUnion, TypeName, TypedRead, TypedWrite, WriteTuple, WriteUnion, LIB_EMBEDDED,
DecodeError, DefineUnion, ReadRaw, ReadTuple, ReadUnion, RestrictedCharacter, RestrictedString,
Sizing, StrictDecode, StrictDumb, StrictEncode, StrictProduct, StrictStruct, StrictSum,
StrictTuple, StrictType, StrictUnion, TypeName, TypedRead, TypedWrite, WriteTuple, WriteUnion,
LIB_EMBEDDED,
};

pub trait DecodeRawLe: Sized {
fn decode_raw_le(reader: &mut (impl ReadRaw + ?Sized)) -> Result<Self, DecodeError>;
}

#[derive(
Wrapper, WrapperMut, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Default, From
)]
Expand Down Expand Up @@ -68,10 +73,15 @@ macro_rules! encode_num {
}
}
}
impl $crate::DecodeRawLe for $ty {
fn decode_raw_le(reader: &mut (impl ReadRaw + ?Sized)) -> Result<Self, DecodeError> {
let buf = reader.read_raw_array::<{ Self::BITS as usize / 8 }>()?;
Ok(Self::from_le_bytes(buf))
}
}
impl $crate::StrictDecode for $ty {
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let buf = unsafe { reader._read_raw_array::<{ Self::BITS as usize / 8 }>()? };
Ok(Self::from_le_bytes(buf))
Self::decode_raw_le(unsafe { reader.raw_reader() })
}
}
};
Expand All @@ -93,7 +103,11 @@ macro_rules! encode_nonzero {
}
impl $crate::StrictDecode for $ty {
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let buf = unsafe { reader._read_raw_array::<{ Self::BITS as usize / 8 }>()? };
let buf = unsafe {
reader
.raw_reader()
.read_raw_array::<{ Self::BITS as usize / 8 }>()?

Check warning on line 109 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L106-L109

Added lines #L106 - L109 were not covered by tests
};
let v = <$p>::from_le_bytes(buf);
Self::new(v).ok_or(DecodeError::ZeroNatural)
}
Expand All @@ -120,7 +134,7 @@ macro_rules! encode_float {
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
const BYTES: usize = <$ty>::BITS / 8;
let mut inner = [0u8; 32];
let buf = unsafe { reader._read_raw_array::<BYTES>()? };
let buf = unsafe { reader.raw_reader().read_raw_array::<BYTES>()? };

Check warning on line 137 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L137

Added line #L137 was not covered by tests
inner[..BYTES].copy_from_slice(&buf[..]);
let bits = u256::from_le_bytes(inner);
Ok(Self::from_bits(bits))
Expand Down Expand Up @@ -474,7 +488,7 @@ impl<T: StrictDecode, const MIN_LEN: usize, const MAX_LEN: usize> StrictDecode
for Confined<Vec<T>, MIN_LEN, MAX_LEN>
{
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let len = unsafe { reader._read_raw_len::<MAX_LEN>()? };
let len = unsafe { reader.raw_reader().read_raw_len::<MAX_LEN>()? };

Check warning on line 491 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L491

Added line #L491 was not covered by tests
let mut col = Vec::<T>::with_capacity(len);
for _ in 0..len {
col.push(StrictDecode::strict_decode(reader)?);
Expand Down Expand Up @@ -505,7 +519,7 @@ impl<T: StrictDecode + Ord, const MIN_LEN: usize, const MAX_LEN: usize> StrictDe
for Confined<BTreeSet<T>, MIN_LEN, MAX_LEN>
{
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let len = unsafe { reader._read_raw_len::<MAX_LEN>()? };
let len = unsafe { reader.raw_reader().read_raw_len::<MAX_LEN>()? };

Check warning on line 522 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L522

Added line #L522 was not covered by tests
let mut col = BTreeSet::<T>::new();
for _ in 0..len {
let item = StrictDecode::strict_decode(reader)?;
Expand Down Expand Up @@ -558,7 +572,7 @@ impl<
> StrictDecode for Confined<BTreeMap<K, V>, MIN_LEN, MAX_LEN>
{
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let len = unsafe { reader._read_raw_len::<MAX_LEN>()? };
let len = unsafe { reader.raw_reader().read_raw_len::<MAX_LEN>()? };

Check warning on line 575 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L575

Added line #L575 was not covered by tests
let mut col = BTreeMap::new();
for _ in 0..len {
let key = StrictDecode::strict_decode(reader)?;
Expand Down
4 changes: 2 additions & 2 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ pub mod stl;
#[cfg(test)]
pub(crate) mod test;

pub use embedded::Byte;
pub use embedded::{Byte, DecodeRawLe};
pub use error::{DecodeError, DeserializeError, SerializeError};
pub use ident::{FieldName, Ident, InvalidIdent, LibName, TypeName, VariantName};
pub use primitives::{constants, NumCls, NumInfo, NumSize, Primitive};
pub use reader::StrictReader;
pub use reader::{ConfinedReader, StrictReader};
pub use stl::{Bool, RestrictedCharacter, RestrictedString, U4, U5};
pub use traits::*;
pub use types::*;
Expand Down
88 changes: 51 additions & 37 deletions rust/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
use std::io;

use crate::{
DecodeError, FieldName, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
DecodeError, FieldName, ReadRaw, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
StrictStruct, StrictSum, StrictTuple, StrictUnion, TypedRead, VariantName,
};

Expand All @@ -46,13 +46,13 @@ impl io::Read for ReadCounter {

// TODO: Move to amplify crate
#[derive(Clone, Debug)]
pub struct CountingReader<R: io::Read> {
pub struct ConfinedReader<R: io::Read> {
count: usize,
limit: usize,
reader: R,
}

impl<R: io::Read> From<R> for CountingReader<R> {
impl<R: io::Read> From<R> for ConfinedReader<R> {
fn from(reader: R) -> Self {
Self {
count: 0,
Expand All @@ -62,7 +62,7 @@ impl<R: io::Read> From<R> for CountingReader<R> {
}
}

impl<R: io::Read> CountingReader<R> {
impl<R: io::Read> ConfinedReader<R> {
pub fn with(limit: usize, reader: R) -> Self {
Self {
count: 0,
Expand All @@ -73,10 +73,10 @@ impl<R: io::Read> CountingReader<R> {

pub fn count(&self) -> usize { self.count }

pub fn unbox(self) -> R { self.reader }
pub fn unconfine(self) -> R { self.reader }
}

impl<R: io::Read> io::Read for CountingReader<R> {
impl<R: io::Read> io::Read for ConfinedReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let len = self.reader.read(buf)?;
match self.count.checked_add(len) {
Expand All @@ -88,31 +88,59 @@ impl<R: io::Read> io::Read for CountingReader<R> {
}
}

#[derive(Clone, Debug, From)]
pub struct StrictReader<R: io::Read>(CountingReader<R>);
#[derive(Clone, Debug)]

Check warning on line 91 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L91

Added line #L91 was not covered by tests
pub struct StreamReader<R: io::Read>(ConfinedReader<R>);

impl StrictReader<io::Cursor<Vec<u8>>> {
pub fn in_memory(data: Vec<u8>, limit: usize) -> Self {
StrictReader(CountingReader::with(limit, io::Cursor::new(data)))
impl<R: io::Read> StreamReader<R> {
pub fn new<const MAX: usize>(inner: R) -> Self { Self(ConfinedReader::with(MAX, inner)) }
pub fn unconfine(self) -> R { self.0.unconfine() }
}

impl<R: io::Read> ReadRaw for StreamReader<R> {
fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
use io::Read;
let mut buf = vec![0u8; len];
self.0.read_exact(&mut buf)?;
Ok(buf)
}

Check warning on line 105 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L100-L105

Added lines #L100 - L105 were not covered by tests

fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
use io::Read;
let mut buf = [0u8; LEN];
self.0.read_exact(&mut buf)?;
Ok(buf)
}
}

impl StrictReader<ReadCounter> {
pub fn counter() -> Self { StrictReader(CountingReader::from(ReadCounter::default())) }
#[derive(Clone, Debug, From)]

Check warning on line 115 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L115

Added line #L115 was not covered by tests
pub struct StrictReader<R: ReadRaw>(R);

impl<T: AsRef<[u8]>> StrictReader<StreamReader<io::Cursor<T>>> {
pub fn in_memory<const MAX: usize>(data: T) -> Self {
StrictReader(StreamReader::new::<MAX>(io::Cursor::new(data)))
}
pub fn into_cursor(self) -> io::Cursor<T> { self.0.unconfine() }
}

impl<R: io::Read> StrictReader<R> {
pub fn with(limit: usize, reader: R) -> Self {
StrictReader(CountingReader::with(limit, reader))
impl StrictReader<StreamReader<ReadCounter>> {
pub fn counter<const MAX: usize>() -> Self {
StrictReader(StreamReader::new::<MAX>(ReadCounter::default()))

Check warning on line 127 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L126-L127

Added lines #L126 - L127 were not covered by tests
}
}

pub fn unbox(self) -> R { self.0.unbox() }
impl<R: ReadRaw> StrictReader<R> {
pub fn with(reader: R) -> Self { StrictReader(reader) }

Check warning on line 132 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L132

Added line #L132 was not covered by tests

pub fn unbox(self) -> R { self.0 }

Check warning on line 134 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L134

Added line #L134 was not covered by tests
}

impl<R: io::Read> TypedRead for StrictReader<R> {
impl<R: ReadRaw> TypedRead for StrictReader<R> {
type TupleReader<'parent> = TupleReader<'parent, R> where Self: 'parent;
type StructReader<'parent> = StructReader<'parent, R> where Self: 'parent;
type UnionReader = Self;
type RawReader = R;

unsafe fn raw_reader(&mut self) -> &mut Self::RawReader { &mut self.0 }

fn read_union<T: StrictUnion>(
&mut self,
Expand Down Expand Up @@ -183,49 +211,35 @@ impl<R: io::Read> TypedRead for StrictReader<R> {
assert!(reader.named_fields.is_empty(), "excessive fields are read for {}", name);
Ok(res)
}

unsafe fn _read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
use io::Read;
let mut buf = vec![0u8; len];
self.0.read_exact(&mut buf)?;
Ok(buf)
}

unsafe fn _read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
use io::Read;
let mut buf = [0u8; LEN];
self.0.read_exact(&mut buf)?;
Ok(buf)
}
}

#[derive(Debug)]
pub struct TupleReader<'parent, R: io::Read> {
pub struct TupleReader<'parent, R: ReadRaw> {
read_fields: u8,
parent: &'parent mut StrictReader<R>,
}

impl<'parent, R: io::Read> ReadTuple for TupleReader<'parent, R> {
impl<'parent, R: ReadRaw> ReadTuple for TupleReader<'parent, R> {
fn read_field<T: StrictDecode>(&mut self) -> Result<T, DecodeError> {
self.read_fields += 1;
T::strict_decode(self.parent)
}
}

#[derive(Debug)]
pub struct StructReader<'parent, R: io::Read> {
pub struct StructReader<'parent, R: ReadRaw> {
named_fields: Vec<FieldName>,
parent: &'parent mut StrictReader<R>,
}

impl<'parent, R: io::Read> ReadStruct for StructReader<'parent, R> {
impl<'parent, R: ReadRaw> ReadStruct for StructReader<'parent, R> {
fn read_field<T: StrictDecode>(&mut self, field: FieldName) -> Result<T, DecodeError> {
self.named_fields.push(field);
T::strict_decode(self.parent)
}
}

impl<R: io::Read> ReadUnion for StrictReader<R> {
impl<R: ReadRaw> ReadUnion for StrictReader<R> {
type TupleReader<'parent> = TupleReader<'parent, R> where Self: 'parent;
type StructReader<'parent> = StructReader<'parent, R> where Self: 'parent;

Expand Down
6 changes: 2 additions & 4 deletions rust/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
// limitations under the License.

use std::fmt::Debug;
use std::io;
use std::io::BufRead;

use amplify::confinement::Confined;
Expand All @@ -40,10 +39,9 @@ pub fn encode<T: StrictEncode + Debug + Eq>(val: &T) -> Vec<u8> {
pub fn decode<T: StrictDecode + Debug + Eq>(data: impl AsRef<[u8]>) -> T {
const MAX: usize = u16::MAX as usize;

let cursor = io::Cursor::new(data);
let mut reader = StrictReader::with(MAX, cursor);
let mut reader = StrictReader::in_memory::<MAX>(data);

Check warning on line 42 in rust/src/test.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/test.rs#L42

Added line #L42 was not covered by tests
let val2 = T::strict_decode(&mut reader).unwrap();
let mut cursor = reader.unbox();
let mut cursor = reader.into_cursor();

Check warning on line 44 in rust/src/test.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/test.rs#L44

Added line #L44 was not covered by tests
assert!(!cursor.fill_buf().unwrap().is_empty(), "data not entirely consumed");

val2
Expand Down
Loading

0 comments on commit df81691

Please sign in to comment.