Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor I/O traits #31

Merged
merged 6 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions rust/src/embedded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,16 @@
use crate::constants::*;
use crate::stl::AsciiSym;
use crate::{
DecodeError, DefineUnion, ReadTuple, ReadUnion, RestrictedCharacter, RestrictedString, Sizing,
StrictDecode, StrictDumb, StrictEncode, StrictProduct, StrictStruct, StrictSum, StrictTuple,
StrictType, StrictUnion, TypeName, TypedRead, TypedWrite, WriteTuple, WriteUnion, LIB_EMBEDDED,
DecodeError, DefineUnion, ReadRaw, ReadTuple, ReadUnion, RestrictedCharacter, RestrictedString,
Sizing, StrictDecode, StrictDumb, StrictEncode, StrictProduct, StrictStruct, StrictSum,
StrictTuple, StrictType, StrictUnion, TypeName, TypedRead, TypedWrite, WriteRaw, WriteTuple,
WriteUnion, LIB_EMBEDDED,
};

pub trait DecodeRawLe: Sized {
fn decode_raw_le(reader: &mut (impl ReadRaw + ?Sized)) -> Result<Self, DecodeError>;
}

#[derive(
Wrapper, WrapperMut, Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Default, From
)]
Expand All @@ -49,8 +54,12 @@
pub struct Byte(u8);

