Skip to content

Commit

Permalink
Merge pull request #83 from bruderj15/82-shrink-expr
Browse files Browse the repository at this point in the history
Implemented #80 and #82
  • Loading branch information
bruderj15 authored Aug 25, 2024
2 parents c234cf4 + 3ade8c1 commit 2b9b967
Show file tree
Hide file tree
Showing 12 changed files with 671 additions and 296 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,23 @@ file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [PVP versioning](https://pvp.haskell.org/).

## v2.5.0 _(2024-08-25)_

### Added
- added instances `Eq`, `Ord`, `GEq` and `GCompare` for `Expr t`
- added instances `Real` and `Enum` for `Expr IntSort`, `Expr RealSort` and `Expr (BvSort n)`
- added instance `Integral` for `Expr IntSort` and `Expr (BvSort n)`
- added instance `Bits` for `Expr BoolSort` and `Expr (BvSort n)`

### Changed
- Removed `Language.Hasmtlib.Integraled`: use the added `Integral` instance instead
- Removed redundant BitVec constructors from `Expr` and replaced usage in instances with the more generic existing ones.
For example: Where `BvNot` was used previously, we now use `Not` which is already used for Expr BoolSort.
Differentiation between such operations now takes place in `Language.Hasmtlib.Internal.Render#render` when rendering expressions,
e.g. rendering `bvnot` for `BvSort` and `not` for `BoolSort`.
Therefore there is no behavioral change for the solver.
- Removed functions `bvRotL` and `bvRotR` from `Language.Hasmtlib.Type.Expr`: use the added `Bits` instance instead with `rotateL` and `rotateR`

## v2.4.0 _(2024-08-21)_

### Added
Expand Down
4 changes: 2 additions & 2 deletions hasmtlib.cabal
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cabal-version: 3.0

name: hasmtlib
version: 2.4.0
version: 2.5.0
synopsis: A monad for interfacing with external SMT solvers
description: Hasmtlib is a library for generating SMTLib2-problems using a monad.
It takes care of encoding your problem, marshaling the data to an external solver and parsing and interpreting the result into Haskell types.
Expand Down Expand Up @@ -32,7 +32,6 @@ library
, Language.Hasmtlib.Boolean
, Language.Hasmtlib.Variable
, Language.Hasmtlib.Counting
, Language.Hasmtlib.Integraled
, Language.Hasmtlib.Internal.Parser
, Language.Hasmtlib.Internal.Bitvec
, Language.Hasmtlib.Internal.Render
Expand All @@ -47,6 +46,7 @@ library
, Language.Hasmtlib.Solver.Yices
, Language.Hasmtlib.Solver.Z3
, Language.Hasmtlib.Type.Expr
, Language.Hasmtlib.Type.Value
, Language.Hasmtlib.Type.MonadSMT
, Language.Hasmtlib.Type.SMT
, Language.Hasmtlib.Type.OMT
Expand Down
4 changes: 2 additions & 2 deletions src/Language/Hasmtlib.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ module Language.Hasmtlib
, module Language.Hasmtlib.Type.OMT
, module Language.Hasmtlib.Type.Pipe
, module Language.Hasmtlib.Type.Expr
, module Language.Hasmtlib.Type.Value
, module Language.Hasmtlib.Type.Solver
, module Language.Hasmtlib.Type.Option
, module Language.Hasmtlib.Type.SMTSort
, module Language.Hasmtlib.Type.Solution
, module Language.Hasmtlib.Type.ArrayMap
, module Language.Hasmtlib.Integraled
, module Language.Hasmtlib.Boolean
, module Language.Hasmtlib.Codec
, module Language.Hasmtlib.Counting
Expand All @@ -30,12 +30,12 @@ import Language.Hasmtlib.Type.SMT
import Language.Hasmtlib.Type.OMT
import Language.Hasmtlib.Type.Pipe
import Language.Hasmtlib.Type.Expr
import Language.Hasmtlib.Type.Value
import Language.Hasmtlib.Type.Solver
import Language.Hasmtlib.Type.Option
import Language.Hasmtlib.Type.SMTSort
import Language.Hasmtlib.Type.Solution
import Language.Hasmtlib.Type.ArrayMap
import Language.Hasmtlib.Integraled
import Language.Hasmtlib.Boolean
import Language.Hasmtlib.Codec
import Language.Hasmtlib.Counting
Expand Down
20 changes: 4 additions & 16 deletions src/Language/Hasmtlib/Codec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Language.Hasmtlib.Boolean
import Data.Kind
import Data.Coerce
import qualified Data.List as List
import Data.Bits hiding (And, Xor, xor)
import Data.Map (Map)
import Data.Sequence (Seq)
import Data.IntMap as IM hiding (foldl)
Expand Down Expand Up @@ -74,6 +75,7 @@ instance KnownSMTSort t => Codec (Expr t) where
return $ unwrapValue val
decode _ (Constant v) = Just $ unwrapValue v
decode sol (Plus x y) = (+) <$> decode sol x <*> decode sol y
decode sol (Minus x y) = (-) <$> decode sol x <*> decode sol y
decode sol (Neg x) = fmap negate (decode sol x)
decode sol (Mul x y) = (*) <$> decode sol x <*> decode sol y
decode sol (Abs x) = fmap abs (decode sol x)
Expand Down Expand Up @@ -111,27 +113,13 @@ instance KnownSMTSort t => Codec (Expr t) where
decode sol (ToInt x) = fmap truncate (decode sol x)
decode sol (IsInt x) = fmap ((0 ==) . snd . properFraction) (decode sol x)
decode sol (Ite p t f) = (\p' t' f' -> if p' then t' else f') <$> decode sol p <*> decode sol t <*> decode sol f
decode sol (BvNot x) = fmap not (decode sol x)
decode sol (BvAnd x y) = (&&) <$> decode sol x <*> decode sol y
decode sol (BvOr x y) = (||) <$> decode sol x <*> decode sol y
decode sol (BvXor x y) = xor <$> decode sol x <*> decode sol y
decode sol (BvNand x y) = nand <$> sequenceA [decode sol x, decode sol y]
decode sol (BvNor x y) = nor <$> sequenceA [decode sol x, decode sol y]
decode sol (BvNeg x) = fmap negate (decode sol x)
decode sol (BvAdd x y) = (+) <$> decode sol x <*> decode sol y
decode sol (BvSub x y) = (-) <$> decode sol x <*> decode sol y
decode sol (BvMul x y) = (*) <$> decode sol x <*> decode sol y
decode sol (BvuDiv x y) = div <$> decode sol x <*> decode sol y
decode sol (BvuRem x y) = rem <$> decode sol x <*> decode sol y
decode sol (BvShL x y) = join $ bvShL <$> decode sol x <*> decode sol y
decode sol (BvLShR x y) = join $ bvLShR <$> decode sol x <*> decode sol y
decode sol (BvConcat x y) = bvConcat <$> decode sol x <*> decode sol y
decode sol (BvRotL i x) = bvRotL i <$> decode sol x
decode sol (BvRotR i x) = bvRotR i <$> decode sol x
decode sol (BvuLT x y) = (<) <$> decode sol x <*> decode sol y
decode sol (BvuLTHE x y) = (<=) <$> decode sol x <*> decode sol y
decode sol (BvuGTHE x y) = (>=) <$> decode sol x <*> decode sol y
decode sol (BvuGT x y) = (>) <$> decode sol x <*> decode sol y
decode sol (BvRotL i x) = rotateL <$> decode sol x <*> pure (fromIntegral i)
decode sol (BvRotR i x) = rotateR <$> decode sol x <*> pure (fromIntegral i)
decode sol (ArrSelect i arr) = arrSelect <$> decode sol i <*> decode sol arr
decode sol (ArrStore i x arr) = arrStore <$> decode sol i <*> decode sol x <*> decode sol arr
decode sol (StrConcat x y) = (<>) <$> decode sol x <*> decode sol y
Expand Down
2 changes: 1 addition & 1 deletion src/Language/Hasmtlib/Example/Arith.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Language.Hasmtlib.Example.Arith where

import Prelude hiding (mod, (&&))
import Prelude hiding ((&&))
import Language.Hasmtlib

main :: IO ()
Expand Down
14 changes: 4 additions & 10 deletions src/Language/Hasmtlib/Example/Bitvector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,17 @@ module Language.Hasmtlib.Example.Bitvector where

import Language.Hasmtlib
import Data.Default
import Data.Bits

main :: IO ()
main = do
res <- solveWith @SMT (debug bitwuzla def) $ do
setLogic "QF_BV"

xbv8 <- variable
ybv8 <- var @(BvSort 8)
x <- var @(BvSort 8)

assert $ true === (xbv8 `xor` ybv8)
assert $ xbv8 <=? maxBound
assert $ x === clearBit (maxBound `div` 2) 2

assert $ xbv8 >? 0
assert $ ybv8 >? 0

assert $ xbv8 + ybv8 >? xbv8 * ybv8

return (xbv8, ybv8)
return x

print res
43 changes: 0 additions & 43 deletions src/Language/Hasmtlib/Integraled.hs

This file was deleted.

58 changes: 32 additions & 26 deletions src/Language/Hasmtlib/Internal/Bitvec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

module Language.Hasmtlib.Internal.Bitvec where

import Prelude hiding ((&&), (||), not)
import Language.Hasmtlib.Boolean
import Language.Hasmtlib.Internal.Render
import Data.ByteString.Builder
import Data.Bit
import Data.Bits
import Data.Coerce
import Data.Finite
import Data.Finite hiding (shift)
import Data.Proxy
import Data.Ratio ((%))
import Data.Bifunctor
Expand All @@ -18,8 +19,21 @@ import GHC.TypeNats

-- | Unsigned and length-indexed bitvector with MSB first.
newtype Bitvec (n :: Nat) = Bitvec { unBitvec :: V.Vector n Bit }
deriving stock (Eq, Ord)
deriving newtype (Boolean)
deriving newtype (Eq, Ord, Boolean)

instance KnownNat n => Bits (Bitvec n) where
(.&.) = (&&)
(.|.) = (||)
xor = Language.Hasmtlib.Boolean.xor
complement = not
shift bv i = coerce $ shift (coerce @_ @(V.Vector n Bit) bv) (negate i)
rotate bv i = coerce $ rotate (coerce @_ @(V.Vector n Bit) bv) (negate i)
bitSize _ = fromIntegral $ natVal (Proxy @n)
bitSizeMaybe _ = Just $ fromIntegral $ natVal (Proxy @n)
isSigned _ = false
testBit bv = testBit (V.reverse (coerce @_ @(V.Vector n Bit) bv))
bit (toInteger -> i) = coerce $ V.reverse $ V.replicate @n (Bit False) V.// [(finite i, Bit True)]
popCount = coerce . popCount . coerce @_ @(V.Vector n Bit)

instance Show (Bitvec n) where
show = V.toList . V.map (\b -> if coerce b then '1' else '0') . coerce @_ @(V.Vector n Bit)
Expand All @@ -29,17 +43,17 @@ instance Render (Bitvec n) where
{-# INLINEABLE render #-}

instance KnownNat n => Num (Bitvec n) where
fromInteger x = coerce . V.reverse $ V.generate @n (coerce . testBit x . fromInteger . getFinite)
fromInteger x = coerce . V.reverse $ V.generate @n (coerce . testBit x . fromInteger . getFinite)
negate = id
abs = id
signum _ = 0
(coerce -> x) + (coerce -> y) = coerce @(V.Vector n Bit) $ x + y
(coerce -> x) - (coerce -> y) = coerce @(V.Vector n Bit) $ x - y
(coerce -> x) * (coerce -> y) = coerce @(V.Vector n Bit) $ x * y
(coerce -> x) * (coerce -> y) = coerce @(V.Vector n Bit) $ x * y

instance KnownNat n => Bounded (Bitvec n) where
minBound = coerce $ V.replicate @n false
maxBound = coerce $ V.replicate @n true
minBound = coerce $ V.replicate @n false
maxBound = coerce $ V.replicate @n true

instance KnownNat n => Enum (Bitvec n) where
succ x = x + 1
Expand Down Expand Up @@ -87,32 +101,24 @@ bvFromListN = coerce . V.fromListN @n
bvFromListN' :: forall n. KnownNat n => Proxy n -> [Bit] -> Maybe (Bitvec n)
bvFromListN' _ = bvFromListN

bvRotL :: forall n i. KnownNat (Mod i n) => Proxy i -> Bitvec n -> Bitvec n
bvRotL _ (coerce -> x) = coerce $ r V.++ l
where
(l, r) = V.splitAt' (Proxy @(Mod i n)) x

bvRotR :: forall n i. KnownNat (Mod i n) => Proxy i -> Bitvec n -> Bitvec n
bvRotR p = bvReverse . bvRotL p . bvReverse

bvShL :: KnownNat n => Bitvec n -> Bitvec n -> Maybe (Bitvec n)
bvShL x y = bvFromListN $ (++ replicate i false) $ drop i $ bvToList x
where
i = fromIntegral y
where
i = fromIntegral y

bvLShR :: KnownNat n => Bitvec n -> Bitvec n -> Maybe (Bitvec n)
bvLShR x y = fmap bvReverse $ bvFromListN $ (++ replicate i false) $ drop i $ bvToList $ bvReverse x
where
i = fromIntegral y
where
i = fromIntegral y

bvZeroExtend :: KnownNat i => Proxy i -> Bitvec n -> Bitvec (n+i)
bvZeroExtend p x = bvConcat x $ bvReplicate' p false
bvExtract :: forall n i j.
bvZeroExtend p x = bvConcat x $ bvReplicate' p false

bvExtract :: forall n i j.
( KnownNat i, KnownNat ((j - i) + 1)
, (i+(n-i)) ~ n
, (((j - i) + 1) + ((n - i)-((j - i) + 1))) ~ (n - i)
) => Proxy i -> Proxy j -> Bitvec n -> Bitvec (( j - i ) + 1)
bvExtract pri _ x = bvTake' @_ @((n-i)-((j-i)+1)) (Proxy @((j-i)+1)) x'
where
x' :: Bitvec (n-i) = bvDrop' pri x
x' :: Bitvec (n-i) = bvDrop' pri x
4 changes: 2 additions & 2 deletions src/Language/Hasmtlib/Internal/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ parseExpr = var <|> constantExpr <|> ternary "ite" (ite @(Expr BoolSort))
<|> binary "bvand" (&&) <|> binary "bvor" (||) <|> binary "bvxor" xor <|> binary "bvnand" BvNand <|> binary "bvnor" BvNor
<|> unary "bvneg" negate
<|> binary "bvadd" (+) <|> binary "bvsub" (-) <|> binary "bvmul" (*)
<|> binary "bvudiv" BvuDiv <|> binary "bvurem" BvuRem
<|> binary "bvudiv" div <|> binary "bvurem" rem
<|> binary "bvshl" BvShL <|> binary "bvlshr" BvLShR
SArraySort _ _ -> ternary "store" ArrStore
-- TODO: Add compare ops for all (?) array-sorts
SStringSort -> binary "str.++" (<>) <|> binary "str.at" strAt <|> ternary "str.substr" StrSubstring
<|> ternary "str.replace" strReplace <|> ternary "str.replace_all" strReplaceAll

var :: Parser (Expr t)
var :: KnownSMTSort t => Parser (Expr t)
var = do
_ <- string "var_"
vId <- decimal @Int
Expand Down
Loading

0 comments on commit 2b9b967

Please sign in to comment.