Skip to content

Commit

Permalink
abstract raw byte writers from strict writers and io::Write
Browse files Browse the repository at this point in the history
  • Loading branch information
dr-orlovsky committed Feb 4, 2024
1 parent f95dbcb commit 9b5a9f2
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 107 deletions.
38 changes: 24 additions & 14 deletions rust/src/embedded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ use crate::stl::AsciiSym;
use crate::{
DecodeError, DefineUnion, ReadRaw, ReadTuple, ReadUnion, RestrictedCharacter, RestrictedString,
Sizing, StrictDecode, StrictDumb, StrictEncode, StrictProduct, StrictStruct, StrictSum,
StrictTuple, StrictType, StrictUnion, TypeName, TypedRead, TypedWrite, WriteTuple, WriteUnion,
LIB_EMBEDDED,
StrictTuple, StrictType, StrictUnion, TypeName, TypedRead, TypedWrite, WriteRaw, WriteTuple,
WriteUnion, LIB_EMBEDDED,
};

pub trait DecodeRawLe: Sized {
Expand All @@ -54,8 +54,12 @@ pub trait DecodeRawLe: Sized {
pub struct Byte(u8);

impl StrictEncode for Byte {
fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> {
unsafe { writer.register_primitive(BYTE)._write_raw::<1>([self.0]) }
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {
unsafe {
writer = writer.register_primitive(BYTE);
writer.raw_writer().write_raw::<1>([self.0])?;

Check warning on line 60 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L57-L60

Added lines #L57 - L60 were not covered by tests
}
Ok(writer)

Check warning on line 62 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L62

Added line #L62 was not covered by tests
}
}

Expand All @@ -65,12 +69,12 @@ macro_rules! encode_num {
const STRICT_LIB_NAME: &'static str = $crate::LIB_EMBEDDED;
}
impl $crate::StrictEncode for $ty {
fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> {
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {
unsafe {
writer
.register_primitive($id)
._write_raw_array(self.to_le_bytes())
writer = writer.register_primitive($id);
writer.raw_writer().write_raw_array(self.to_le_bytes())?;
}
Ok(writer)
}
}
impl $crate::DecodeRawLe for $ty {
Expand All @@ -93,12 +97,14 @@ macro_rules! encode_nonzero {
const STRICT_LIB_NAME: &'static str = $crate::LIB_EMBEDDED;
}
impl $crate::StrictEncode for $ty {
fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> {
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {

Check warning on line 100 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L100

Added line #L100 was not covered by tests
unsafe {
writer = writer.register_primitive($id);

Check warning on line 102 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L102

Added line #L102 was not covered by tests
writer
.register_primitive($id)
._write_raw_array(self.get().to_le_bytes())
.raw_writer()
.write_raw_array(self.get().to_le_bytes())?;

Check warning on line 105 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L104-L105

Added lines #L104 - L105 were not covered by tests
}
Ok(writer)

Check warning on line 107 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L107

Added line #L107 was not covered by tests
}
}
impl $crate::StrictDecode for $ty {
Expand All @@ -123,10 +129,14 @@ macro_rules! encode_float {
}
#[cfg(feature = "float")]
impl $crate::StrictEncode for $ty {
fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> {
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {

Check warning on line 132 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L132

Added line #L132 was not covered by tests
let mut be = [0u8; $len];
be.copy_from_slice(&self.to_bits().to_le_bytes()[..$len]);
unsafe { writer.register_primitive($id)._write_raw_array(be) }
unsafe {
writer = writer.register_primitive($id);
writer.raw_writer().write_raw_array(be)?;

Check warning on line 137 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L135-L137

Added lines #L135 - L137 were not covered by tests
}
Ok(writer)

Check warning on line 139 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L139

Added line #L139 was not covered by tests
}
}
#[cfg(feature = "float")]
Expand Down Expand Up @@ -549,7 +559,7 @@ impl<
{
fn strict_encode<W: TypedWrite>(&self, mut writer: W) -> io::Result<W> {
unsafe {
writer = writer._write_raw_len::<MAX_LEN>(self.len())?;
writer.raw_writer().write_raw_len::<MAX_LEN>(self.len())?;

Check warning on line 562 in rust/src/embedded.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/embedded.rs#L562

Added line #L562 was not covered by tests
}
for (k, v) in self {
writer = k.strict_encode(writer)?;
Expand Down
4 changes: 3 additions & 1 deletion rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ pub use stl::{Bool, RestrictedCharacter, RestrictedString, U4, U5};
pub use traits::*;
pub use types::*;
pub use util::{Sizing, Variant};
pub use writer::{SplitParent, StrictParent, StrictWriter, StructWriter, UnionWriter};
pub use writer::{
SplitParent, StreamWriter, StrictParent, StrictWriter, StructWriter, UnionWriter,
};

#[deprecated(since = "2.2.0", note = "use LIB_EMBEDDED")]
pub const NO_LIB: &str = LIB_EMBEDDED;
Expand Down
6 changes: 3 additions & 3 deletions rust/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,19 @@ pub struct StrictReader<R: ReadRaw>(R);

impl<T: AsRef<[u8]>> StrictReader<StreamReader<io::Cursor<T>>> {
pub fn in_memory<const MAX: usize>(data: T) -> Self {
StrictReader(StreamReader::new::<MAX>(io::Cursor::new(data)))
Self(StreamReader::new::<MAX>(io::Cursor::new(data)))
}
pub fn into_cursor(self) -> io::Cursor<T> { self.0.unconfine() }
}

impl StrictReader<StreamReader<ReadCounter>> {
pub fn counter<const MAX: usize>() -> Self {
StrictReader(StreamReader::new::<MAX>(ReadCounter::default()))
Self(StreamReader::new::<MAX>(ReadCounter::default()))
}

Check warning on line 134 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L132-L134

Added lines #L132 - L134 were not covered by tests
}

impl<R: ReadRaw> StrictReader<R> {
pub fn with(reader: R) -> Self { StrictReader(reader) }
pub fn with(reader: R) -> Self { Self(reader) }

Check warning on line 138 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L138

Added line #L138 was not covered by tests

pub fn unbox(self) -> R { self.0 }

Check warning on line 140 in rust/src/reader.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/reader.rs#L140

Added line #L140 was not covered by tests
}
Expand Down
4 changes: 2 additions & 2 deletions rust/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use crate::{StrictDecode, StrictEncode, StrictReader, StrictWriter};
pub fn encode<T: StrictEncode + Debug + Eq>(val: &T) -> Vec<u8> {
const MAX: usize = u16::MAX as usize;

let ast_data = StrictWriter::in_memory(MAX);
let data = val.strict_encode(ast_data).unwrap().unbox();
let ast_data = StrictWriter::in_memory::<MAX>();
let data = val.strict_encode(ast_data).unwrap().unbox().unconfine();

Check warning on line 33 in rust/src/test.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/test.rs#L32-L33

Added lines #L32 - L33 were not covered by tests
Confined::<Vec<u8>, 0, MAX>::try_from(data)
.unwrap()
.into_inner()
Expand Down
81 changes: 46 additions & 35 deletions rust/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,48 @@ use amplify::Wrapper;

use super::{DecodeError, DecodeRawLe, VariantName};
use crate::reader::StreamReader;
use crate::writer::StreamWriter;
use crate::{
DeserializeError, FieldName, Primitive, SerializeError, Sizing, StrictDumb, StrictEnum,
StrictReader, StrictStruct, StrictSum, StrictTuple, StrictType, StrictUnion, StrictWriter,
};

pub trait TypedParent: Sized {}

pub trait WriteRaw {
fn write_raw<const MAX_LEN: usize>(&mut self, bytes: impl AsRef<[u8]>) -> io::Result<()>;
fn write_raw_array<const LEN: usize>(&mut self, raw: [u8; LEN]) -> io::Result<()> {
self.write_raw::<LEN>(raw)
}
fn write_raw_len<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<()> {
match MAX_LEN {
tiny if tiny <= u8::MAX as usize => self.write_raw_array((len as u8).to_le_bytes()),
small if small <= u16::MAX as usize => self.write_raw_array((len as u16).to_le_bytes()),
medium if medium <= u24::MAX.into_usize() => {
self.write_raw_array((u24::with(len as u32)).to_le_bytes())
}
large if large <= u32::MAX as usize => self.write_raw_array((len as u32).to_le_bytes()),
huge if huge <= u64::MAX as usize => self.write_raw_array((len as u64).to_le_bytes()),
_ => unreachable!("confined collections larger than u64::MAX must not exist"),

Check warning on line 54 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L52-L54

Added lines #L52 - L54 were not covered by tests
}
}

Check warning on line 56 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L56

Added line #L56 was not covered by tests
}

impl<T: WriteRaw> WriteRaw for &mut T {
fn write_raw<const MAX_LEN: usize>(&mut self, bytes: impl AsRef<[u8]>) -> io::Result<()> {
(*self).write_raw::<MAX_LEN>(bytes)
}

Check warning on line 62 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L60-L62

Added lines #L60 - L62 were not covered by tests
}

#[allow(unused_variables)]
pub trait TypedWrite: Sized {
type TupleWriter: WriteTuple<Parent = Self>;
type StructWriter: WriteStruct<Parent = Self>;
type UnionDefiner: DefineUnion<Parent = Self>;
type RawWriter: WriteRaw;

#[doc(hidden)]
unsafe fn raw_writer(&mut self) -> &mut Self::RawWriter;

fn write_union<T: StrictUnion>(
self,
Expand Down Expand Up @@ -85,38 +115,16 @@ pub trait TypedWrite: Sized {
self
}

#[doc(hidden)]
unsafe fn _write_raw<const MAX_LEN: usize>(self, bytes: impl AsRef<[u8]>) -> io::Result<Self>;
#[doc(hidden)]
unsafe fn _write_raw_array<const LEN: usize>(self, raw: [u8; LEN]) -> io::Result<Self> {
self._write_raw::<LEN>(raw)
}
#[doc(hidden)]
unsafe fn _write_raw_len<const MAX_LEN: usize>(self, len: usize) -> io::Result<Self> {
match MAX_LEN {
tiny if tiny <= u8::MAX as usize => self._write_raw_array((len as u8).to_le_bytes()),
small if small <= u16::MAX as usize => {
self._write_raw_array((len as u16).to_le_bytes())
}
medium if medium <= u24::MAX.into_usize() => {
self._write_raw_array((u24::with(len as u32)).to_le_bytes())
}
large if large <= u32::MAX as usize => {
self._write_raw_array((len as u32).to_le_bytes())
}
huge if huge <= u64::MAX as usize => self._write_raw_array((len as u64).to_le_bytes()),
_ => unreachable!("confined collections larger than u64::MAX must not exist"),
}
}

/// Used by unicode strings, ASCII strings and restricted char set strings.
#[doc(hidden)]
unsafe fn write_string<const MAX_LEN: usize>(
self,
mut self,

Check warning on line 121 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L121

Added line #L121 was not covered by tests
bytes: impl AsRef<[u8]>,
) -> io::Result<Self> {
self._write_raw_len::<MAX_LEN>(bytes.as_ref().len())?
._write_raw::<MAX_LEN>(bytes)
self.raw_writer()
.write_raw_len::<MAX_LEN>(bytes.as_ref().len())?;
self.raw_writer().write_raw::<MAX_LEN>(bytes)?;
Ok(self)

Check warning on line 127 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L124-L127

Added lines #L124 - L127 were not covered by tests
}

/// Vec and sets - excluding strings, written by [`Self::write_string`].
Expand All @@ -129,7 +137,7 @@ pub trait TypedWrite: Sized {
for<'a> &'a C: IntoIterator,
for<'a> <&'a C as IntoIterator>::Item: StrictEncode,
{
self = self._write_raw_len::<MAX_LEN>(col.len())?;
self.raw_writer().write_raw_len::<MAX_LEN>(col.len())?;

Check warning on line 140 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L140

Added line #L140 was not covered by tests
for item in col {
self = item.strict_encode(self)?;
}
Expand Down Expand Up @@ -336,9 +344,10 @@ pub trait ReadUnion: Sized {

pub trait StrictEncode: StrictType {
fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W>;
fn strict_write(&self, lim: usize, writer: impl io::Write) -> io::Result<usize> {
let counter = StrictWriter::with(lim, writer);
Ok(self.strict_encode(counter)?.count())
fn strict_write(&self, writer: impl WriteRaw) -> io::Result<()> {
let counter = StrictWriter::with(writer);
self.strict_encode(counter)?;
Ok(())

Check warning on line 350 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L347-L350

Added lines #L347 - L350 were not covered by tests
}
}

Expand Down Expand Up @@ -367,22 +376,24 @@ impl<T> StrictDecode for PhantomData<T> {
pub trait StrictSerialize: StrictEncode {
fn strict_serialized_len(&self) -> io::Result<usize> {
let counter = StrictWriter::counter();
Ok(self.strict_encode(counter)?.unbox().count)
Ok(self.strict_encode(counter)?.unbox().unconfine().count)

Check warning on line 379 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L379

Added line #L379 was not covered by tests
}

fn to_strict_serialized<const MAX: usize>(
&self,
) -> Result<Confined<Vec<u8>, 0, MAX>, SerializeError> {
let ast_data = StrictWriter::in_memory(MAX);
let data = self.strict_encode(ast_data)?.unbox();
let ast_data = StrictWriter::in_memory::<MAX>();
let data = self.strict_encode(ast_data)?.unbox().unconfine();
Confined::<Vec<u8>, 0, MAX>::try_from(data).map_err(SerializeError::from)
}

fn strict_serialize_to_file<const MAX: usize>(
&self,
path: impl AsRef<std::path::Path>,
) -> Result<(), SerializeError> {
let file = StrictWriter::with(MAX, fs::File::create(path)?);
let file = fs::File::create(path)?;

Check warning on line 394 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L394

Added line #L394 was not covered by tests
// TODO: Do FileReader
let file = StrictWriter::with(StreamWriter::new::<MAX>(file));

Check warning on line 396 in rust/src/traits.rs

View check run for this annotation

Codecov / codecov/patch

rust/src/traits.rs#L396

Added line #L396 was not covered by tests
self.strict_encode(file)?;
Ok(())
}
Expand Down
Loading

0 comments on commit 9b5a9f2

Please sign in to comment.