Skip to content

Commit

Permalink
Merge pull request #116 from bruderj15/107-smart-ite
Browse files Browse the repository at this point in the history
107 smart ite
  • Loading branch information
bruderj15 authored Oct 5, 2024
2 parents 78d8355 + db6a53d commit 4242278
Show file tree
Hide file tree
Showing 11 changed files with 245 additions and 62 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [PVP versioning](https://pvp.haskell.org/).

## v2.7.1 _(2024-10-05)_

### Changed
- Cardinality constraints in `Language.Hasmtlib.Counting` now use specialized and more efficient encodings for a few cases.
- Debugging with debugger `statistically` now prints a more comprehensive overview about the problem size.
- Fixed bug where setting multiple custom `SMTOption`s would only set the most recent.
- Fixed bug where timeout for `SMT`/`OMT` would not work.
- Added smart constructors for `ite`.

## v2.7.0 _(2024-09-12)_

### Added
Expand Down
2 changes: 1 addition & 1 deletion hasmtlib.cabal
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cabal-version: 3.0

name: hasmtlib
version: 2.7.0
version: 2.7.1
synopsis: A monad for interfacing with external SMT solvers
description: Hasmtlib is a library for generating SMTLib2-problems using a monad.
It takes care of encoding your problem, marshaling the data to an external solver and parsing and interpreting the result into Haskell types.
Expand Down
99 changes: 80 additions & 19 deletions src/Language/Hasmtlib/Counting.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Therefore additional information for the temporal summation may need to be provi
E.g. if your logic is \"QF_LIA\" you would want @'count'' \@'IntSort' $ ...@
It is worth noting that some cardinality constraints use optimized encodings, such as @'atLeast' 1 ≡ 'or'@.
It is worth noting that some cardinality constraints use optimized encodings for special cases.
-}
module Language.Hasmtlib.Counting
(
Expand All @@ -27,48 +27,109 @@ module Language.Hasmtlib.Counting

-- * At-Most
, atMost
, amoSqrt
, amoQuad
)
where

import Prelude hiding (not, (&&), (||), or)
import Prelude hiding (not, (&&), (||), and, or, all, any)
import Language.Hasmtlib.Type.SMTSort
import Language.Hasmtlib.Type.Expr
import Language.Hasmtlib.Boolean
import Data.Foldable (toList)
import Data.List (transpose)
import Data.Proxy
import Control.Lens

