diff --git a/cel-rs/src/context.rs b/cel-rs/src/context.rs index a43cc3a..abd06ad 100644 --- a/cel-rs/src/context.rs +++ b/cel-rs/src/context.rs @@ -1,21 +1,42 @@ +use crate::{function::Function, value::value::Val}; use std::{collections::HashMap, rc::Rc}; -use crate::value::value::{Val}; -#[derive(Default)] pub struct Context { par: Option>, - vars: HashMap<&'static str, Val> + variables: HashMap<&'static str, Val>, + funtions: HashMap<&'static str, Function>, +} + +impl Default for Context { + fn default() -> Self { + Self { + par: Default::default(), + variables: Default::default(), + funtions: HashMap::from([ + ("dyn", crate::std::new_dyn()) + ]), + } + } } impl Context { - pub fn add_variable(mut self, name: &'static str, val: Val) -> Self { - self.vars.insert(name, val); + pub fn add_variable(&mut self, name: &'static str, val: Val) -> &mut Self { + self.variables.insert(name, val); self } - pub fn resolve(&self, name: &String) -> Option<&Val> { - self.vars.get(name.as_str()) + pub fn resolve_variable(&self, name: &String) -> Option<&Val> { + self.variables.get(name.as_str()) } + + pub fn add_function(&mut self, name: &'static str, func: Function) -> &mut Self { + self.funtions.insert(name, func); + self + } + pub fn resolve_function(&self, name: &String) -> Option<&Function> { + self.funtions.get(name.as_str()) + } + pub fn parent(&self) -> Option> { - self.par.clone() + self.par.clone() } -} \ No newline at end of file +} diff --git a/cel-rs/src/eval.rs b/cel-rs/src/eval.rs index fb34991..fc8d64c 100644 --- a/cel-rs/src/eval.rs +++ b/cel-rs/src/eval.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::rc::Rc; -use crate::parser::{Atom, Expression, RelationOp}; +use crate::parser::{Atom, Expression, Member, RelationOp}; use crate::value::value::Val; use crate::Context; @@ -9,51 +9,45 @@ use crate::Context; pub struct Eval {} impl Eval { - // fn eval_member(&self, expr: Expression, member: Member, ctx: &mut Context) -> impl Bag { - // let v = self.eval(expr, ctx).unpack(); - // match member { - // crate::parser::Member::Attribute(attr) => { - // if let Value::Map(v) = v { - // let value = v.get(&Value::String(attr)).expect("TODO: unknown map key"); - // return value.to_owned() - // } - // if let Some(val) = ctx.resolve(&attr) { - // return val - // } - // panic!("unknown attribute {}", attr) - // }, - // crate::parser::Member::FunctionCall(name, mut rargs) => { - // let mut args = Vec::with_capacity(rargs.len()); - // rargs.reverse(); - // for arg in rargs { - // args.push(self.eval(arg, ctx).unpack()); - // } + fn eval_function( + &self, + name: Rc, + receiver: Option, + argexprs: Vec, + ctx: &mut Context, + ) -> Val { + let mut args = Vec::with_capacity(argexprs.len() + 1); - // if let Some(val) = ctx.resolve(&name) { - // args.push(v.clone()); - // args.reverse(); - // if let Value::Function(f) = val { - // return (f.overloads.first().unwrap().func)(args) - // } - // } + if let Some(expr) = receiver { + args.push(expr) + } - // panic!("is not a func") - // }, - // crate::parser::Member::Index(i) => { - // let i = self.eval(*i, ctx).unpack(); - // if let Value::Map(v) = v { - // let value = v.get(&i).expect("TODO: unknown map key"); - // return value.to_owned() - // } - // Value::Null - // }, - // crate::parser::Member::Fields(_) => todo!("Fields"), - // } - // } + for expr in argexprs { + args.push(self.eval(expr, ctx)); + } + if let Some(func) = ctx.resolve_function(&name) { + return (func.overloads.first().unwrap().func)(args) + } + Val::new_error("unknown func".to_string()) + } + fn eval_member(&self, expr: Box, member: Box, ctx: &mut Context) -> Val { + let v = self.eval(*expr, ctx); + match *member { + crate::parser::Member::Attribute(attr) => todo!(), + crate::parser::Member::FunctionCall(name, argexprs) => { + self.eval_function(name, Some(v), argexprs, ctx) + } + crate::parser::Member::Index(i) => todo!(), + crate::parser::Member::Fields(_) => todo!(), + } + } pub fn eval(&self, expr: Expression, ctx: &mut Context) -> Val { match expr { + Expression::GlobalFunctionCall(name, argexprs) => { + self.eval_function(name, None, argexprs, ctx) + } Expression::Arithmetic(_, _, _) => todo!(), Expression::Relation(left, op, right) => { let l = self.eval(*left, ctx); @@ -72,13 +66,12 @@ impl Eval { Expression::Or(_, _) => todo!(), Expression::And(_, _) => todo!(), Expression::Unary(_, _) => todo!(), - Expression::Member(_, _) => todo!(), - Expression::FunctionCall(_) => todo!(), + Expression::Member(expr, member) => self.eval_member(expr, member, ctx), Expression::List(values) => self.eval_list(values, ctx), Expression::Map(entries) => self.eval_map(entries, ctx), Expression::Atom(atom) => self.eval_atom(atom, ctx), Expression::Ident(ident) => ctx - .resolve(&ident) + .resolve_variable(&ident) .unwrap_or(&Val::new_error(format!("unknown variable {}", ident))) .to_owned(), } @@ -94,7 +87,6 @@ impl Eval { Val::new_map(Rc::new(map)) } - fn eval_list(&self, elems: Vec, ctx: &mut Context) -> Val { let mut list = Vec::with_capacity(elems.len()); for expr in elems { diff --git a/cel-rs/src/function.rs b/cel-rs/src/function.rs new file mode 100644 index 0000000..a46c30f --- /dev/null +++ b/cel-rs/src/function.rs @@ -0,0 +1,15 @@ +use crate::Val; + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct Function { + pub name: &'static str, + pub overloads: &'static [Overload], +} + +type Func = fn(args: Vec) -> Val; + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct Overload { + pub key: &'static str, + pub func: Func +} \ No newline at end of file diff --git a/cel-rs/src/lib.rs b/cel-rs/src/lib.rs index a7268ba..d712884 100644 --- a/cel-rs/src/lib.rs +++ b/cel-rs/src/lib.rs @@ -3,6 +3,8 @@ mod value; mod program; mod eval; mod parser; +mod function; +mod std; // public api pub use crate::program::Program; diff --git a/cel-rs/src/parser/ast.rs b/cel-rs/src/parser/ast.rs index 0ba30e8..7c8d210 100644 --- a/cel-rs/src/parser/ast.rs +++ b/cel-rs/src/parser/ast.rs @@ -40,7 +40,7 @@ pub enum Expression { Member(Box, Box), - FunctionCall(Box), + GlobalFunctionCall(Rc, Vec), List(Vec), Map(Vec<(Expression, Expression)>), diff --git a/cel-rs/src/parser/cel.lalrpop b/cel-rs/src/parser/cel.lalrpop index dbf4767..63231cf 100644 --- a/cel-rs/src/parser/cel.lalrpop +++ b/cel-rs/src/parser/cel.lalrpop @@ -36,7 +36,7 @@ pub Member: Expression = { pub Primary: Expression = { "."? => Expression::Ident(<>.into()).into(), "."? "(" > ")" => { - Expression::FunctionCall(Member::FunctionCall(identifier.into(), arguments).into()).into() + Expression::GlobalFunctionCall(identifier.into(), arguments) }, Atom => Expression::Atom(<>).into(), "[" > "]" => Expression::List(<>).into(), diff --git a/cel-rs/src/program.rs b/cel-rs/src/program.rs index 7cd5566..1441901 100644 --- a/cel-rs/src/program.rs +++ b/cel-rs/src/program.rs @@ -69,7 +69,8 @@ pub mod tests { #[test] fn test_bool() { - let mut ctx = program::Context::default().add_variable("a", Val::new_bool(true)); + let mut ctx = program::Context::default(); + ctx.add_variable("a", Val::new_bool(true)); assert_eq!(eval_program!(r#"a == true"#, &mut ctx), Val::new_bool(true)); assert_eq!(eval_program!(r#"a == false"#, &mut ctx), Val::new_bool(false)); } diff --git a/cel-rs/src/std.rs b/cel-rs/src/std.rs new file mode 100644 index 0000000..53ea91b --- /dev/null +++ b/cel-rs/src/std.rs @@ -0,0 +1,18 @@ +use crate::{function::{Function, Overload}, Val}; + + +fn invoke_dyn(args: Vec) -> Val { + args.first().unwrap().clone() +} + +pub fn new_dyn() -> Function { + Function { + name: "dyn", + overloads: &[Overload { + key: "dyn", + func: invoke_dyn, + }], + } +} + + diff --git a/cel-rs/src/value/bool.rs b/cel-rs/src/value/bool.rs index e7082b1..28baae2 100644 --- a/cel-rs/src/value/bool.rs +++ b/cel-rs/src/value/bool.rs @@ -36,9 +36,11 @@ impl Value for Bool { } fn equals(&self, other: &Val) -> Val { - other - .as_bool() - .map(|f| Val::new_bool(&self.0 == f)) - .unwrap_or(Val::new_bool(false)) + Val::new_bool( + other + .native_value() + .downcast_ref::() + .is_some_and(|f| f.eq(&self.0)), + ) } } diff --git a/cel-rs/src/value/bytes.rs b/cel-rs/src/value/bytes.rs index 9a547bc..0c3c805 100644 --- a/cel-rs/src/value/bytes.rs +++ b/cel-rs/src/value/bytes.rs @@ -19,6 +19,15 @@ impl Value for Bytes { &self.0 } + fn equals(&self, other: &Val) -> Val { + Val::new_bool( + other + .native_value() + .downcast_ref::>>() + .is_some_and(|f| f.eq(&self.0)), + ) + } + fn compare(&self, other: &Val) -> Option { other.native_value().downcast_ref::>>().map(|ob| { (&self.0).cmp(ob).into() diff --git a/cel-rs/src/value/double.rs b/cel-rs/src/value/double.rs index c1b38c4..b3d3ab3 100644 --- a/cel-rs/src/value/double.rs +++ b/cel-rs/src/value/double.rs @@ -18,6 +18,15 @@ impl Value for Double { &self.0 } + fn equals(&self, other: &Val) -> Val { + Val::new_bool( + other + .native_value() + .downcast_ref::() + .is_some_and(|f| f.eq(&self.0)), + ) + } + fn compare(&self, other: &Val) -> Option { let vl = other.native_value().downcast_ref::(); if vl.is_some() { diff --git a/cel-rs/src/value/error.rs b/cel-rs/src/value/error.rs index fef55d7..fb47f5e 100644 --- a/cel-rs/src/value/error.rs +++ b/cel-rs/src/value/error.rs @@ -1,6 +1,9 @@ use core::fmt; -use super::{value::{Value, Val}, ty::Ty}; +use super::{ + ty::Ty, + value::{Val, Value}, +}; #[derive(Eq, PartialEq)] pub struct Error { @@ -34,8 +37,22 @@ impl Value for Error { fn ty(&self) -> Ty { Ty::Error } - + fn native_value(&self) -> &dyn std::any::Any { self } + + fn equals(&self, other: &Val) -> Val { + if other.ty() != Ty::Error { + return Val::new_bool(false); + } + + Val::new_bool( + other + .native_value() + .downcast_ref::() + .map(|oerr| oerr.eq(self)) + .is_some_and(|f| f) + ) + } } diff --git a/cel-rs/src/value/int.rs b/cel-rs/src/value/int.rs index 1fd89c5..2abb780 100644 --- a/cel-rs/src/value/int.rs +++ b/cel-rs/src/value/int.rs @@ -17,12 +17,19 @@ impl Value for Int { &self.0 } + fn equals(&self, other: &Val) -> Val { + Val::new_bool( + other + .native_value() + .downcast_ref::() + .is_some_and(|f| f.eq(&self.0)), + ) + } + fn compare(&self, other: &Val) -> Option { other .native_value() .downcast_ref::() - .map(|oi| { - (&self.0).cmp(oi).into() - }) + .map(|oi| (&self.0).cmp(oi).into()) } } diff --git a/cel-rs/src/value/null.rs b/cel-rs/src/value/null.rs index 82df404..50ba0ba 100644 --- a/cel-rs/src/value/null.rs +++ b/cel-rs/src/value/null.rs @@ -22,6 +22,11 @@ impl Value for Null { fn native_value(&self) -> &dyn std::any::Any { &() } + + fn equals(&self, other: &Val) -> Val { + Val::new_bool(other.ty() == Ty::Null) + } + fn compare(&self, other: &Val) -> Option { if other.ty() == Ty::Null { return Some(Val::from(Ordering::Equal)) diff --git a/cel-rs/src/value/string.rs b/cel-rs/src/value/string.rs index 824da13..8a96cf3 100644 --- a/cel-rs/src/value/string.rs +++ b/cel-rs/src/value/string.rs @@ -20,6 +20,14 @@ impl Value for String { &self.0 } + fn equals(&self, other: &Val) -> Val { + Val::new_bool( + other + .native_value() + .downcast_ref::() + .is_some_and(|f| f.eq(&self.0)), + ) + } fn compare(&self, other: &Val) -> Option { other.native_value().downcast_ref::().map(|oths| { Val::from((&self.0).cmp(oths)) diff --git a/cel-rs/src/value/ty.rs b/cel-rs/src/value/ty.rs index 79b2e24..958d3fa 100644 --- a/cel-rs/src/value/ty.rs +++ b/cel-rs/src/value/ty.rs @@ -1,3 +1,5 @@ +use crate::Val; + use super::value::Value; // https://github.com/google/cel-spec/blob/master/doc/langdef.md#values @@ -48,4 +50,13 @@ impl Value for Ty { fn native_value(&self) -> &dyn std::any::Any { self } + + fn equals(&self, other: &Val) -> Val { + Val::new_bool( + other + .native_value() + .downcast_ref::() + .is_some_and(|f| f.eq(&self)), + ) + } } \ No newline at end of file diff --git a/cel-rs/src/value/uint.rs b/cel-rs/src/value/uint.rs index 6557ec6..c49e12d 100644 --- a/cel-rs/src/value/uint.rs +++ b/cel-rs/src/value/uint.rs @@ -20,6 +20,15 @@ impl Value for Uint { &self.0 } + fn equals(&self, other: &Val) -> Val { + Val::new_bool( + other + .native_value() + .downcast_ref::() + .is_some_and(|f| f.eq(&self.0)), + ) + } + fn compare(&self, other: &Val) -> Option { other .native_value() diff --git a/cel-rs/src/value/value.rs b/cel-rs/src/value/value.rs index 3a32a99..93d2344 100644 --- a/cel-rs/src/value/value.rs +++ b/cel-rs/src/value/value.rs @@ -53,13 +53,7 @@ impl Eq for Val {} impl PartialEq for Val { fn eq(&self, other: &Self) -> bool { - // TODO: switch other types to use equals instead. - if self.ty() == Ty::Map { - eprintln!("equals map"); - return self.equals(other).as_bool().expect("equals did not return bool").to_owned(); - } - self.partial_cmp(other) - .is_some_and(|ord| ord == cmp::Ordering::Equal) + return self.equals(other).as_bool().expect("equals did not return bool").to_owned(); } } diff --git a/cel-rs/tests/test.rs b/cel-rs/tests/test.rs index e03d16b..ecb5d11 100644 --- a/cel-rs/tests/test.rs +++ b/cel-rs/tests/test.rs @@ -10,3 +10,12 @@ cel_spec::suite!( skip_test = "self_eval_bytes_invalid_utf8", skip_test = "self_eval_unicode_escape_eight" ); + +cel_spec::suite!( + name = "comparisons", + + skip_section = "eq_wrapper", + skip_section = "in_list_literal", + skip_section = "in_map_literal", + skip_section = "bound", +); diff --git a/cel-spec/src/lib.rs b/cel-spec/src/lib.rs index 6d87e3d..54b769b 100644 --- a/cel-spec/src/lib.rs +++ b/cel-spec/src/lib.rs @@ -21,26 +21,34 @@ include!(concat!(env!("OUT_DIR"), "/tests.rs")); use darling::ast::NestedMeta; use darling::{Error, FromMeta}; use google::api::expr::test::v1::{simple_test::ResultMatcher, SimpleTestFile}; -use google::api::expr::v1alpha1::value::Kind; -use google::api::expr::v1alpha1::Value; +use google::api::expr::v1alpha1::{Value, value}; +use google::api::expr::v1alpha1::{ExprValue, expr_value}; use proc_macro::TokenStream; use prost::Message; +fn expand_expr_value(val: ExprValue) -> String { + match val.kind.unwrap() { + expr_value::Kind::Value(val) => expand_value(val), + expr_value::Kind::Error(_) => String::from("TODO: ExprValue::Error"), + expr_value::Kind::Unknown(_) => String::from("TODO: ExprValue::Unknown"), + } +} + fn expand_value(val: Value) -> String { match val.kind.unwrap() { - Kind::NullValue(_) => "cel_rs::Val::new_null()".to_string(), - Kind::BoolValue(b) => format!("cel_rs::Val::new_bool({})", b), - Kind::Int64Value(i) => format!("cel_rs::Val::new_int({})", i), - Kind::Uint64Value(ui) => format!("cel_rs::Val::new_uint({})", ui), - Kind::DoubleValue(db) => format!("cel_rs::Val::new_double({}f64)", db), - Kind::StringValue(str) => format!( + value::Kind::NullValue(_) => "cel_rs::Val::new_null()".to_string(), + value::Kind::BoolValue(b) => format!("cel_rs::Val::new_bool({})", b), + value::Kind::Int64Value(i) => format!("cel_rs::Val::new_int({})", i), + value::Kind::Uint64Value(ui) => format!("cel_rs::Val::new_uint({})", ui), + value::Kind::DoubleValue(db) => format!("cel_rs::Val::new_double({}f64)", db), + value::Kind::StringValue(str) => format!( "cel_rs::Val::new_string(std::rc::Rc::new(String::from_utf8({:?}.to_vec()).unwrap()))", str.as_bytes() ), - Kind::BytesValue(bytes) => { + value::Kind::BytesValue(bytes) => { format!("cel_rs::Val::new_bytes(Vec::from({:?}).into())", bytes) } - Kind::MapValue(map) => format!( + value::Kind::MapValue(map) => format!( "cel_rs::Val::new_map(std::collections::HashMap::::from([{}]).into())", map.entries.iter().map(|entry| { let key = entry.key.clone().unwrap(); @@ -49,10 +57,10 @@ fn expand_value(val: Value) -> String { format!("({}, {}),", expand_value(key), expand_value(value)) }).collect::>().join("\n") ), - Kind::ListValue(list) => format!("cel_rs::Val::new_list({})", "Vec::new().into()"), - Kind::EnumValue(en) => "TODO".to_string(), - Kind::ObjectValue(obj) => "TODO".to_string(), - Kind::TypeValue(ty) => "TODO".to_string(), + value::Kind::ListValue(list) => format!("cel_rs::Val::new_list({})", "Vec::new().into()"), + value::Kind::EnumValue(en) => "TODO: EnumValue".to_string(), + value::Kind::ObjectValue(obj) => "TODO: ObjectValue".to_string(), + value::Kind::TypeValue(ty) => "TODO: TypeValue".to_string(), } } @@ -60,13 +68,33 @@ fn expand_result_matcher(rm: Option) -> String { if rm.is_none() { panic!("result matcher is none."); } - if let ResultMatcher::Value(val) = rm.unwrap() { - expand_value(val) - } else { - String::from("TODO") + + match rm.unwrap() { + ResultMatcher::Value(val) => expand_value(val), + ResultMatcher::EvalError(err) => format!("cel_rs::Val::new_error({:?}.into())", err.errors[0].message), + ResultMatcher::AnyEvalErrors(eval) => format!("TODO: AnyEvalErrors: {:#?}", eval), + ResultMatcher::Unknown(unk) => format!("TODO: Unknown: {:#?}", unk), + ResultMatcher::AnyUnknowns(anyunk) => format!("TODO: AnyUnknowns: {:#?}", anyunk), } } +fn expand_bindings(bindings: HashMap) -> String { + let mut exp = String::new(); + + for (k, v) in bindings.iter() { + exp.push_str( + format!( + r#"ctx.add_variable("{k}", {v});"#, + k = k, + v = expand_expr_value(v.clone()) + ) + .as_str(), + ) + } + + exp +} + #[derive(Debug, FromMeta)] struct MacroArgs { name: String, @@ -98,8 +126,7 @@ pub fn suite(rargs: TokenStream) -> TokenStream { let mut ast = String::new(); for section in testfile.section { - - if args.skip_sections.contains(§ion.name){ + if args.skip_sections.contains(§ion.name) { continue; } @@ -108,13 +135,17 @@ pub fn suite(rargs: TokenStream) -> TokenStream { ast.push_str("{"); for test in section.test { - if args.skip_tests.contains(&test.name){ + if args.skip_tests.contains(&test.name) { continue; } let expected_value = expand_result_matcher(test.result_matcher); - ast.push_str(&format!(r##" + let bindings = expand_bindings(test.bindings); + + ast.push_str( + &format!( + r##" #[test] fn r#{name}() {{ let expr = r#"{expr}"#; @@ -122,11 +153,19 @@ pub fn suite(rargs: TokenStream) -> TokenStream { assert!(program.is_ok(), "failed to parse '{{}}'", expr); let program = program.unwrap(); let mut ctx = cel_rs::Context::default(); + {bindings} let value = program.eval(&mut ctx); let expected_value = {expected_value}; assert_eq!(value, expected_value); }} - "##, name = test.name, expr = test.expr, expected_value = expected_value ).to_string()); + "##, + name = test.name, + expr = test.expr, + expected_value = expected_value, + bindings = bindings + ) + .to_string(), + ); } ast.push_str("}");