Skip to content

Commit

Permalink
Support deserialization of C-style enums (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwalker authored Nov 12, 2023
1 parent c0239ce commit fce4b46
Showing 1 changed file with 110 additions and 9 deletions.
119 changes: 110 additions & 9 deletions lib/src/types/serde/typ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ use std::{fmt, result::Result};
use bytes::Bytes;
use serde::{
de::{
value::{BorrowedStrDeserializer, MapDeserializer, SeqDeserializer},
DeserializeSeed, Deserializer, EnumAccess, Error, IntoDeserializer, Unexpected as Unexp,
VariantAccess, Visitor,
value::{
BorrowedBytesDeserializer, BorrowedStrDeserializer, MapDeserializer, SeqDeserializer,
},
DeserializeSeed, Deserializer, EnumAccess, Error, Expected, IntoDeserializer,
Unexpected as Unexp, VariantAccess, Visitor,
},
forward_to_deserialize_any, Deserialize,
};
Expand Down Expand Up @@ -636,11 +638,13 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> {
where
V: Visitor<'de>,
{
if name != std::any::type_name::<BoltType>() {
return Err(DeError::invalid_type(Unexp::Enum, &"BoltType"));
if name == std::any::type_name::<BoltType>() {
visitor.visit_enum(BoltEnum { value: self.value })
} else {
visitor.visit_enum(CStyleEnum {
variant: self.value,
})
}

visitor.visit_enum(BoltEnum { value: self.value })
}

fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down Expand Up @@ -727,7 +731,16 @@ impl<'de> BoltTypeDeserializer<'de> {
where
V: Visitor<'de>,
{
let typ = match self.value {
self.value.unexpected(&visitor)
}
}

impl BoltType {
fn unexpected<T, E>(&self, expected: &E) -> Result<T, DeError>
where
E: Expected,
{
let typ = match self {
BoltType::String(v) => Unexp::Str(&v.value),
BoltType::Boolean(v) => Unexp::Bool(v.value),
BoltType::Map(_) => Unexp::Map,
Expand All @@ -751,7 +764,7 @@ impl<'de> BoltTypeDeserializer<'de> {
BoltType::DateTimeZoneId(_) => Unexp::Other("DateTimeZoneId"),
};

Err(DeError::invalid_type(typ, &visitor))
Err(DeError::invalid_type(typ, expected))
}
}

Expand Down Expand Up @@ -859,6 +872,70 @@ impl<'de> VariantAccess<'de> for BoltEnum<'de> {
}
}

struct CStyleEnum<'de> {
variant: &'de BoltType,
}

impl<'de> EnumAccess<'de> for CStyleEnum<'de> {
type Error = DeError;

type Variant = UnitVariant;

fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where
V: DeserializeSeed<'de>,
{
let val = match self.variant {
BoltType::String(variant) => seed.deserialize(variant.into_deserializer())?,
BoltType::Bytes(variant) => seed.deserialize(variant.into_deserializer())?,
BoltType::Integer(variant) => seed.deserialize(variant.value.into_deserializer())?,
otherwise => {
otherwise.unexpected(&"string, bytes, or integer (valid enum identifier)")?
}
};

Ok((val, UnitVariant))
}
}

struct UnitVariant;

impl<'de> VariantAccess<'de> for UnitVariant {
type Error = DeError;

fn unit_variant(self) -> Result<(), Self::Error> {
Ok(())
}

fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
Err(DeError::invalid_type(
Unexp::NewtypeVariant,
&"unit variant",
))
}

fn tuple_variant<V>(self, _len: usize, _visitorr: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
Err(DeError::invalid_type(Unexp::TupleVariant, &"unit variant"))
}

fn struct_variant<V>(
self,
_fields: &'static [&'static str],
_visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
Err(DeError::invalid_type(Unexp::StructVariant, &"unit variant"))
}
}

impl<'de> IntoDeserializer<'de, DeError> for &'de BoltType {
type Deserializer = BoltTypeDeserializer<'de>;

Expand All @@ -875,6 +952,14 @@ impl<'de> IntoDeserializer<'de, DeError> for &'de BoltString {
}
}

impl<'de> IntoDeserializer<'de, DeError> for &'de BoltBytes {
type Deserializer = BorrowedBytesDeserializer<'de, DeError>;

fn into_deserializer(self) -> Self::Deserializer {
BorrowedBytesDeserializer::new(&self.value)
}
}

trait FromFloat {
fn from_float(f: f64) -> Self;
}
Expand Down Expand Up @@ -2068,4 +2153,20 @@ mod tests {

assert_eq!(actual.addr, SocketAddr::from(([127, 0, 0, 1], 4242)));
}

#[test]
fn deserialize_some_c_style_enum() {
#[derive(Debug, Copy, Clone, Deserialize, PartialEq, Eq)]
enum Frobnicate {
Foo,
Moo,
Bing,
}

let value = BoltType::from("Foo");

let actual = value.to::<Frobnicate>().unwrap();

assert_eq!(actual, Frobnicate::Foo);
}
}

0 comments on commit fce4b46

Please sign in to comment.