Skip to content

Commit

Permalink
Update standard library
Browse files Browse the repository at this point in the history
This updates to the original standard library used by corset.  In order
to get this to work, some fixes related to function overloading were
required.
  • Loading branch information
DavePearce committed Dec 19, 2024
1 parent cacd2c6 commit c9aed70
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 84 deletions.
47 changes: 29 additions & 18 deletions pkg/corset/binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func (p *LocalVariableBinding) Finalise(index uint) {
// function bindings.
type OverloadedBinding struct {
// Available specialisations
overloads []FunctionSignature
overloads []*DefunBinding
}

// IsPure checks whether this is a defpurefun or not
Expand All @@ -289,15 +289,20 @@ func (p *OverloadedBinding) IsPure() bool {

// IsFinalised checks whether this binding has been finalised yet or not.
func (p *OverloadedBinding) IsFinalised() bool {
// Unclear to me whether or not this really makes sense.
for _, binding := range p.overloads {
if !binding.IsFinalised() {
return false
}
}
//
return true
}

// HasArity checks whether this function accepts a given number of arguments (or
// not).
func (p *OverloadedBinding) HasArity(arity uint) bool {
for _, sig := range p.overloads {
if sig.NumParameters() == arity {
for _, binding := range p.overloads {
if binding.HasArity(arity) {
// match
return true
}
Expand All @@ -315,8 +320,12 @@ func (p *OverloadedBinding) Select(args []Type) *FunctionSignature {
var selected *FunctionSignature
// Attempt to select the Greated Lower Bound (GLB). This can fail if there
// is no unique GLB.
for _, sig := range p.overloads {
for _, binding := range p.overloads {
// Extract its function signature
sig := binding.Signature()
// Check whether its applicable to the given argument types.
applicable := sig.Accepts(args)
// If it is applicable, then update the current selection as necessary.
if applicable && selected == nil {
selected = &sig
} else if applicable && sig.SubtypeOf(selected) {
Expand All @@ -336,21 +345,20 @@ func (p *OverloadedBinding) Select(args []Type) *FunctionSignature {
// (e.g. intrinsics) cannot be overloaded; (2) duplicate overloadings are
// not permitted; (3) combinding pure and impure overloadings is also not
// permitted.
func (p *OverloadedBinding) Overload(binding *DefunBinding) (FunctionBinding, bool) {
func (p *OverloadedBinding) Overload(overload *DefunBinding) (FunctionBinding, bool) {
// Check matches purity
if binding.IsPure() != p.IsPure() {
if overload.IsPure() != p.IsPure() {
return nil, false
}
// Check overload does not already exist
for _, sig := range p.overloads {
if reflect.DeepEqual(sig.parameters, binding.paramTypes) {
for _, binding := range p.overloads {
if reflect.DeepEqual(binding.paramTypes, overload.paramTypes) {
// Already declared
return nil, false
}
}
// Otherwise, looks good.
sig := FunctionSignature{binding.pure, binding.paramTypes, binding.returnType, binding.body}
p.overloads = append(p.overloads, sig)
p.overloads = append(p.overloads, overload)
//
return p, true
}
Expand Down Expand Up @@ -398,6 +406,12 @@ func (p *DefunBinding) HasArity(arity uint) bool {
return arity == uint(len(p.paramTypes))
}

// Signature returns the corresponding function signature for this user-defined
// function.
func (p *DefunBinding) Signature() FunctionSignature {
return FunctionSignature{p.pure, p.paramTypes, p.returnType, p.body}
}

// Finalise this binding by providing the necessary missing information.
func (p *DefunBinding) Finalise(bodyType Type) {
p.bodyType = bodyType
Expand All @@ -421,17 +435,14 @@ func (p *DefunBinding) Select(args []Type) *FunctionSignature {
// (e.g. intrinsics) cannot be overloaded; (2) duplicate overloadings are
// not permitted; (3) combinding pure and impure overloadings is also not
// permitted.
func (p *DefunBinding) Overload(binding *DefunBinding) (FunctionBinding, bool) {
if p.IsPure() != binding.IsPure() {
func (p *DefunBinding) Overload(overload *DefunBinding) (FunctionBinding, bool) {
if p.IsPure() != overload.IsPure() {
// Purity is misaligned
return nil, false
} else if reflect.DeepEqual(p.paramTypes, binding.paramTypes) {
} else if reflect.DeepEqual(p.paramTypes, overload.paramTypes) {
// Specialisation already exists!
return nil, false
}
// Construct initial overloadings
first := FunctionSignature{p.pure, p.paramTypes, p.returnType, p.body}
second := FunctionSignature{binding.pure, binding.paramTypes, binding.returnType, binding.body}
//
return &OverloadedBinding{[]FunctionSignature{first, second}}, true
return &OverloadedBinding{[]*DefunBinding{p, overload}}, true
}
6 changes: 4 additions & 2 deletions pkg/corset/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,9 @@ func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) (Type
types, errs := r.finaliseExpressionsInModule(scope, v.Args)
return GreatestLowerBoundAll(types), errs
} else if v, ok := expr.(*Normalise); ok {
return r.finaliseExpressionInModule(scope, v.Arg)
_, errs := r.finaliseExpressionInModule(scope, v.Arg)
// Normalise guaranteed to return either 0 or 1.
return NewUintType(1), errs
} else if v, ok := expr.(*Reduce); ok {
return r.finaliseReduceInModule(scope, v)
} else if v, ok := expr.(*Shift); ok {
Expand Down Expand Up @@ -628,7 +630,7 @@ func (r *resolver) finaliseInvokeInModule(scope LocalScope, expr *Invoke) (Type,
expected := signature.Parameter(uint(i))
actual := argTypes[i]
// subtype check
if !actual.SubtypeOf(expected) {
if actual != nil && !actual.SubtypeOf(expected) {
msg := fmt.Sprintf("expected type %s (found %s)", expected, actual)
errors = append(errors, *r.srcmap.SyntaxError(expr.args[i], msg))
}
Expand Down
58 changes: 23 additions & 35 deletions pkg/corset/stdlib.lisp
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
;; [TODO] (defunalias debug-assert debug)

(defpurefun (if-zero cond then) (if (vanishes! cond) then))
;; [TODO] (defpurefun (if-zero cond then else) (if (vanishes! cond) then else))
(defpurefun (if-zero-else cond then else) (if (vanishes! cond) then else))
(defpurefun (if-zero cond then else) (if (vanishes! cond) then else))

(defpurefun (if-not-zero cond then) (if (force-bool cond) then))
;; [TODO] (defpurefun (if-not-zero cond then else) (if (force-bool cond) then else))
(defpurefun (if-not-zero cond then else) (if (force-bool cond) then else))

(defpurefun ((force-bool :@bool :force) x) x)
(defpurefun ((is-binary :@loob :force) e0) (* e0 (- 1 e0)))
Expand All @@ -18,23 +15,24 @@
;; !-suffix denotes loobean algebra (i.e. 0 == true)
;; ~-prefix denotes normalized-functions (i.e. output is 0/1)
(defpurefun (and a b) (* a b))
;; [TODO] (defpurefun ((~and :binary@bool) a b) (~ (and a b)))
(defpurefun ((~and :binary@bool) a b) (~ (and a b)))
(defpurefun ((or! :@loob) a b) (* a b))
;; [TODO] (defpurefun ((~or! :binary@loob) a b) (~ (or! a b)))
(defpurefun ((~or! :binary@loob) a b) (~ (or! a b)))

;; [TODO] (defpurefun ((not :binary@bool :force) (x :binary)) (- 1 x))
(defpurefun ((not :binary@bool :force) (x :binary)) (- 1 x))

(defpurefun ((eq! :binary@loob :force) (x :binary) (y :binary)) (^ (- x y) 2))
(defpurefun ((eq! :@loob) x y) (- x y))
;; [TODO] (defpurefun ((neq! :binary@loob :force) x y) (not (~ (eq! x y))))
;; [TODO] (defunalias = eq!)
(defpurefun ((neq! :binary@loob :force) x y) (not (~ (eq! x y))))
(defunalias = eq!)

;; [TODO] (defpurefun ((eq :binary@bool :force) (x :binary) (y :binary)) (^ (- x y) 2))
(defpurefun ((eq :binary@bool :force) (x :binary) (y :binary)) (- 1 (^ (- x y) 2)))
(defpurefun ((eq :binary@bool :force) x y) (- 1 (~ (eq! x y))))
(defpurefun ((neq :binary@bool :force) x y) (eq! x y))

;; Variadic variations on and/or
;; [TODO] (defunalias any! *)
;; [TODO] (defunalias all *)
(defunalias any! *)
(defunalias all *)

;; Boolean functions
(defpurefun ((is-not-zero :binary@bool) x) (~ x))
Expand Down Expand Up @@ -67,8 +65,8 @@

;; Ensure (in loobean logic) that e0 has changed (resp. will change) its value
;; with regards to the previous (resp. next) row.
;; [TODO] (defpurefun (did-change! e0) (neq! e0 (prev e0)))
;; [TODO] (defpurefun (will-change! e0) (neq! e0 (next e0)))
(defpurefun (did-change! e0) (neq! e0 (prev e0)))
(defpurefun (will-change! e0) (neq! e0 (next e0)))

(defpurefun (did-change e0) (neq e0 (prev e0)))
(defpurefun (will-change e0) (neq e0 (next e0)))
Expand Down Expand Up @@ -99,7 +97,7 @@

;; base-X decomposition constraints
(defpurefun (base-X-decomposition ct base acc digits)
(if-zero-else ct
(if-zero ct
(eq! acc digits)
(eq! acc (+ (* base (prev acc)) digits))))

Expand All @@ -110,28 +108,18 @@
(defpurefun (bit-decomposition ct acc bits) (base-X-decomposition ct 2 acc bits))

;; plateau constraints
;; [TODO] (defpurefun (plateau-constraint CT (X :binary) C)
;; (begin (debug-assert (stamp-constancy CT C))
;; (if-zero C
;; (eq! X 1)
;; (if (eq! CT 0)
;; (vanishes! X)
;; (if (eq! CT C)
;; (did-inc! X 1)
;; (remained-constant! X))))))
(defpurefun (plateau-constraint CT (X :binary) C)
(begin (debug (stamp-constancy CT C))
(if-zero C
(eq! X 1)
(if (eq! CT 0)
(vanishes! X)
(if (eq! CT C)
(did-inc! X 1)
(remained-constant! X))))))

;; stamp constancy imposes that the column C may only
;; change at rows where the STAMP column changes.
(defpurefun (stamp-constancy STAMP C)
(if (will-remain-constant! STAMP)
(will-remain-constant! C)))

;; =============================================================================
;; Add
;; =============================================================================
(defpurefun (if-not-eq X Y then)
(if (eq! X Y)
;; True branch
(vanishes! 0)
;; False branch
then))
28 changes: 14 additions & 14 deletions testdata/bin-dynamic.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
LLARGEMO 15)

(defpurefun (if-eq-else A B THEN ELSE)
(if-zero-else (- A B)
(if-zero (- A B)
THEN
ELSE))

Expand All @@ -81,7 +81,7 @@

;; 2.3 Instruction decoding
(defconstraint no-bin-no-flag ()
(if-zero-else STAMP
(if-zero STAMP
(vanishes! (flag-sum))
(eq! (flag-sum) 1)))

Expand All @@ -104,7 +104,7 @@

(defconstraint isbyte-ctmax ()
(if-eq (+ IS_BYTE IS_SIGNEXTEND) 1
(if-zero-else ARG_1_HI
(if-zero ARG_1_HI
(eq! CT_MAX LLARGEMO)
(vanishes! CT_MAX))))

Expand Down Expand Up @@ -207,23 +207,23 @@
;; 2.9 pivot constraints
(defconstraint pivot (:guard CT_MAX)
(begin (if-eq IS_BYTE 1
(if-zero-else LOW_4
(if-zero LOW_4
(if-zero CT
(if-zero-else BIT_B_4
(if-zero BIT_B_4
(eq! PIVOT BYTE_3)
(eq! PIVOT BYTE_4)))
(if-zero (+ (prev BIT_1) (- 1 BIT_1))
(if-zero-else BIT_B_4
(if-zero BIT_B_4
(eq! PIVOT BYTE_3)
(eq! PIVOT BYTE_4)))))
(if-eq IS_SIGNEXTEND 1
(if-eq-else LOW_4 LLARGEMO
(if-zero CT
(if-zero-else BIT_B_4
(if-zero BIT_B_4
(eq! PIVOT BYTE_4)
(eq! PIVOT BYTE_3)))
(if-zero (+ (prev BIT_1) (- 1 BIT_1))
(if-zero-else BIT_B_4
(if-zero BIT_B_4
(eq! PIVOT BYTE_4)
(eq! PIVOT BYTE_3)))))))

Expand All @@ -233,31 +233,31 @@
;; ;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(defconstraint is-byte-result (:guard IS_BYTE)
(if-zero-else CT_MAX
(if-zero CT_MAX
(begin (vanishes! RES_HI)
(vanishes! RES_LO))
(begin (vanishes! RES_HI)
(eq! RES_LO (* SMALL PIVOT)))))

(defconstraint is-signextend-result (:guard IS_SIGNEXTEND)
(if-zero-else CT_MAX
(if-zero CT_MAX
(begin (eq! RES_HI ARG_2_HI)
(eq! RES_LO ARG_2_LO))
(if-zero-else SMALL
(if-zero SMALL
;; SMALL == 0
(begin (eq! RES_HI ARG_2_HI)
(eq! RES_LO ARG_2_LO))
;; SMALL == 1
(begin (if-zero-else BIT_B_4
(begin (if-zero BIT_B_4
;; b4 == 0
(begin (eq! BYTE_5 (* NEG 255))
(if-zero-else BIT_1
(if-zero BIT_1
;; [[1]] == 0
(eq! BYTE_6 (* NEG 255))
;; [[1]] == 1
(eq! BYTE_6 BYTE_4)))
;; b4 == 1
(begin (if-zero-else BIT_1
(begin (if-zero BIT_1
;; [[1]] == 0
(eq! BYTE_5 (* NEG 255))
;; [[1]] == 1
Expand Down
Loading

0 comments on commit c9aed70

Please sign in to comment.