Skip to content

Commit

Permalink
Merge pull request #16 from bruderj15/prelude-like-api
Browse files Browse the repository at this point in the history
Prelude like api
  • Loading branch information
studJBccl authored Jun 12, 2024
2 parents c802539 + 2cb96b2 commit 27e757d
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 71 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ Therefore, this allows you to use the much richer subset of Haskell than a purel
For instance, to define the addition of two `V3` containing a Real-SMT-Expression:
```haskell
v3Add :: V3 (Expr RealType) -> V3 (Expr RealType) -> V3 (Expr RealType)
v3Add = _
v3Add = liftA2 (+)
```
Even better, the [Expr-GADT](https://github.com/bruderj15/Hasmtlib/blob/master/src/Language/Hasmtlib/Internal/Expr.hs) allows for a polymorph definition:
```haskell
v3Add :: Num (Expr t) => V3 (Expr t) -> V3 (Expr t) -> V3 (Expr t)
v3Add = _
v3Add = liftA2 (+)
```
This looks a lot like the [definition of Num](https://hackage.haskell.org/package/linear-1.23/docs/src/Linear.V3.html#local-6989586621679182277) for `V3 a`:
```haskell
Expand Down Expand Up @@ -66,7 +66,7 @@ May print: `(Sat,Just (V3 (-2.0) (-1.0) 0.0,V3 (-2.0) (-1.0) 0.0))`
- [x] Add your own solvers via the [Solver type](https://github.com/bruderj15/Hasmtlib/blob/master/src/Language/Hasmtlib/Type/Solver.hs)

### Coming
- [ ] Incremental solving
- [ ] Incremental solving (work in progress)
- [ ] Observable sharing
- [ ] Quantifiers `for_all` and `exists`

Expand Down
1 change: 1 addition & 0 deletions hasmtlib.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ library
, Language.Hasmtlib.Variable
, Language.Hasmtlib.Equatable
, Language.Hasmtlib.Orderable
, Language.Hasmtlib.Integraled
, Language.Hasmtlib.Internal.Parser
, Language.Hasmtlib.Internal.Bitvec
, Language.Hasmtlib.Internal.Render
Expand Down
2 changes: 2 additions & 0 deletions src/Language/Hasmtlib.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module Language.Hasmtlib
, module Language.Hasmtlib.Type.Expr
, module Language.Hasmtlib.Type.Solver
, module Language.Hasmtlib.Type.Solution
, module Language.Hasmtlib.Integraled
, module Language.Hasmtlib.Iteable
, module Language.Hasmtlib.Boolean
, module Language.Hasmtlib.Equatable
Expand All @@ -21,6 +22,7 @@ import Language.Hasmtlib.Type.SMT
import Language.Hasmtlib.Type.Expr
import Language.Hasmtlib.Type.Solver
import Language.Hasmtlib.Type.Solution
import Language.Hasmtlib.Integraled
import Language.Hasmtlib.Iteable
import Language.Hasmtlib.Boolean
import Language.Hasmtlib.Equatable
Expand Down
63 changes: 33 additions & 30 deletions src/Language/Hasmtlib/Boolean.hs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
{-# LANGUAGE NoImplicitPrelude #-}

module Language.Hasmtlib.Boolean where

import Prelude (Bool(..), (.), id, Eq(..))
import Data.Bit
import Data.Coerce
import Data.Bits as Bits
import Data.Foldable (foldl')
import Data.Foldable hiding (and, or)
import qualified Data.Vector.Unboxed.Sized as V
import GHC.TypeNats

Expand All @@ -22,73 +25,73 @@ class Boolean b where
false = bool False

-- | Logical conjunction.
(&&&) :: b -> b -> b
(&&) :: b -> b -> b

-- | Logical disjunction (inclusive or).
(|||) :: b -> b -> b
(||) :: b -> b -> b

-- | Logical implication.
(==>) :: b -> b -> b
x ==> y = not' x ||| y
x ==> y = not x || y

-- | Logical negation
not' :: b -> b
not :: b -> b

-- | Exclusive-or
xor :: b -> b -> b

infixr 3 &&&
infixr 2 |||
infixr 3 &&
infixr 2 ||
infixr 0 ==>

-- | The logical conjunction of several values.
and' :: (Foldable t, Boolean b) => t b -> b
and' = foldl' (&&&) true
and :: (Foldable t, Boolean b) => t b -> b
and = foldl' (&&) true

-- | The logical disjunction of several values.
or' :: (Foldable t, Boolean b) => t b -> b
or' = foldl' (|||) false
or :: (Foldable t, Boolean b) => t b -> b
or = foldl' (||) false

-- | The negated logical conjunction of several values.
--
-- @'nand' = 'neg' . 'and'@
nand :: (Foldable t, Boolean b) => t b -> b
nand = not' . and'
nand = not . and

-- | The negated logical disjunction of several values.
--
-- @'nor' = 'neg' . 'or'@
nor :: (Foldable t, Boolean b) => t b -> b
nor = not' . or'
nor = not . or

-- | The logical conjunction of the mapping of a function over several values.
all' :: (Foldable t, Boolean b) => (a -> b) -> t a -> b
all' p = foldl' (\acc b -> acc &&& p b) true
all :: (Foldable t, Boolean b) => (a -> b) -> t a -> b
all p = foldl' (\acc b -> acc && p b) true

-- | The logical disjunction of the mapping of a function over several values.
any' :: (Foldable t, Boolean b) => (a -> b) -> t a -> b
any' p = foldl' (\acc b -> acc ||| p b) false
any :: (Foldable t, Boolean b) => (a -> b) -> t a -> b
any p = foldl' (\acc b -> acc || p b) false

instance Boolean Bool where
bool = id
true = True
false = False
(&&&) = (&&)
(|||) = (||)
not' = not
(&&) = (&&)
(||) = (||)
not = not
xor = (/=)

instance Boolean Bit where
bool = Bit
(&&&)= (.&.)
(|||) = (.|.)
not' = complement
xor = Bits.xor
bool = Bit
(&&) = (.&.)
(||) = (.|.)
not = complement
xor = Bits.xor

-- | Bitwise operations
instance KnownNat n => Boolean (V.Vector n Bit) where
bool = V.replicate . coerce
(&&&) = V.zipWith (&&&)
(|||) = V.zipWith (|||)
not' = V.map not'
xor = V.zipWith Bits.xor
bool = V.replicate . coerce
(&&) = V.zipWith (&&)
(||) = V.zipWith (||)
not = V.map not
xor = V.zipWith Bits.xor
13 changes: 7 additions & 6 deletions src/Language/Hasmtlib/Codec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

module Language.Hasmtlib.Codec where

import Prelude hiding (not, (&&), (||))
import Language.Hasmtlib.Internal.Bitvec
import Language.Hasmtlib.Internal.Expr
import Language.Hasmtlib.Type.Solution
Expand Down Expand Up @@ -74,9 +75,9 @@ instance KnownSMTRepr t => Codec (Expr t) where
decode sol (Distinct x y) = liftA2 (/=) (decode sol x) (decode sol y)
decode sol (GTHE x y) = liftA2 (>=) (decode sol x) (decode sol y)
decode sol (GTH x y) = liftA2 (>) (decode sol x) (decode sol y)
decode sol (Not x) = fmap not' (decode sol x)
decode sol (And x y) = liftA2 (&&&) (decode sol x) (decode sol y)
decode sol (Or x y) = liftA2 (|||) (decode sol x) (decode sol y)
decode sol (Not x) = fmap not (decode sol x)
decode sol (And x y) = liftA2 (&&) (decode sol x) (decode sol y)
decode sol (Or x y) = liftA2 (||) (decode sol x) (decode sol y)
decode sol (Impl x y) = liftA2 (==>) (decode sol x) (decode sol y)
decode sol (Xor x y) = liftA2 xor (decode sol x) (decode sol y)
decode _ Pi = Just pi
Expand All @@ -92,9 +93,9 @@ instance KnownSMTRepr t => Codec (Expr t) where
decode sol (ToInt x) = fmap truncate (decode sol x)
decode sol (IsInt x) = fmap ((0 ==) . snd . properFraction) (decode sol x)
decode sol (Ite p t f) = liftM3 (\p' t' f' -> if p' then t' else f') (decode sol p) (decode sol t) (decode sol f)
decode sol (BvNot x) = fmap not' (decode sol x)
decode sol (BvAnd x y) = liftA2 (&&&) (decode sol x) (decode sol y)
decode sol (BvOr x y) = liftA2 (|||) (decode sol x) (decode sol y)
decode sol (BvNot x) = fmap not (decode sol x)
decode sol (BvAnd x y) = liftA2 (&&) (decode sol x) (decode sol y)
decode sol (BvOr x y) = liftA2 (||) (decode sol x) (decode sol y)
decode sol (BvXor x y) = liftA2 xor (decode sol x) (decode sol y)
decode sol (BvNand x y) = nand <$> sequenceA [decode sol x, decode sol y]
decode sol (BvNor x y) = nor <$> sequenceA [decode sol x, decode sol y]
Expand Down
5 changes: 3 additions & 2 deletions src/Language/Hasmtlib/Equatable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

module Language.Hasmtlib.Equatable where

import Prelude hiding (not, (&&))
import Language.Hasmtlib.Internal.Expr
import Language.Hasmtlib.Boolean
import GHC.Generics
Expand All @@ -23,7 +24,7 @@ class Equatable a where
a === b = from a ===# from b

(/==) :: a -> a -> Expr BoolType
x /== y = not' $ x === y
x /== y = not $ x === y

infix 4 ===, /==

Expand All @@ -41,7 +42,7 @@ instance GEquatable V1 where
x ===# y = x `seq` y `seq` error "GEquatable[V1].===#"

instance (GEquatable f, GEquatable g) => GEquatable (f :*: g) where
(a :*: b) ===# (c :*: d) = (a ===# c) &&& (b ===# d)
(a :*: b) ===# (c :*: d) = (a ===# c) && (b ===# d)

instance (GEquatable f, GEquatable g) => GEquatable (f :+: g) where
L1 a ===# L1 b = a ===# b
Expand Down
5 changes: 3 additions & 2 deletions src/Language/Hasmtlib/Example/Arith.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Language.Hasmtlib.Example.Arith where

import Prelude hiding (mod, (&&))
import Language.Hasmtlib

main :: IO ()
Expand All @@ -10,8 +11,8 @@ main = do
x <- var @IntType
y <- var @IntType

assert $ y >? 0
assert $ x `mod'` 42 === y
assert $ y >? 0 && x /== y
assert $ x `mod` 42 === y
assert $ y + x + 1 >=? x + y

return (x,y)
Expand Down
41 changes: 41 additions & 0 deletions src/Language/Hasmtlib/Integraled.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{-# LANGUAGE DefaultSignatures #-}

module Language.Hasmtlib.Integraled where

import qualified Prelude as P
import Numeric.Natural
import Data.Word
import Data.Functor.Identity
import Data.Functor.Const

class Integraled a where
quot :: a -> a -> a
n `quot` d = q where (q,_) = quotRem n d

rem :: a -> a -> a
n `rem` d = r where (_,r) = quotRem n d

div :: a -> a -> a
n `div` d = q where (q,_) = divMod n d

mod :: a -> a -> a
n `mod` d = r where (_,r) = divMod n d

quotRem :: a -> a -> (a, a)
default quotRem :: P.Integral a => a -> a -> (a, a)
quotRem = P.quotRem

divMod :: a -> a -> (a, a)
default divMod :: P.Integral a => a -> a -> (a, a)
divMod = P.quotRem

instance Integraled P.Int
instance Integraled P.Integer
instance Integraled P.Word
instance Integraled Natural
instance Integraled Word8
instance Integraled Word16
instance Integraled Word32
instance Integraled Word64
instance P.Integral a => Integraled (Identity a)
instance P.Integral a => Integraled (Const a b)
20 changes: 10 additions & 10 deletions src/Language/Hasmtlib/Internal/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,18 @@ data Expr (t :: SMTType) where
deriving instance Show (Expr t)

instance Boolean (Expr BoolType) where
bool = Constant . BoolValue
(&&&) = And
(|||) = Or
not' = Not
xor = Xor
bool = Constant . BoolValue
(&&) = And
(||) = Or
not = Not
xor = Xor

instance KnownNat n => Boolean (Expr (BvType n)) where
bool = Constant . BvValue . bool
(&&&) = BvAnd
(|||) = BvOr
not' = BvNot
xor = BvXor
bool = Constant . BvValue . bool
(&&) = BvAnd
(||) = BvOr
not = BvNot
xor = BvXor

instance Bounded (Expr BoolType) where
minBound = false
Expand Down
26 changes: 16 additions & 10 deletions src/Language/Hasmtlib/Internal/Expr/Num.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module Language.Hasmtlib.Internal.Expr.Num where

import Prelude hiding (div, mod, quotRem, rem, quot, divMod)
import Language.Hasmtlib.Internal.Expr
import Language.Hasmtlib.Integraled
import Language.Hasmtlib.Iteable
import Language.Hasmtlib.Equatable
import Language.Hasmtlib.Orderable
Expand Down Expand Up @@ -55,17 +57,21 @@ instance Floating (Expr RealType) where
acosh = error "SMT-Solver currently do not support acosh"
atanh = error "SMT-Solver currently do not support atanh"

-- | Integer modulus
mod' :: Expr IntType -> Expr IntType -> Expr IntType
mod' = Mod
instance Integraled (Expr IntType) where
quot = IDiv
rem = Mod
div = IDiv
mod = Mod
quotRem x y = (quot x y, rem x y)
divMod x y = (div x y, mod x y)

-- | Integer division
div' :: Expr IntType -> Expr IntType -> Expr IntType
div' = IDiv

-- | Unsigned bitvector remainder
bvuRem :: KnownNat n => Expr (BvType n) -> Expr (BvType n) -> Expr (BvType n)
bvuRem = BvuRem
instance KnownNat n => Integraled (Expr (BvType n)) where
quot = BvuDiv
rem = BvuRem
div = BvuDiv
mod = BvuRem
quotRem x y = (quot x y, rem x y)
divMod x y = (div x y, mod x y)

-- | Bitvector shift left
bvShL :: KnownNat n => Expr (BvType n) -> Expr (BvType n) -> Expr (BvType n)
Expand Down
9 changes: 5 additions & 4 deletions src/Language/Hasmtlib/Internal/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

module Language.Hasmtlib.Internal.Parser where

import Prelude hiding (not, (&&), (||), and , or)
import Language.Hasmtlib.Internal.Bitvec
import Language.Hasmtlib.Internal.Render
import Language.Hasmtlib.Internal.Expr
Expand Down Expand Up @@ -104,8 +105,8 @@ parseExpr = var <|> constant <|> smtIte
<|> unary "sin" sin <|> unary "cos" cos <|> unary "tan" tan
<|> unary "arcsin" asin <|> unary "arccos" acos <|> unary "arctan" atan
BoolRepr -> isIntFun
<|> unary "not" not'
<|> nary "and" and' <|> nary "or" or' <|> binary "=>" (==>) <|> binary "xor" xor
<|> unary "not" not
<|> nary "and" and <|> nary "or" or <|> binary "=>" (==>) <|> binary "xor" xor
<|> binary @IntType "=" (===) <|> binary @IntType "distinct" (/==)
<|> binary @RealType "=" (===) <|> binary @RealType "distinct" (/==)
<|> binary @BoolType "=" (===) <|> binary @BoolType "distinct" (/==)
Expand All @@ -116,8 +117,8 @@ parseExpr = var <|> constant <|> smtIte
-- TODO: All (?) bv lengths - also for '=' and 'distinct'
-- <|> binary @(BvType 10) "bvult" (<?) <|> binary @(BvType 10) "bvule" (<=?)
-- <|> binary @(BvType 10) "bvuge" (>=?) <|> binary @(BvType 10) "bvugt" (>?)
BvRepr _ -> unary "bvnot" not'
<|> binary "bvand" (&&&) <|> binary "bvor" (|||) <|> binary "bvxor" xor <|> binary "bvnand" BvNand <|> binary "bvnor" BvNor
BvRepr _ -> unary "bvnot" not
<|> binary "bvand" (&&) <|> binary "bvor" (||) <|> binary "bvxor" xor <|> binary "bvnand" BvNand <|> binary "bvnor" BvNor
<|> unary "bvneg" negate
<|> binary "bvadd" (+) <|> binary "bvsub" (-) <|> binary "bvmul" (*)
<|> binary "bvudiv" BvuDiv <|> binary "bvurem" BvuRem
Expand Down
Loading

0 comments on commit 27e757d

Please sign in to comment.