impl StrictEncode for Byte {
fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> {
unsafe { writer.register_primitive(BYTE)._write_raw::<1>([self.0]) }
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {
unsafe {
writer = writer.register_primitive(BYTE);
writer.raw_writer().write_raw::<1>([self.0])?;

Check warning on line 60 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L57-L60

Added lines #L57 - L60 were not covered by tests
}
Ok(writer)

Check warning on line 62 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L62

Added line #L62 was not covered by tests
}
}

Expand All @@ -60,18 +69,23 @@
const STRICT_LIB_NAME: &'static str = $crate::LIB_EMBEDDED;
}
impl $crate::StrictEncode for $ty {
fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> {
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {
unsafe {
writer
.register_primitive($id)
._write_raw_array(self.to_le_bytes())
writer = writer.register_primitive($id);
writer.raw_writer().write_raw_array(self.to_le_bytes())?;
}
Ok(writer)
}
}
impl $crate::DecodeRawLe for $ty {
fn decode_raw_le(reader: &mut (impl ReadRaw + ?Sized)) -> Result<Self, DecodeError> {
let buf = reader.read_raw_array::<{ Self::BITS as usize / 8 }>()?;
Ok(Self::from_le_bytes(buf))
}
}
impl $crate::StrictDecode for $ty {
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let buf = unsafe { reader._read_raw_array::<{ Self::BITS as usize / 8 }>()? };
Ok(Self::from_le_bytes(buf))
Self::decode_raw_le(unsafe { reader.raw_reader() })
}
}
};
Expand All @@ -83,17 +97,23 @@
const STRICT_LIB_NAME: &'static str = $crate::LIB_EMBEDDED;
}
impl $crate::StrictEncode for $ty {
fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> {
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {

Check warning on line 100 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L100

Added line #L100 was not covered by tests
unsafe {
writer = writer.register_primitive($id);

Check warning on line 102 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L102

Added line #L102 was not covered by tests
writer
.register_primitive($id)
._write_raw_array(self.get().to_le_bytes())
.raw_writer()
.write_raw_array(self.get().to_le_bytes())?;

Check warning on line 105 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L104-L105

Added lines #L104 - L105 were not covered by tests
}
Ok(writer)

Check warning on line 107 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L107

Added line #L107 was not covered by tests
}
}
impl $crate::StrictDecode for $ty {
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let buf = unsafe { reader._read_raw_array::<{ Self::BITS as usize / 8 }>()? };
let buf = unsafe {
reader
.raw_reader()
.read_raw_array::<{ Self::BITS as usize / 8 }>()?

Check warning on line 115 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L112-L115

Added lines #L112 - L115 were not covered by tests
};
let v = <$p>::from_le_bytes(buf);
Self::new(v).ok_or(DecodeError::ZeroNatural)
}
Expand All @@ -109,18 +129,22 @@
}
#[cfg(feature = "float")]
impl $crate::StrictEncode for $ty {
fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> {
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {

Check warning on line 132 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L132

Added line #L132 was not covered by tests
let mut be = [0u8; $len];
be.copy_from_slice(&self.to_bits().to_le_bytes()[..$len]);
unsafe { writer.register_primitive($id)._write_raw_array(be) }
unsafe {
writer = writer.register_primitive($id);
writer.raw_writer().write_raw_array(be)?;

Check warning on line 137 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L135-L137

Added lines #L135 - L137 were not covered by tests
}
Ok(writer)

Check warning on line 139 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L139

Added line #L139 was not covered by tests
}
}
#[cfg(feature = "float")]
impl $crate::StrictDecode for $ty {
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
const BYTES: usize = <$ty>::BITS / 8;
let mut inner = [0u8; 32];
let buf = unsafe { reader._read_raw_array::<BYTES>()? };
let buf = unsafe { reader.raw_reader().read_raw_array::<BYTES>()? };

Check warning on line 147 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L147

Added line #L147 was not covered by tests
inner[..BYTES].copy_from_slice(&buf[..]);
let bits = u256::from_le_bytes(inner);
Ok(Self::from_bits(bits))
Expand Down Expand Up @@ -474,7 +498,7 @@
for Confined<Vec<T>, MIN_LEN, MAX_LEN>
{
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let len = unsafe { reader._read_raw_len::<MAX_LEN>()? };
let len = unsafe { reader.raw_reader().read_raw_len::<MAX_LEN>()? };

Check warning on line 501 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L501

Added line #L501 was not covered by tests
let mut col = Vec::<T>::with_capacity(len);
for _ in 0..len {
col.push(StrictDecode::strict_decode(reader)?);
Expand Down Expand Up @@ -505,7 +529,7 @@
for Confined<BTreeSet<T>, MIN_LEN, MAX_LEN>
{
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let len = unsafe { reader._read_raw_len::<MAX_LEN>()? };
let len = unsafe { reader.raw_reader().read_raw_len::<MAX_LEN>()? };

Check warning on line 532 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L532

Added line #L532 was not covered by tests
let mut col = BTreeSet::<T>::new();
for _ in 0..len {
let item = StrictDecode::strict_decode(reader)?;
Expand Down Expand Up @@ -535,7 +559,7 @@
{
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {
unsafe {
writer = writer._write_raw_len::<MAX_LEN>(self.len())?;
writer.raw_writer().write_raw_len::<MAX_LEN>(self.len())?;

Check warning on line 562 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L562

Added line #L562 was not covered by tests
}
for (k, v) in self {
writer = k.strict_encode(writer)?;
Expand All @@ -558,7 +582,7 @@
> StrictDecode for Confined<BTreeMap<K, V>, MIN_LEN, MAX_LEN>
{
fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError> {
let len = unsafe { reader._read_raw_len::<MAX_LEN>()? };
let len = unsafe { reader.raw_reader().read_raw_len::<MAX_LEN>()? };

Check warning on line 585 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L585

Added line #L585 was not covered by tests
let mut col = BTreeMap::new();
for _ in 0..len {
let key = StrictDecode::strict_decode(reader)?;
Expand Down
8 changes: 5 additions & 3 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,18 @@ pub mod stl;
#[cfg(test)]
pub(crate) mod test;

pub use embedded::Byte;
pub use embedded::{Byte, DecodeRawLe};
pub use error::{DecodeError, DeserializeError, SerializeError};
pub use ident::{FieldName, Ident, InvalidIdent, LibName, TypeName, VariantName};
pub use primitives::{constants, NumCls, NumInfo, NumSize, Primitive};
pub use reader::StrictReader;
pub use reader::{ConfinedReader, StreamReader, StrictReader};
pub use stl::{Bool, RestrictedCharacter, RestrictedString, U4, U5};
pub use traits::*;
pub use types::*;
pub use util::{Sizing, Variant};
pub use writer::{SplitParent, StrictParent, StrictWriter, StructWriter, UnionWriter};
pub use writer::{
SplitParent, StreamWriter, StrictParent, StrictWriter, StructWriter, UnionWriter,
};

#[deprecated(since = "2.2.0", note = "use LIB_EMBEDDED")]
pub const NO_LIB: &str = LIB_EMBEDDED;
Expand Down
101 changes: 64 additions & 37 deletions rust/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
use std::io;

use crate::{
DecodeError, FieldName, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
DecodeError, FieldName, ReadRaw, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
StrictStruct, StrictSum, StrictTuple, StrictUnion, TypedRead, VariantName,
};

Expand All @@ -46,13 +46,13 @@

// TODO: Move to amplify crate
#[derive(Clone, Debug)]
pub struct CountingReader<R: io::Read> {
pub struct ConfinedReader<R: io::Read> {
count: usize,
limit: usize,
reader: R,
}

impl<R: io::Read> From<R> for CountingReader<R> {
impl<R: io::Read> From<R> for ConfinedReader<R> {
fn from(reader: R) -> Self {
Self {
count: 0,
Expand All @@ -62,7 +62,7 @@
}
}

impl<R: io::Read> CountingReader<R> {
impl<R: io::Read> ConfinedReader<R> {
pub fn with(limit: usize, reader: R) -> Self {
Self {
count: 0,
Expand All @@ -73,10 +73,10 @@

pub fn count(&self) -> usize { self.count }

pub fn unbox(self) -> R { self.reader }
pub fn unconfine(self) -> R { self.reader }
}

impl<R: io::Read> io::Read for CountingReader<R> {
impl<R: io::Read> io::Read for ConfinedReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let len = self.reader.read(buf)?;
match self.count.checked_add(len) {
Expand All @@ -88,31 +88,72 @@
}
}

#[derive(Clone, Debug, From)]
pub struct StrictReader<R: io::Read>(CountingReader<R>);
#[derive(Clone, Debug)]

Check warning on line 91 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L91

Added line #L91 was not covered by tests
pub struct StreamReader<R: io::Read>(ConfinedReader<R>);

impl<R: io::Read> StreamReader<R> {
pub fn new<const MAX: usize>(inner: R) -> Self { Self(ConfinedReader::with(MAX, inner)) }
pub fn unconfine(self) -> R { self.0.unconfine() }

Check warning on line 96 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L96

Added line #L96 was not covered by tests
}

impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
pub fn cursor<const MAX: usize>(inner: T) -> Self {
Self(ConfinedReader::with(MAX, io::Cursor::new(inner)))
}

Check warning on line 102 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L100-L102

Added lines #L100 - L102 were not covered by tests
}

impl<R: io::Read> ReadRaw for StreamReader<R> {
fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
use io::Read;
let mut buf = vec![0u8; len];
self.0.read_exact(&mut buf)?;
Ok(buf)
}

Check warning on line 111 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L106-L111

Added lines #L106 - L111 were not covered by tests

impl StrictReader<io::Cursor<Vec<u8>>> {
pub fn in_memory(data: Vec<u8>, limit: usize) -> Self {
StrictReader(CountingReader::with(limit, io::Cursor::new(data)))
fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
use io::Read;
let mut buf = [0u8; LEN];
self.0.read_exact(&mut buf)?;
Ok(buf)
}
}

impl StrictReader<ReadCounter> {
pub fn counter() -> Self { StrictReader(CountingReader::from(ReadCounter::default())) }
impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
pub fn in_memory<const MAX: usize>(data: T) -> Self { Self::new::<MAX>(io::Cursor::new(data)) }
pub fn into_cursor(self) -> io::Cursor<T> { self.0.unconfine() }
}

impl<R: io::Read> StrictReader<R> {
pub fn with(limit: usize, reader: R) -> Self {
StrictReader(CountingReader::with(limit, reader))
impl StreamReader<ReadCounter> {
pub fn counter<const MAX: usize>() -> Self { Self::new::<MAX>(ReadCounter::default()) }

Check warning on line 127 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L127

Added line #L127 was not covered by tests
}

#[derive(Clone, Debug, From)]

Check warning on line 130 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L130

Added line #L130 was not covered by tests
pub struct StrictReader<R: ReadRaw>(R);

impl<T: AsRef<[u8]>> StrictReader<StreamReader<io::Cursor<T>>> {
pub fn in_memory<const MAX: usize>(data: T) -> Self {
Self(StreamReader::in_memory::<MAX>(data))
}
pub fn into_cursor(self) -> io::Cursor<T> { self.0.into_cursor() }
}

impl StrictReader<StreamReader<ReadCounter>> {
pub fn counter<const MAX: usize>() -> Self { Self(StreamReader::counter::<MAX>()) }

Check warning on line 141 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L141

Added line #L141 was not covered by tests
}

pub fn unbox(self) -> R { self.0.unbox() }
impl<R: ReadRaw> StrictReader<R> {
pub fn with(reader: R) -> Self { Self(reader) }

Check warning on line 145 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L145

Added line #L145 was not covered by tests

pub fn unbox(self) -> R { self.0 }

Check warning on line 147 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L147

Added line #L147 was not covered by tests
}

impl<R: io::Read> TypedRead for StrictReader<R> {
impl<R: ReadRaw> TypedRead for StrictReader<R> {
type TupleReader<'parent> = TupleReader<'parent, R> where Self: 'parent;
type StructReader<'parent> = StructReader<'parent, R> where Self: 'parent;
type UnionReader = Self;
type RawReader = R;

unsafe fn raw_reader(&mut self) -> &mut Self::RawReader { &mut self.0 }

fn read_union<T: StrictUnion>(
&mut self,
Expand Down Expand Up @@ -183,49 +224,35 @@
assert!(reader.named_fields.is_empty(), "excessive fields are read for {}", name);
Ok(res)
}

unsafe fn _read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
use io::Read;
let mut buf = vec![0u8; len];
self.0.read_exact(&mut buf)?;
Ok(buf)
}

unsafe fn _read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
use io::Read;
let mut buf = [0u8; LEN];
self.0.read_exact(&mut buf)?;
Ok(buf)
}
}

#[derive(Debug)]
pub struct TupleReader<'parent, R: io::Read> {
pub struct TupleReader<'parent, R: ReadRaw> {
read_fields: u8,
parent: &'parent mut StrictReader<R>,
}

impl<'parent, R: io::Read> ReadTuple for TupleReader<'parent, R> {
impl<'parent, R: ReadRaw> ReadTuple for TupleReader<'parent, R> {
fn read_field<T: StrictDecode>(&mut self) -> Result<T, DecodeError> {
self.read_fields += 1;
T::strict_decode(self.parent)
}
}

#[derive(Debug)]
pub struct StructReader<'parent, R: io::Read> {
pub struct StructReader<'parent, R: ReadRaw> {
named_fields: Vec<FieldName>,
parent: &'parent mut StrictReader<R>,
}

impl<'parent, R: io::Read> ReadStruct for StructReader<'parent, R> {
impl<'parent, R: ReadRaw> ReadStruct for StructReader<'parent, R> {
fn read_field<T: StrictDecode>(&mut self, field: FieldName) -> Result<T, DecodeError> {
self.named_fields.push(field);
T::strict_decode(self.parent)
}
}

impl<R: io::Read> ReadUnion for StrictReader<R> {
impl<R: ReadRaw> ReadUnion for StrictReader<R> {
type TupleReader<'parent> = TupleReader<'parent, R> where Self: 'parent;
type StructReader<'parent> = StructReader<'parent, R> where Self: 'parent;

Expand Down
10 changes: 4 additions & 6 deletions rust/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
// limitations under the License.

use std::fmt::Debug;
use std::io;
use std::io::BufRead;

use amplify::confinement::Confined;
Expand All @@ -30,8 +29,8 @@
pub fn encode<T: StrictEncode + Debug + Eq>(val: &T) -> Vec<u8> {
const MAX: usize = u16::MAX as usize;

let ast_data = StrictWriter::in_memory(MAX);
let data = val.strict_encode(ast_data).unwrap().unbox();
let ast_data = StrictWriter::in_memory::<MAX>();
let data = val.strict_encode(ast_data).unwrap().unbox().unconfine();

Check warning on line 33 in rust/src/test.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/test.rs#L32-L33

Added lines #L32 - L33 were not covered by tests
Confined::<Vec<u8>, 0, MAX>::try_from(data)
.unwrap()
.into_inner()
Expand All @@ -40,16 +39,15 @@
pub fn decode<T: StrictDecode + Debug + Eq>(data: impl AsRef<[u8]>) -> T {
const MAX: usize = u16::MAX as usize;

let cursor = io::Cursor::new(data);
let mut reader = StrictReader::with(MAX, cursor);
let mut reader = StrictReader::in_memory::<MAX>(data);

Check warning on line 42 in rust/src/test.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/test.rs#L42

Added line #L42 was not covered by tests
let val2 = T::strict_decode(&mut reader).unwrap();
let mut cursor = reader.unbox();
let mut cursor = reader.into_cursor();
assert!(!cursor.fill_buf().unwrap().is_empty(), "data not entirely consumed");

val2
}

#[allow(dead_code)]

Check warning on line 50 in rust/src/test.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/test.rs#L44-L50

Added lines #L44 - L50 were not covered by tests
pub fn encoding_roundtrip<T: StrictEncode + StrictDecode + Debug + Eq>(val: &T) {
let data = encode(val);
let val2: T = decode(data);
Expand Down
Loading
Loading