-- | Wrapper for 'count' which takes a 'Proxy'.
count' :: forall t f. (Functor f, Foldable f, Num (Expr t)) => Proxy t -> f (Expr BoolSort) -> Expr t
count' _ = sum . fmap (\b -> ite b 1 0)
{-# INLINEABLE count' #-}
{-# INLINE count' #-}

-- | Out of many bool-expressions build a formula which encodes how many of them are 'true'.
count :: forall t f. (Functor f, Foldable f, Num (Expr t)) => f (Expr BoolSort) -> Expr t
count = count' (Proxy @t)
{-# INLINE count #-}

-- | Out of many bool-expressions build a formula which encodes that __at most__ @k@ of them are 'true'.
atMost :: forall t f. (Functor f, Foldable f, KnownSMTSort t, Num (HaskellType t), Ord (HaskellType t)) => Expr t -> f (Expr BoolSort) -> Expr BoolSort
atMost (Constant 0) = nand
atMost (Constant 1) = atMostOneLinear
--
-- 'atMost' is defined as follows:
--
-- @
-- 'atMost' 0 = 'nand'
-- 'atMost' 1 = 'amoSqrt'
-- 'atMost' k = ('<=?' k) . 'count'
-- @
atMost :: forall t f. (Functor f, Foldable f, Num (Expr t), Orderable (Expr t)) => Expr t -> f (Expr BoolSort) -> Expr BoolSort
atMost 0 = nand
atMost 1 = amoSqrt
atMost k = (<=? k) . count
{-# INLINE atMost #-}

atMostOneLinear :: (Foldable f, Boolean b) => f b -> b
atMostOneLinear xs =
let (_, sz) = foldr (plus . (, false)) (false, false) xs
in not sz
where
plus (xe, xz) (ye, yz) = (xe || ye, xz || yz || (xe && ye))

-- | Out of many bool-expressions build a formula which encodes that __at least__ @k@ of them are 'true'.
atLeast :: forall t f. (Functor f, Foldable f, KnownSMTSort t, Num (HaskellType t), Ord (HaskellType t)) => Expr t -> f (Expr BoolSort) -> Expr BoolSort
atLeast (Constant 0) = const true
atLeast (Constant 1) = or
--
-- 'atLeast' is defined as follows:
--
-- @
-- 'atLeast' 0 = 'const' 'true'
-- 'atLeast' 1 = 'or'
-- 'atLeast' k = ('>=?' k) . 'count'
-- @
atLeast :: forall t f. (Functor f, Foldable f, Num (Expr t), Orderable (Expr t)) => Expr t -> f (Expr BoolSort) -> Expr BoolSort
atLeast 0 = const true
atLeast 1 = or
atLeast k = (>=? k) . count
{-# INLINE atLeast #-}

-- | Out of many bool-expressions build a formula which encodes that __exactly__ @k@ of them are 'true'.
exactly :: forall t f. (Functor f, Foldable f, KnownSMTSort t, Num (HaskellType t), Ord (HaskellType t)) => Expr t -> f (Expr BoolSort) -> Expr BoolSort
exactly (Constant 0) = nand
exactly k = (=== k) . count
--
-- 'exactly' is defined as follows:
--
-- @
-- 'exactly' 0 xs = 'nand' xs
-- 'exactly' 1 xs = 'atLeast' \@t 1 xs '&&' 'atMost' \@t 1 xs
-- 'exactly' k xs = 'count' xs '===' k
-- @
exactly :: forall t f. (Functor f, Foldable f, Num (Expr t), Orderable (Expr t)) => Expr t -> f (Expr BoolSort) -> Expr BoolSort
exactly 0 xs = nand xs
exactly 1 xs = atLeast @t 1 xs && atMost @t 1 xs
exactly k xs = count xs === k
{-# INLINE exactly #-}

-- | The squareroot-encoding, also called product-encoding, is a special encoding for @atMost 1@.
--
-- The original product-encoding provided by /Jingchao Chen/ in /A New SAT Encoding of the At-Most-One Constraint (2010)/
-- used auxiliary variables and would therefore be monadic.
-- It requires \( 2 \sqrt{n} + \mathcal{O}(\sqrt[4]{n}) \) auxiliary variables and
-- \( 2n + 4\sqrt{n} + \mathcal{O}(\sqrt[4]{n}) \) clauses.
--
-- To make this encoding pure, all auxiliary variables are replaced with a disjunction of size \( \mathcal{O}(\sqrt{n}) \).
-- Therefore zero auxiliary variables are required and technically clause-count remains the same, although the clauses get bigger.
amoSqrt :: (Foldable f, Boolean b) => f b -> b
amoSqrt xs
| length xs < 10 = amoQuad $ toList xs
| otherwise =
let n = toInteger $ length xs
p = ceiling $ sqrt $ fromInteger n
rows = splitEvery (fromInteger p) $ toList xs
columns = transpose rows
vs = or <$> rows
us = or <$> columns
in amoSqrt vs && amoSqrt us &&
and (imap (\j r -> and $ imap (\i x -> (x ==> us !! i) && (x ==> vs !! j)) r) rows)
where
splitEvery n = takeWhile (not . null) . map (take n) . iterate (drop n)

-- | The quadratic-encoding, also called pairwise-encoding, is a special encoding for @atMost 1@.
--
-- It's the naive encoding for the at-most-one-constraint and produces \( \binom{n}{2} \) clauses and no auxiliary variables..
amoQuad :: Boolean b => [b] -> b
amoQuad as = and $ do
ys <- subs 2 as
return $ any not ys
where
subs :: Int -> [a] -> [[a]]
subs 0 _ = [[]]
subs _ [] = []
subs k (x : xs) = map (x :) (subs (k -1) xs) <> subs k xs
{-# INLINE amoQuad #-}
50 changes: 50 additions & 0 deletions src/Language/Hasmtlib/Example/NQueens.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE UndecidableInstances #-}

module Language.Hasmtlib.Example.McCarthyArrayAxioms where

import Language.Hasmtlib
import Control.Monad
import Data.List (groupBy, sortOn, transpose)
import Data.Function (on)

main :: IO ()
main = do
(res, sol) <- solveWith @SMT (solver z3) $ nqueens @IntSort 4
case res of
Sat -> case sol of
Nothing -> putStrLn "No solution"
Just board -> forM_ board print
r -> print r

nqueens :: forall t s m. (MonadSMT s m, KnownSMTSort t, Num (Expr t), Orderable (Expr t)) => Int -> m [[Expr t]]
nqueens n = do
setLogic $ case sortSing @t of
SIntSort -> "QF_LIA"
SRealSort -> "QF_LRA"
SBvSort _ _ -> "QF_BV"
_ -> "ALL"

board <- replicateM n $ replicateM n var

forM_ (concat board) $ assert . queenDomain
forM_ board $ assert . ((=== 1) . sum)
forM_ (transpose board) $ assert . ((=== 1) . sum)
forM_ (diagonals board) $ assert . ((<=? 1) . sum)

return board

queenDomain ::(Equatable (Expr t), Num (Expr t)) => Expr t -> Expr BoolSort
queenDomain f = (f === 0) `xor` (f === 1)

diagonals :: [[a]] -> [[a]]
diagonals mat = diagonals1 mat ++ diagonals2 mat

indexedMatrix :: [[a]] -> [((Int, Int), a)]
indexedMatrix mat = [((i, j), val) | (i, row) <- zip [0..] mat, (j, val) <- zip [0..] row]

diagonals1 :: [[a]] -> [[a]]
diagonals1 mat = map (map snd) . groupBy ((==) `on` (uncurry (-) . fst)) . sortOn (uncurry (-) . fst) $ indexedMatrix mat

diagonals2 :: [[a]] -> [[a]]
diagonals2 mat = map (map snd) . groupBy ((==) `on` (uncurry (+) . fst)) . sortOn (uncurry (+) . fst) $ indexedMatrix mat
11 changes: 6 additions & 5 deletions src/Language/Hasmtlib/Example/OMTOptimization.hs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
module Language.Hasmtlib.Example.IncrementalOptimization where

import Prelude hiding ((&&))
import Language.Hasmtlib

main :: IO ()
main = do
res <- solveWith @OMT (solver $ debugging statistically z3) $ do
res <- solveWith @OMT (solver $ debugging verbosely z3) $ do
setLogic "QF_LIA"

x <- var @IntSort
y <- var @IntSort

assert $ x >? -2
assertSoftWeighted (x >? -1) 5.0
assert $ x <? 10 && y <? 5 && (y <? 7 ==> x === 1)

minimize x
maximize $ x + y

return x
return (x,y)

print res
5 changes: 5 additions & 0 deletions src/Language/Hasmtlib/Internal/Uniplate1.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
module Language.Hasmtlib.Internal.Uniplate1 where

import Language.Hasmtlib.Internal.Constraint
import Data.Functor.Identity
import Data.Kind

type Uniplate1 :: (k -> Type) -> [k -> Constraint] -> Constraint
Expand All @@ -12,5 +13,9 @@ class Uniplate1 f cs | f -> cs where
transformM1 :: (Monad m, Uniplate1 f cs, AllC cs b) => (forall a. AllC cs a => f a -> m (f a)) -> f b -> m (f b)
transformM1 f x = uniplate1 (transformM1 f) x >>= f

transform1 :: (Uniplate1 f cs, AllC cs b) => (forall a. AllC cs a => f a -> f a) -> f b -> f b
transform1 f = runIdentity . transformM1 (Identity . f)
{-# INLINE transform1 #-}

lazyParaM1 :: (Monad m, Uniplate1 f cs, AllC cs b) => (forall a. AllC cs a => f a -> m (f a) -> m (f a)) -> f b -> m (f b)
lazyParaM1 f x = f x (uniplate1 (lazyParaM1 f) x)
18 changes: 14 additions & 4 deletions src/Language/Hasmtlib/Type/Debugger.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ where

import Language.Hasmtlib.Type.SMT
import Language.Hasmtlib.Type.OMT
import Data.Sequence as Seq hiding ((|>), filter)
import Data.ByteString.Lazy hiding (singleton)
import Language.Hasmtlib.Type.Expr
import Language.Hasmtlib.Type.SMTSort
import Data.Sequence as Seq hiding ((|>))
import Data.ByteString.Lazy (ByteString, split)
import Data.ByteString.Lazy.UTF8 (toString)
import Data.ByteString.Builder
import qualified Data.ByteString.Lazy.Char8 as ByteString.Char8
Expand Down Expand Up @@ -158,15 +160,23 @@ class StateDebugger s where
instance StateDebugger SMT where
statistically = silently
{ debugState = \s -> do
putStrLn $ "Variables: " ++ show (Seq.length (s^.vars))
putStrLn $ "Bool Vars: " ++ show (Seq.length $ Seq.filter (\(SomeSMTSort v) -> case sortSing' v of SBoolSort -> True ; _ -> False) $ s^.vars)
putStrLn $ "Arith Vars: " ++ show (Seq.length $ Seq.filter (\(SomeSMTSort v) -> case sortSing' v of SBoolSort -> False ; _ -> True) $ s^.vars)
putStrLn $ "Assertions: " ++ show (Seq.length (s^.formulas))
putStrLn $ "Size: " ++ show (sum $ fmap exprSize $ s^.formulas)
}

instance StateDebugger OMT where
statistically = silently
{ debugState = \omt -> do
putStrLn $ "Variables: " ++ show (Seq.length (omt^.smt.vars))
putStrLn $ "Bool Vars: " ++ show (Seq.length $ Seq.filter (\(SomeSMTSort v) -> case sortSing' v of SBoolSort -> True ; _ -> False) $ omt^.smt.vars)
putStrLn $ "Arith Vars: " ++ show (Seq.length $ Seq.filter (\(SomeSMTSort v) -> case sortSing' v of SBoolSort -> False ; _ -> True) $ omt^.smt.vars)
putStrLn $ "Hard assertions: " ++ show (Seq.length (omt^.smt.formulas))
putStrLn $ "Soft assertions: " ++ show (Seq.length (omt^.softFormulas))
putStrLn $ "Optimizations: " ++ show (Seq.length (omt^.targetMinimize) + Seq.length (omt^.targetMaximize))
let omtSize = sum (fmap exprSize $ omt^.smt.formulas)
+ sum (fmap (exprSize . view formula) $ omt^.softFormulas)
+ sum (fmap (\(SomeSMTSort (Minimize expr)) -> exprSize expr) $ omt^.targetMinimize)
+ sum (fmap (\(SomeSMTSort (Maximize expr)) -> exprSize expr) $ omt^.targetMaximize)
putStrLn $ "Size: " ++ show omtSize
}
45 changes: 42 additions & 3 deletions src/Language/Hasmtlib/Type/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ module Language.Hasmtlib.Type.Expr
SMTVar(..), varId

-- * Expr type
, Expr(..), isLeaf
, Expr(..), isLeaf, exprSize

-- * Compare
-- ** Equatable
Expand Down Expand Up @@ -91,13 +91,16 @@ import Data.Void
import qualified Data.Bits as Bits
import Data.Sequence (Seq)
import Data.Tree (Tree)
import Data.STRef
import Data.Monoid (Sum, Product, First, Last, Dual)
import Data.String (IsString(..))
import Data.Text (pack)
import Data.List(genericLength)
import Data.Foldable (toList)
import qualified Data.Vector.Sized as V
import Control.Lens hiding (from, to)
import Control.Monad.ST
import Control.Monad
import GHC.TypeLits hiding (someNatVal)
import GHC.TypeNats (someNatVal)
import GHC.Generics
Expand Down Expand Up @@ -184,6 +187,26 @@ isLeaf Pi = True
isLeaf _ = False
{-# INLINE isLeaf #-}

-- | Size of the expression.
--
-- Counts the amount of operations.
--
-- ==== __Examples__
--
-- >>> nodeSize $ x + y === x + y
-- 3
-- >>> nodeSize $ false
-- 0
exprSize :: KnownSMTSort t => Expr t -> Integer
exprSize expr = runST $ do
nodesRef <- newSTRef 0
_ <- transformM1
(\expr' -> do
unless (isLeaf expr') $ modifySTRef' nodesRef (+1)
return expr')
expr
readSTRef nodesRef

-- | Class that allows branching on predicates of type @b@ on branches of type @a@.
--
-- If predicate (p :: b) then (t :: a) else (f :: a).
Expand All @@ -202,8 +225,24 @@ class Iteable b a where
ite p t f = ite p <$> t <*> f

instance Iteable (Expr BoolSort) (Expr t) where
ite = Ite
{-# INLINE ite #-}
ite (Constant (BoolValue False)) _ f = f
ite (Constant (BoolValue True)) t _ = t
ite p t@(Ite p' t' f') f@(Ite p'' t'' f'')
| p' == p'' && t' == t'' = Ite p' t' (Ite p f' f'')
| p' == p'' && f' == f'' = Ite (not p') f' (Ite p t' t'')
| otherwise = Ite p t f
ite p t f@(Ite p' t' f')
| p == p' = Ite p t f'
| t == t' = Ite (p || p') t f'
| otherwise = Ite p t f
ite p t@(Ite p' t' f') f
| p == p' = Ite p t' f
| f == f' = Ite (p && p') t' f
| otherwise = Ite p t f
ite p t f
| t == f = t
| otherwise = Ite p t f
{-# INLINEABLE ite #-}

instance Iteable Bool a where
ite p t f = if p then t else f
Expand Down
Loading

0 comments on commit 4242278

Please sign in to comment.