diff --git a/src/front/zsharpcurly/interp.rs b/src/front/zsharpcurly/interp.rs index b0dc0638..3056c833 100644 --- a/src/front/zsharpcurly/interp.rs +++ b/src/front/zsharpcurly/interp.rs @@ -45,6 +45,10 @@ pub fn extract( }) .collect::, _>>()?, )), - Ty::Tuple(tys) => Ok(T::new_tuple(tys.iter().map(|ty| extract(name, ty, scalar_input_values)).collect::, _>>()?)), + Ty::Tuple(tys) => Ok(T::new_tuple( + tys.iter() + .map(|ty| extract(name, ty, scalar_input_values)) + .collect::, _>>()?, + )), } } diff --git a/src/front/zsharpcurly/mod.rs b/src/front/zsharpcurly/mod.rs index d23a5ff8..6db621b7 100644 --- a/src/front/zsharpcurly/mod.rs +++ b/src/front/zsharpcurly/mod.rs @@ -20,8 +20,8 @@ use std::collections::HashMap; use std::fmt::Display; use std::path::PathBuf; use std::str::FromStr; -use zokrates_curly_pest_ast as ast; use std::time; +use zokrates_curly_pest_ast as ast; use term::*; use zvisit::{ZConstLiteralRewriter, ZGenericInf, ZStatementWalker, ZVisitorMut}; @@ -37,7 +37,6 @@ pub struct Inputs { pub mode: Mode, } - #[allow(dead_code)] fn const_value_simple(term: &Term) -> Option { match term.op() { @@ -51,7 +50,7 @@ fn const_bool_simple(t: T) -> Option { match const_value_simple(&t.term) { Some(Value::Bool(b)) => Some(b), _ => None, - } + } } #[allow(dead_code)] @@ -132,7 +131,7 @@ struct ZGen<'ast> { } #[derive(Debug, Clone, PartialEq, Hash, Eq)] -struct FnCallImplInput(bool, Vec, Vec<(String,T)>, PathBuf, String); +struct FnCallImplInput(bool, Vec, Vec<(String, T)>, PathBuf, String); impl<'ast> Drop for ZGen<'ast> { fn drop(&mut self) { @@ -570,10 +569,9 @@ impl<'ast> ZGen<'ast> { .map_err(|e| format!("{e}"))? .unwrap_term() }; - let new = - loc_store(old, &zaccs[..], val) - .map(const_fold) - .and_then(|n| if strict { const_val_simple(n) } else { Ok(n) })?; + let new = loc_store(old, &zaccs[..], val) + .map(const_fold) + .and_then(|n| if strict { const_val_simple(n) } else { Ok(n) })?; debug!("Assign: {}", name); if IS_CNST { self.cvar_assign(name, new) @@ -591,7 +589,7 @@ impl<'ast> ZGen<'ast> { // Get the variable name and accesses from the assignee let name = &assign.assignee.id.value; let accs = &assign.assignee.accesses; - + // Convert AST accesses to IR accesses let zaccs = self.zaccs_impl_::(accs)?; // Get the current value @@ -610,9 +608,9 @@ impl<'ast> ZGen<'ast> { let new = loc_store(old, &zaccs[..], val) .map(const_fold) .and_then(|n| if IS_CNST { const_val_simple(n) } else { Ok(n) })?; - + debug!("Assembly Assign: {}", name); - + // Store the result if IS_CNST { self.cvar_assign(name, new) @@ -623,25 +621,28 @@ impl<'ast> ZGen<'ast> { } } - fn assembly_constraint_(&self, c: &ast::AssemblyConstraint) -> Result<(), String> { + fn assembly_constraint_( + &self, + c: &ast::AssemblyConstraint, + ) -> Result<(), String> { // Get expressions for both sides let lhs = self.expr_impl_::(&c.lhs)?; let rhs = self.expr_impl_::(&c.rhs)?; - + // Create equality comparison let eq_expr = term![EQ; lhs.term, rhs.term]; - + // Similar to assertion logic, check if it's a constant expression match const_bool_simple(T::new(Ty::Bool, eq_expr.clone())) { Some(true) => Ok(()), Some(false) => Err(format!( - "Const assembly constraint failed: {} == {} at\n{}", + "Const assembly constraint failed: {} == {} at\n{}", c.lhs.span().as_str(), c.rhs.span().as_str(), span_to_string(&c.span), )), None if IS_CNST => Err(format!( - "Const assembly constraint eval failed at\n{}", + "Const assembly constraint eval failed at\n{}", span_to_string(&c.span), )), _ => { @@ -661,7 +662,10 @@ impl<'ast> ZGen<'ast> { .map(|acc| match acc { ast::AssigneeAccess::Dot(m) => match &m.inner { ast::IdentifierOrDecimal::Identifier(i) => Ok(ZAccess::Member(i.value.clone())), - ast::IdentifierOrDecimal::Decimal(_) => Err(format!("Unsupported access of struct field by value: {}", span_to_string(&m.span))), + ast::IdentifierOrDecimal::Decimal(_) => Err(format!( + "Unsupported access of struct field by value: {}", + span_to_string(&m.span) + )), }, ast::AssigneeAccess::Select(m) => match &m.expression { ast::RangeOrExpression::Expression(e) => { @@ -702,9 +706,11 @@ impl<'ast> ZGen<'ast> { // otherwise we should return an error match Integer::from_str_radix(vstr, 10) { Ok(val) => Ok(field_lit(val)), - Err(_) => Err("Could not infer literal type. Annotation needed.".to_string()) + Err(_) => { + Err("Could not infer literal type. Annotation needed.".to_string()) + } } - }, + } } } ast::LiteralExpression::BooleanLiteral(b) => { @@ -828,11 +834,11 @@ impl<'ast> ZGen<'ast> { let before = time::Instant::now(); let input = FnCallImplInput( - IS_CNST, - args.clone(), - generic_vec.clone(), - f_path.clone(), - f_name.clone() + IS_CNST, + args.clone(), + generic_vec.clone(), + f_path.clone(), + f_name.clone(), ); let cached_value = self.fn_call_memoization.borrow().get(&input).cloned(); @@ -841,12 +847,13 @@ impl<'ast> ZGen<'ast> { } else { debug!("successfully memoized {} {:?}", f_name, f_path); self.function_call_impl_inner_::( - f, - args, - generics, - f_path.clone(), + f, + args, + generics, + f_path.clone(), f_name.clone(), - ).inspect(|v| { + ) + .inspect(|v| { self.fn_call_memoization .borrow_mut() .insert(input, v.clone()); @@ -907,8 +914,8 @@ impl<'ast> ZGen<'ast> { // XXX(unimpl) multi-return unimplemented let ret_ty = f - .return_type - .map(|r| self.type_impl_::(&r)) + .return_type + .map(|r| self.type_impl_::(&r)) .transpose()?; let ret_ty = if IS_CNST { self.cvar_enter_function(); @@ -1074,13 +1081,7 @@ impl<'ast> ZGen<'ast> { let name = "return".to_owned(); let ret_val = r.unwrap_term(); let ret_var_val = self - .circ_declare_input( - name, - ty, - ZVis::Public, - Some(ret_val.clone()), - false, - ) + .circ_declare_input(name, ty, ZVis::Public, Some(ret_val.clone()), false) .expect("circ_declare return"); let ret_eq = eq(ret_val, ret_var_val).unwrap().term; let mut assertions = std::mem::take(&mut *self.assertions.borrow_mut()); @@ -1152,9 +1153,7 @@ impl<'ast> ZGen<'ast> { match visibility { None | Some(ast::Visibility::Public(_)) => ZVis::Public, Some(ast::Visibility::Private(_)) => match self.mode { - Mode::Proof | Mode::Opt | Mode::ProofOfHighValue(_) => { - ZVis::Private(PROVER_ID) - } + Mode::Proof | Mode::Opt | Mode::ProofOfHighValue(_) => ZVis::Private(PROVER_ID), Mode::Mpc(_n_parties) => { // XXX(unimpl) party number panic!("Mpc mode is not implemented"); @@ -1298,11 +1297,10 @@ impl<'ast> ZGen<'ast> { .map_err(|err| format!("{}; context:\n{}", err, span_to_string(e.span()))) } - // XXX(rsw) make Result to give more precise error messages? fn expr_impl_inner_( - &self, - e: &ast::Expression<'ast> + &self, + e: &ast::Expression<'ast>, ) -> Result { if IS_CNST { debug!("Const expr: {}", e.span().as_str()); @@ -1312,7 +1310,11 @@ impl<'ast> ZGen<'ast> { match e { ast::Expression::Ternary(u) => { - match self.expr_impl_::(&u.condition).ok().and_then(const_bool_simple) { + match self + .expr_impl_::(&u.condition) + .ok() + .and_then(const_bool_simple) + { Some(true) => self.expr_impl_::(&u.consequence), Some(false) => self.expr_impl_::(&u.alternative), None if IS_CNST => Err("ternary condition not const bool".to_string()), @@ -1396,8 +1398,10 @@ impl<'ast> ZGen<'ast> { (res, &p.accesses[1..]) } else { match &*p.base { - ast::Expression::Identifier(id) =>(self.identifier_impl_::(id)?, &p.accesses[..]), - _ => panic!("Expected identifier in postfix expression base") + ast::Expression::Identifier(id) => { + (self.identifier_impl_::(id)?, &p.accesses[..]) + } + _ => panic!("Expected identifier in postfix expression base"), } }; accs.iter().try_fold(val, |v, acc| match acc { @@ -1410,9 +1414,11 @@ impl<'ast> ZGen<'ast> { ast::IdentifierOrDecimal::Identifier(id) => field_select(&v, &id.value), ast::IdentifierOrDecimal::Decimal(idx) => { if let Ty::Tuple(tys) = &v.ty { - let idx_val = idx.span.as_str().parse::().map_err(|_| { - "Invalid tuple index".to_string() - })?; + let idx_val = idx + .span + .as_str() + .parse::() + .map_err(|_| "Invalid tuple index".to_string())?; if idx_val < tys.len() { Ok(T::new( tys[idx_val].clone(), @@ -1426,11 +1432,14 @@ impl<'ast> ZGen<'ast> { )) } } else { - Err(format!("Cannot use decimal index on non-tuple type: {:?}", v.ty)) + Err(format!( + "Cannot use decimal index on non-tuple type: {:?}", + v.ty + )) } } } - }, + } ast::Access::Select(s) => self.array_access_impl_::(s, v), }) } @@ -1443,9 +1452,18 @@ impl<'ast> ZGen<'ast> { }) .collect::, String>>() .and_then(|members| Ok(T::new_struct(self.canon_struct(&u.ty.value)?, members))), - ast::Expression::InlineTuple(ite) => Ok(T::new_tuple(ite.elements.iter().map(|e| self.expr_impl_::(e)).collect::, _>>()?)), + ast::Expression::InlineTuple(ite) => Ok(T::new_tuple( + ite.elements + .iter() + .map(|e| self.expr_impl_::(e)) + .collect::, _>>()?, + )), ast::Expression::IfElse(u) => { - match self.expr_impl_::(&u.condition).ok().and_then(const_bool_simple) { + match self + .expr_impl_::(&u.condition) + .ok() + .and_then(const_bool_simple) + { Some(true) => self.expr_impl_::(&u.consequence), Some(false) => self.expr_impl_::(&u.alternative), None if IS_CNST => Err("IfElse condition not const bool".to_string()), @@ -1524,16 +1542,14 @@ impl<'ast> ZGen<'ast> { // XXX(unimpl) condstore, and witness from old zokrates // XXX(unimpl) log from new zokrates match s { - ast::Statement::Return(r) => { - if let Some(e) = r.expression.as_ref() { - self.set_lhs_ty_ret(r); - let ret = self.expr_impl_::(e)?; - self.ret_impl_::(Some(ret)) - } else { - self.ret_impl_::(None) - } - .map_err(|e| format!("{e}")) + ast::Statement::Return(r) => if let Some(e) = r.expression.as_ref() { + self.set_lhs_ty_ret(r); + let ret = self.expr_impl_::(e)?; + self.ret_impl_::(Some(ret)) + } else { + self.ret_impl_::(None) } + .map_err(|e| format!("{e}")), ast::Statement::Assertion(e) => { let expr = self.expr_impl_::(&e.expression)?; match const_bool_simple(expr.clone()) { @@ -1580,7 +1596,12 @@ impl<'ast> ZGen<'ast> { self.decl_impl_::(v_name, &ty)?; for j in s..e { self.enter_scope_impl_::(); - self.assign_impl_::(&i.index.identifier.value, &[][..], ival_cons(j), false)?; + self.assign_impl_::( + &i.index.identifier.value, + &[][..], + ival_cons(j), + false, + )?; for s in &i.statements { self.stmt_impl_::(s)?; } @@ -1605,11 +1626,7 @@ impl<'ast> ZGen<'ast> { "Assignment type mismatch: {decl_ty} annotated vs {ty} actual", )); } - self.declare_init_impl_::( - l.identifier.value.clone(), - decl_ty, - e, - )?; + self.declare_init_impl_::(l.identifier.value.clone(), decl_ty, e)?; Ok(()) } } @@ -1630,9 +1647,7 @@ impl<'ast> ZGen<'ast> { } Ok(()) } - ast::Statement::Log(_) => { - Err("Log statement is not implemented".to_string()) - } + ast::Statement::Log(_) => Err("Log statement is not implemented".to_string()), } .map_err(|err| format!("{}; context:\n{}", err, span_to_string(s.span()))) } @@ -1680,18 +1695,19 @@ impl<'ast> ZGen<'ast> { ast::AssigneeAccess::Dot(sa) => { let id_value = match &sa.inner { ast::IdentifierOrDecimal::Identifier(id) => &id.value, - _ => panic!("Expected an Identifier, but got a non-Identifier value in sa"), - }; + _ => panic!( + "Expected an Identifier, but got a non-Identifier value in sa" + ), + }; match ty { - Ty::Struct(nm, map) => map - .search(id_value) - .map(|r| r.1.clone()) - .ok_or_else(|| { + Ty::Struct(nm, map) => { + map.search(id_value).map(|r| r.1.clone()).ok_or_else(|| { format!("No such member {} of struct {nm}", id_value) - }), + }) + } ty => Err(format!("Attempted member access on non-Struct type {ty}")), } - }, + } }) } TypedIdentifier(t) => self.type_impl_::(&t.ty), @@ -1831,7 +1847,10 @@ impl<'ast> ZGen<'ast> { .unwrap_or(false) { self.err( - format!("Constant {} clashes with import of same name", &c.id.identifier.value), + format!( + "Constant {} clashes with import of same name", + &c.id.identifier.value + ), &c.span, ); } @@ -1877,7 +1896,10 @@ impl<'ast> ZGen<'ast> { .insert(c.id.identifier.value.clone(), (c.id.ty.clone(), value)) .is_some() { - self.err(format!("Constant {} redefined", &c.id.identifier.value), &c.span); + self.err( + format!("Constant {} redefined", &c.id.identifier.value), + &c.span, + ); } } @@ -1945,7 +1967,10 @@ impl<'ast> ZGen<'ast> { sdef.fields .iter() .map::, _>(|f| { - Ok((f.id.identifier.value.clone(), self.type_impl_::(&f.id.ty)?)) + Ok(( + f.id.identifier.value.clone(), + self.type_impl_::(&f.id.ty)?, + )) }) .collect::, _>>()?, ), @@ -1955,13 +1980,12 @@ impl<'ast> ZGen<'ast> { self.file_stack_pop(); Ok(ty) } - ast::Type::Tuple(t) => { - t.elements - .iter() - .map(|element_type| self.type_impl_::(element_type)) - .collect::, _>>() - .map(Ty::Tuple) - } + ast::Type::Tuple(t) => t + .elements + .iter() + .map(|element_type| self.type_impl_::(element_type)) + .collect::, _>>() + .map(Ty::Tuple), } } @@ -1993,38 +2017,39 @@ impl<'ast> ZGen<'ast> { for d in f.declarations.iter() { // XXX(opt) retain() declarations instead? if we don't need them, saves allocs if let ast::SymbolDeclaration::Import(i) = d { - let (src_path, src_names, dst_names, i_span) = match i { - ast::ImportDirective::Main(m) => ( - m.source.raw.value.clone(), - vec!["main".to_owned()], - vec![m - .alias - .as_ref() - .map(|a| a.value.clone()) - .unwrap_or_else(|| { - PathBuf::from(m.source.raw.value.clone()) - .file_stem() - .unwrap_or_else(|| panic!("Bad import: {}", m.source.raw.value)) - .to_string_lossy() - .to_string() - })], - &m.span, - ), - ast::ImportDirective::From(m) => ( - m.source.raw.value.clone(), - m.symbols.iter().map(|s| s.id.value.clone()).collect(), - m.symbols - .iter() - .map(|s| { - s.alias - .as_ref() - .map(|a| a.value.clone()) - .unwrap_or_else(|| s.id.value.clone()) - }) - .collect(), - &m.span, - ), - }; + let (src_path, src_names, dst_names, i_span) = + match i { + ast::ImportDirective::Main(m) => ( + m.source.raw.value.clone(), + vec!["main".to_owned()], + vec![m.alias.as_ref().map(|a| a.value.clone()).unwrap_or_else( + || { + PathBuf::from(m.source.raw.value.clone()) + .file_stem() + .unwrap_or_else(|| { + panic!("Bad import: {}", m.source.raw.value) + }) + .to_string_lossy() + .to_string() + }, + )], + &m.span, + ), + ast::ImportDirective::From(m) => ( + m.source.raw.value.clone(), + m.symbols.iter().map(|s| s.id.value.clone()).collect(), + m.symbols + .iter() + .map(|s| { + s.alias + .as_ref() + .map(|a| a.value.clone()) + .unwrap_or_else(|| s.id.value.clone()) + }) + .collect(), + &m.span, + ), + }; assert!(!src_names.is_empty()); let abs_src_path = self.stdlib.canonicalize(&self.cur_dir(), src_path.as_str()); debug!( @@ -2122,7 +2147,11 @@ impl<'ast> ZGen<'ast> { for d in t.get_mut(&p).unwrap().declarations.iter_mut() { match d { ast::SymbolDeclaration::Constant(c) => { - debug!("processing decl: const {} in {}", c.id.identifier.value, p.display()); + debug!( + "processing decl: const {} in {}", + c.id.identifier.value, + p.display() + ); self.const_decl_(c); } ast::SymbolDeclaration::Struct(s) => { @@ -2190,7 +2219,10 @@ impl<'ast> ZGen<'ast> { // go through stmts typechecking and rewriting literals let mut sw = ZStatementWalker::new( f_ast.parameters.as_ref(), - f_ast.return_type.as_ref().map_or(&[], |ty| std::slice::from_ref(ty)), + f_ast + .return_type + .as_ref() + .map_or(&[], |ty| std::slice::from_ref(ty)), f_ast.generics.as_ref(), self, ); diff --git a/src/front/zsharpcurly/term.rs b/src/front/zsharpcurly/term.rs index 29588880..cb95c0d3 100644 --- a/src/front/zsharpcurly/term.rs +++ b/src/front/zsharpcurly/term.rs @@ -245,8 +245,8 @@ impl T { } pub fn new_integer(v: I) -> Self - where - Integer: From + where + Integer: From, { T::new(Ty::Integer, int_lit(v)) } @@ -383,7 +383,15 @@ fn add_integer(a: Term, b: Term) -> Term { } pub fn add(a: T, b: T) -> Result { - wrap_bin_op("+", Some(add_uint), Some(add_field), None, Some(add_integer), a, b) + wrap_bin_op( + "+", + Some(add_uint), + Some(add_field), + None, + Some(add_integer), + a, + b, + ) } fn sub_uint(a: Term, b: Term) -> Term { @@ -399,7 +407,15 @@ fn sub_integer(a: Term, b: Term) -> Term { } pub fn sub(a: T, b: T) -> Result { - wrap_bin_op("-", Some(sub_uint), Some(sub_field), None, Some(sub_integer), a, b) + wrap_bin_op( + "-", + Some(sub_uint), + Some(sub_field), + None, + Some(sub_integer), + a, + b, + ) } fn mul_uint(a: Term, b: Term) -> Term { @@ -415,7 +431,15 @@ fn mul_integer(a: Term, b: Term) -> Term { } pub fn mul(a: T, b: T) -> Result { - wrap_bin_op("*", Some(mul_uint), Some(mul_field), None, Some(mul_integer), a, b) + wrap_bin_op( + "*", + Some(mul_uint), + Some(mul_field), + None, + Some(mul_integer), + a, + b, + ) } fn div_uint(a: Term, b: Term) -> Term { @@ -431,7 +455,15 @@ fn div_integer(a: Term, b: Term) -> Term { } pub fn div(a: T, b: T) -> Result { - wrap_bin_op("/", Some(div_uint), Some(div_field), None, Some(div_integer), a, b) + wrap_bin_op( + "/", + Some(div_uint), + Some(div_field), + None, + Some(div_integer), + a, + b, + ) } fn to_dflt_f(t: Term) -> Term { @@ -454,7 +486,15 @@ fn rem_integer(a: Term, b: Term) -> Term { } pub fn rem(a: T, b: T) -> Result { - wrap_bin_op("%", Some(rem_uint), Some(rem_field), None, Some(rem_integer), a, b) + wrap_bin_op( + "%", + Some(rem_uint), + Some(rem_field), + None, + Some(rem_integer), + a, + b, + ) } fn bitand_uint(a: Term, b: Term) -> Term { @@ -539,7 +579,15 @@ fn ult_integer(a: Term, b: Term) -> Term { } pub fn ult(a: T, b: T) -> Result { - wrap_bin_pred("<", Some(ult_uint), Some(ult_field), None, Some(ult_integer), a, b) + wrap_bin_pred( + "<", + Some(ult_uint), + Some(ult_field), + None, + Some(ult_integer), + a, + b, + ) } fn ule_uint(a: Term, b: Term) -> Term { @@ -555,7 +603,15 @@ fn ule_integer(a: Term, b: Term) -> Term { } pub fn ule(a: T, b: T) -> Result { - wrap_bin_pred("<=", Some(ule_uint), Some(ule_field), None, Some(ule_integer), a, b) + wrap_bin_pred( + "<=", + Some(ule_uint), + Some(ule_field), + None, + Some(ule_integer), + a, + b, + ) } fn ugt_uint(a: Term, b: Term) -> Term { @@ -571,7 +627,15 @@ fn ugt_integer(a: Term, b: Term) -> Term { } pub fn ugt(a: T, b: T) -> Result { - wrap_bin_pred(">", Some(ugt_uint), Some(ugt_field), None, Some(ugt_integer), a, b) + wrap_bin_pred( + ">", + Some(ugt_uint), + Some(ugt_field), + None, + Some(ugt_integer), + a, + b, + ) } fn uge_uint(a: Term, b: Term) -> Term { @@ -587,18 +651,31 @@ fn uge_integer(a: Term, b: Term) -> Term { } pub fn uge(a: T, b: T) -> Result { - wrap_bin_pred(">=", Some(uge_uint), Some(uge_field), None, Some(uge_integer), a, b) + wrap_bin_pred( + ">=", + Some(uge_uint), + Some(uge_field), + None, + Some(uge_integer), + a, + b, + ) } - pub fn pow(a: T, b: T) -> Result { if (a.ty != Ty::Field && a.ty != Ty::Integer) || b.ty != Ty::Uint(32) { - return Err(format!("Cannot compute {a} ** {b} : must be Field/Integer ** U32")); + return Err(format!( + "Cannot compute {a} ** {b} : must be Field/Integer ** U32" + )); } let b = const_int(b)?; if b == 0 { - return Ok((if a.ty == Ty::Field {T::new_field} else {T::new_integer})(1)) + return Ok((if a.ty == Ty::Field { + T::new_field + } else { + T::new_integer + })(1)); } Ok((0..b.significant_bits() - 1) @@ -644,7 +721,14 @@ fn neg_integer(a: Term) -> Term { // Missing from ZoKrates. pub fn neg(a: T) -> Result { - wrap_un_op("unary-", Some(neg_uint), Some(neg_field), None, Some(neg_integer), a) + wrap_un_op( + "unary-", + Some(neg_uint), + Some(neg_field), + None, + Some(neg_integer), + a, + ) } fn not_bool(a: Term) -> Term { @@ -677,7 +761,7 @@ pub fn const_bool(a: T) -> Option { pub fn const_fold(t: T) -> T { let folded = constant_fold(&t.term, &[]); - return T::new(t.ty, folded) + return T::new(t.ty, folded); } pub fn const_val(a: T) -> Result { @@ -762,7 +846,6 @@ where T::new(Ty::Uint(bits), bv_lit(v, bits)) } - pub fn slice(arr: T, start: Option, end: Option) -> Result { match &arr.ty { Ty::Array(size, _) => { @@ -901,7 +984,10 @@ pub fn uint_to_field(u: T) -> Result { pub fn integer_to_field(u: T) -> Result { match &u.ty { - Ty::Integer => Ok(T::new(Ty::Field, term![Op::IntToPf(default_field()); u.term])), + Ty::Integer => Ok(T::new( + Ty::Field, + term![Op::IntToPf(default_field()); u.term], + )), u => Err(format!("Cannot do int-to-field on {u}")), } } @@ -913,7 +999,6 @@ pub fn field_to_integer(u: T) -> Result { } } - pub fn int_to_bits(i: T, n: usize) -> Result { match &i.ty { Ty::Integer => uint_to_bits(T::new(Ty::Uint(n), term![Op::IntToBv(n); i.term])), @@ -930,7 +1015,10 @@ pub fn int_size(i: T) -> Result { pub fn int_modinv(i: T, m: T) -> Result { match (&i.ty, &m.ty) { - (Ty::Integer, Ty::Integer) => Ok(T::new(Ty::Integer, term![Op::IntBinOp(IntBinOp::ModInv); i.term, m.term])), + (Ty::Integer, Ty::Integer) => Ok(T::new( + Ty::Integer, + term![Op::IntBinOp(IntBinOp::ModInv); i.term, m.term], + )), u => Err(format!("Cannot do modinv on {:?}", u)), } } diff --git a/src/front/zsharpcurly/zvisit/zconstlitrw.rs b/src/front/zsharpcurly/zvisit/zconstlitrw.rs index cad8502d..01e4012c 100644 --- a/src/front/zsharpcurly/zvisit/zconstlitrw.rs +++ b/src/front/zsharpcurly/zvisit/zconstlitrw.rs @@ -75,10 +75,7 @@ impl<'ast> ZVisitorMut<'ast> for ZConstLiteralRewriter { self.visit_span(&mut te.span) } - fn visit_if_else_expression( - &mut self, - ie: &mut ast::IfElseExpression<'ast>, - ) -> ZVisitorResult { + fn visit_if_else_expression(&mut self, ie: &mut ast::IfElseExpression<'ast>) -> ZVisitorResult { // first expression in a ternary should have type bool let to_ty = self.replace(Some(Ty::Bool)); self.visit_expression(&mut ie.condition)?; @@ -91,8 +88,15 @@ impl<'ast> ZVisitorMut<'ast> for ZConstLiteralRewriter { fn visit_binary_expression(&mut self, be: &mut ast::BinaryExpression<'ast>) -> ZVisitorResult { let (ty_l, ty_r) = { match be.op { - ast::BinaryOperator::Pow | ast::BinaryOperator::RightShift | ast::BinaryOperator::LeftShift => (self.to_ty.clone(), Some(Ty::Uint(32))), - ast::BinaryOperator::Eq | ast::BinaryOperator::NotEq | ast::BinaryOperator::Lt | ast::BinaryOperator::Gt | ast::BinaryOperator::Lte | ast::BinaryOperator::Gte => (None, None), + ast::BinaryOperator::Pow + | ast::BinaryOperator::RightShift + | ast::BinaryOperator::LeftShift => (self.to_ty.clone(), Some(Ty::Uint(32))), + ast::BinaryOperator::Eq + | ast::BinaryOperator::NotEq + | ast::BinaryOperator::Lt + | ast::BinaryOperator::Gt + | ast::BinaryOperator::Lte + | ast::BinaryOperator::Gte => (None, None), _ => (self.to_ty.clone(), self.to_ty.clone()), } }; @@ -267,7 +271,7 @@ impl<'ast> ZVisitorMut<'ast> for ZConstLiteralRewriter { ) -> ZVisitorResult { use ast::Expression; match *pe.base { - Expression::Identifier(ref mut id) =>self.visit_identifier_expression(id)?, + Expression::Identifier(ref mut id) => self.visit_identifier_expression(id)?, _ => panic!("Expected identifier in postfix expression base"), } //self.visit_identifier_expression(&mut pe.base.id)?; diff --git a/src/front/zsharpcurly/zvisit/zgenericinf.rs b/src/front/zsharpcurly/zvisit/zgenericinf.rs index 5a25cf31..d5c83a44 100644 --- a/src/front/zsharpcurly/zvisit/zgenericinf.rs +++ b/src/front/zsharpcurly/zvisit/zgenericinf.rs @@ -443,12 +443,8 @@ impl<'ast, 'gen, const IS_CNST: bool> ZGenericInf<'ast, 'gen, IS_CNST> { ArrayInitializer(_) => { Err("ZGenericInf: got ArrayInitializer in array dim expr (unimpl)".into()) } - IfElse(_) => { - Err("ZGenericInf: got IfElse in array dim expr (unimpl)".into()) - }, - InlineTuple(_) => { - Err("ZGenericInf: got InlineTuple in array dim expr (unimpl)".into()) - } + IfElse(_) => Err("ZGenericInf: got IfElse in array dim expr (unimpl)".into()), + InlineTuple(_) => Err("ZGenericInf: got InlineTuple in array dim expr (unimpl)".into()), } } } diff --git a/src/front/zsharpcurly/zvisit/zstmtwalker/mod.rs b/src/front/zsharpcurly/zvisit/zstmtwalker/mod.rs index 85aebc04..b320e120 100644 --- a/src/front/zsharpcurly/zvisit/zstmtwalker/mod.rs +++ b/src/front/zsharpcurly/zvisit/zstmtwalker/mod.rs @@ -157,8 +157,10 @@ impl<'ast, 'ret> ZStatementWalker<'ast, 'ret> { // handle first access, which is special because only this one could be a Call() let acc = &mut pf.accesses; let id = match *pf.base { - ast::Expression::Identifier(ref identifier) => identifier, - _ => panic!("Expected an Expression::Identifier, but found a different expression type"), + ast::Expression::Identifier(ref identifier) => identifier, + _ => { + panic!("Expected an Expression::Identifier, but found a different expression type") + } }; let alen = acc.len(); let (pf_id_ty, acc_offset) = if let Call(ca) = acc.first_mut().unwrap() { @@ -171,12 +173,12 @@ impl<'ast, 'ret> ZStatementWalker<'ast, 'ret> { "ZStatementWalker: fn {} has no return type", &id.value, ))) - }, + } Some(_) => { // Assuming `alen` is the count of arguments and `rty` is defined elsewhere let rty = if alen == 1 { rty } else { None }; Ok((self.get_call_ty(fdef, ca, rty)?, 1)) - }, + } } })? } else { @@ -323,7 +325,7 @@ impl<'ast, 'ret> ZStatementWalker<'ast, 'ret> { span_to_string(&it.span), ))); }; - + // Check if the number of elements in the inline tuple matches the expected tuple type if tt.elements.len() != it.elements.len() { return Err(ZVisitorError(format!( @@ -333,12 +335,14 @@ impl<'ast, 'ret> ZStatementWalker<'ast, 'ret> { span_to_string(&it.span), ))); } - + // Unify each element of the inline tuple with the corresponding type in the tuple type tt.elements .iter() .zip(it.elements.iter_mut()) - .try_for_each(|(expected_ty, element)| self.unify_expression(expected_ty.clone(), element)) + .try_for_each(|(expected_ty, element)| { + self.unify_expression(expected_ty.clone(), element) + }) } fn unify_identifier( @@ -389,7 +393,9 @@ impl<'ast, 'ret> ZStatementWalker<'ast, 'ret> { }; let (lt, rt) = match &be.op { - ast::BinaryOperator::BitXor | ast::BinaryOperator::BitAnd | ast::BinaryOperator::BitOr => match &bt { + ast::BinaryOperator::BitXor + | ast::BinaryOperator::BitAnd + | ast::BinaryOperator::BitOr => match &bt { U8(_) | U16(_) | U32(_) | U64(_) => Ok((Basic(bt.clone()), Basic(bt))), _ => Err(ZVisitorError( "ZStatementWalker: Bit/Rem operators require U* operands".to_owned(), @@ -409,13 +415,22 @@ impl<'ast, 'ret> ZStatementWalker<'ast, 'ret> { "ZStatementWalker: Logical-And/Or operators require Bool operands".to_owned(), )), }, - ast::BinaryOperator::Add | ast::BinaryOperator::Sub | ast::BinaryOperator::Mul | ast::BinaryOperator::Div | ast::BinaryOperator::Rem => match &bt { + ast::BinaryOperator::Add + | ast::BinaryOperator::Sub + | ast::BinaryOperator::Mul + | ast::BinaryOperator::Div + | ast::BinaryOperator::Rem => match &bt { Boolean(_) => Err(ZVisitorError( "ZStatementWalker: +,-,*,/ operators require Field or U* operands".to_owned(), )), _ => Ok((Basic(bt.clone()), Basic(bt))), }, - ast::BinaryOperator::Eq | ast::BinaryOperator::NotEq | ast::BinaryOperator::Lt | ast::BinaryOperator::Gt | ast::BinaryOperator::Lte | ast::BinaryOperator::Gte => match &bt { + ast::BinaryOperator::Eq + | ast::BinaryOperator::NotEq + | ast::BinaryOperator::Lt + | ast::BinaryOperator::Gt + | ast::BinaryOperator::Lte + | ast::BinaryOperator::Gte => match &bt { Boolean(_) => { let mut expr_walker = ZExpressionTyper::new(self); let lty = self.type_expression(&mut be.left, &mut expr_walker)?; @@ -859,7 +874,7 @@ impl<'ast, 'ret> ZVisitorMut<'ast> for ZStatementWalker<'ast, 'ret> { ) -> ZVisitorResult { // XXX(unimpl) no L<-R generic inference right now. // REVISIT: if LHS is generic typed identifier and RHS has complete type, infer L<-R? - self.visit_typed_identifier_or_assignee(&mut def.lhs)?; + self.visit_typed_identifier_or_assignee(&mut def.lhs)?; // unify lhs and rhs let ty_accs = match &def.lhs { @@ -872,7 +887,7 @@ impl<'ast, 'ret> ZVisitorMut<'ast> for ZStatementWalker<'ast, 'ret> { self.lookup_type_varonly(na).map(|t| t.map(|t| (t, acc))) } }; - + if let Ok(Some((ty, accs))) = ty_accs { let ty = self.walk_accesses(ty, accs, aacc_to_msacc)?; self.unify(Some(ty), &mut def.expression)?; diff --git a/src/front/zsharpcurly/zvisit/zstmtwalker/zexprtyper.rs b/src/front/zsharpcurly/zvisit/zstmtwalker/zexprtyper.rs index a9820ba2..49618079 100644 --- a/src/front/zsharpcurly/zvisit/zstmtwalker/zexprtyper.rs +++ b/src/front/zsharpcurly/zvisit/zstmtwalker/zexprtyper.rs @@ -107,10 +107,7 @@ impl<'ast, 'ret, 'wlk> ZVisitorMut<'ast> for ZExpressionTyper<'ast, 'ret, 'wlk> Ok(()) } - fn visit_if_else_expression( - &mut self, - ie: &mut ast::IfElseExpression<'ast>, - ) -> ZVisitorResult { + fn visit_if_else_expression(&mut self, ie: &mut ast::IfElseExpression<'ast>) -> ZVisitorResult { self.visit_expression(&mut ie.consequence)?; let ty2 = self.take()?; self.visit_expression(&mut ie.alternative)?; @@ -131,7 +128,14 @@ impl<'ast, 'ret, 'wlk> ZVisitorMut<'ast> for ZExpressionTyper<'ast, 'ret, 'wlk> use ast::{BasicType::*, Type::*}; assert!(self.ty.is_none()); match &be.op { - ast::BinaryOperator::Or | ast::BinaryOperator::And | ast::BinaryOperator::Eq | ast::BinaryOperator::NotEq | ast::BinaryOperator::Lt | ast::BinaryOperator::Gt | ast::BinaryOperator::Lte | ast::BinaryOperator::Gte => { + ast::BinaryOperator::Or + | ast::BinaryOperator::And + | ast::BinaryOperator::Eq + | ast::BinaryOperator::NotEq + | ast::BinaryOperator::Lt + | ast::BinaryOperator::Gt + | ast::BinaryOperator::Lte + | ast::BinaryOperator::Gte => { self.ty .replace(Basic(Boolean(ast::BooleanType { span: be.span }))); } @@ -139,7 +143,16 @@ impl<'ast, 'ret, 'wlk> ZVisitorMut<'ast> for ZExpressionTyper<'ast, 'ret, 'wlk> self.ty .replace(Basic(Field(ast::FieldType { span: be.span }))); } - ast::BinaryOperator::BitXor | ast::BinaryOperator::BitAnd | ast::BinaryOperator::BitOr | ast::BinaryOperator::RightShift | ast::BinaryOperator::LeftShift | ast::BinaryOperator::Add | ast::BinaryOperator::Sub | ast::BinaryOperator::Mul | ast::BinaryOperator::Div | ast::BinaryOperator::Rem => { + ast::BinaryOperator::BitXor + | ast::BinaryOperator::BitAnd + | ast::BinaryOperator::BitOr + | ast::BinaryOperator::RightShift + | ast::BinaryOperator::LeftShift + | ast::BinaryOperator::Add + | ast::BinaryOperator::Sub + | ast::BinaryOperator::Mul + | ast::BinaryOperator::Div + | ast::BinaryOperator::Rem => { self.visit_expression(&mut be.left)?; let ty_l = self.take()?; self.visit_expression(&mut be.right)?; @@ -164,8 +177,14 @@ impl<'ast, 'ret, 'wlk> ZVisitorMut<'ast> for ZExpressionTyper<'ast, 'ret, 'wlk> .to_string(), )); } - if matches!(&be.op, ast::BinaryOperator::BitXor | ast::BinaryOperator::BitAnd | ast::BinaryOperator::BitOr | ast::BinaryOperator::RightShift | ast::BinaryOperator::LeftShift) - && matches!(&ty, Basic(Field(_))) + if matches!( + &be.op, + ast::BinaryOperator::BitXor + | ast::BinaryOperator::BitAnd + | ast::BinaryOperator::BitOr + | ast::BinaryOperator::RightShift + | ast::BinaryOperator::LeftShift + ) && matches!(&ty, Basic(Field(_))) { return Err(ZVisitorError( "ZExpressionTyper: got Field for a binop that cannot support it" diff --git a/src/front/zsharpcurly/zvisit/zvmut.rs b/src/front/zsharpcurly/zvisit/zvmut.rs index 55b9d76e..6ec2a08e 100644 --- a/src/front/zsharpcurly/zvisit/zvmut.rs +++ b/src/front/zsharpcurly/zvisit/zvmut.rs @@ -113,10 +113,7 @@ pub trait ZVisitorMut<'ast>: Sized { Ok(()) } - fn visit_private_visibility( - &mut self, - _pr: &mut ast::PrivateVisibility, - ) -> ZVisitorResult { + fn visit_private_visibility(&mut self, _pr: &mut ast::PrivateVisibility) -> ZVisitorResult { Ok(()) } @@ -288,10 +285,7 @@ pub trait ZVisitorMut<'ast>: Sized { walk_ternary_expression(self, te) } - fn visit_if_else_expression( - &mut self, - ie: &mut ast::IfElseExpression<'ast>, - ) -> ZVisitorResult { + fn visit_if_else_expression(&mut self, ie: &mut ast::IfElseExpression<'ast>) -> ZVisitorResult { walk_if_else_expression(self, ie) } @@ -327,9 +321,12 @@ pub trait ZVisitorMut<'ast>: Sized { Ok(()) } - fn visit_assign_constrain_operator(&mut self, _aco: &mut ast::AssignConstrainOperator) -> ZVisitorResult { + fn visit_assign_constrain_operator( + &mut self, + _aco: &mut ast::AssignConstrainOperator, + ) -> ZVisitorResult { Ok(()) - } + } fn visit_postfix_expression( &mut self, @@ -384,7 +381,10 @@ pub trait ZVisitorMut<'ast>: Sized { walk_dot_access(self, ma) } - fn visit_identifier_or_decimal(&mut self, ido: &mut ast::IdentifierOrDecimal<'ast>) -> ZVisitorResult { + fn visit_identifier_or_decimal( + &mut self, + ido: &mut ast::IdentifierOrDecimal<'ast>, + ) -> ZVisitorResult { walk_identifier_or_decimal(self, ido) }