Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forking solver implementation #123

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package io.ksmt.solver.bitwuzla

import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap
import io.ksmt.KContext
import io.ksmt.decl.KDecl
import io.ksmt.decl.KFuncDecl
Expand All @@ -15,13 +12,10 @@ import io.ksmt.expr.KExistentialQuantifier
import io.ksmt.expr.KExpr
import io.ksmt.expr.KFunctionApp
import io.ksmt.expr.KFunctionAsArray
import io.ksmt.expr.KUninterpretedSortValue
import io.ksmt.expr.KUniversalQuantifier
import io.ksmt.expr.transformer.KNonRecursiveTransformer
import io.ksmt.solver.KSolverException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.Native
import io.ksmt.solver.util.KExprLongInternalizerBase.Companion.NOT_INTERNALIZED
import io.ksmt.sort.KArray2Sort
import io.ksmt.sort.KArray3Sort
Expand All @@ -37,6 +31,13 @@ import io.ksmt.sort.KRealSort
import io.ksmt.sort.KSort
import io.ksmt.sort.KSortVisitor
import io.ksmt.sort.KUninterpretedSort
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.Native

open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
private var isClosed = false
Expand Down Expand Up @@ -433,6 +434,11 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
return super.transform(expr)
}

override fun transform(expr: KUninterpretedSortValue): KExpr<KUninterpretedSort> {
registerDeclIfNotIgnored(expr.decl)
return super.transform(expr)
}

private val quantifiedVarsScopeOwner = arrayListOf<KExpr<*>>()
private val quantifiedVarsScope = arrayListOf<Set<KDecl<*>>?>()

Expand Down Expand Up @@ -474,7 +480,7 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
override fun transform(expr: KExistentialQuantifier): KExpr<KBoolSort> =
expr.transformQuantifier(expr.bounds, expr.body)

override fun transform(expr: KUniversalQuantifier): KExpr<KBoolSort> =
override fun transform(expr: KUniversalQuantifier): KExpr<KBoolSort> =
expr.transformQuantifier(expr.bounds, expr.body)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,11 @@ import io.ksmt.expr.KUnaryMinusArithExpr
import io.ksmt.expr.KUninterpretedSortValue
import io.ksmt.expr.KUniversalQuantifier
import io.ksmt.expr.KXorExpr
import io.ksmt.expr.rewrite.simplify.rewriteBvAddNoUnderflowExpr
import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoUnderflowExpr
import io.ksmt.expr.rewrite.simplify.rewriteBvNegNoOverflowExpr
import io.ksmt.expr.rewrite.simplify.rewriteBvSubNoUnderflowExpr
import io.ksmt.solver.KSolverUnsupportedFeatureException
import org.ksmt.solver.bitwuzla.bindings.Bitwuzla
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaKind
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaRoundingMode
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.Native
import io.ksmt.solver.bitwuzla.KBitwuzlaExprInternalizer.BvOverflowCheckMode.OVERFLOW
import io.ksmt.solver.bitwuzla.KBitwuzlaExprInternalizer.BvOverflowCheckMode.UNDERFLOW
import io.ksmt.solver.util.KExprLongInternalizerBase
import io.ksmt.sort.KArithSort
import io.ksmt.sort.KArray2Sort
Expand All @@ -186,7 +180,13 @@ import io.ksmt.sort.KRealSort
import io.ksmt.sort.KSort
import io.ksmt.sort.KSortVisitor
import io.ksmt.sort.KUninterpretedSort
import org.ksmt.solver.bitwuzla.bindings.Bitwuzla
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaKind
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaRoundingMode
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTermArray
import org.ksmt.solver.bitwuzla.bindings.Native
import java.math.BigInteger

@Suppress("LargeClass")
Expand Down Expand Up @@ -726,7 +726,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
override fun <T : KBvSort> transform(expr: KBvAddNoOverflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
if (isSigned) {
mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW)
mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW)
} else {
val overflowCheck = Native.bitwuzlaMkTerm2(
bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_UADD_OVERFLOW, a0, a1
Expand All @@ -738,20 +738,20 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL

override fun <T : KBvSort> transform(expr: KBvAddNoUnderflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW)
mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW)
}
}

override fun <T : KBvSort> transform(expr: KBvSubNoOverflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW)
mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW)
}
}

override fun <T : KBvSort> transform(expr: KBvSubNoUnderflowExpr<T>) = with(expr) {
if (isSigned) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW)
mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW)
}
} else {
transform {
Expand All @@ -776,7 +776,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
override fun <T : KBvSort> transform(expr: KBvMulNoOverflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
if (isSigned) {
mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW)
mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW)
} else {
val overflowCheck = Native.bitwuzlaMkTerm2(
bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_UMUL_OVERFLOW, a0, a1
Expand All @@ -788,7 +788,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL

override fun <T : KBvSort> transform(expr: KBvMulNoUnderflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW)
mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW)
}
}

Expand All @@ -813,7 +813,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
a1,
BitwuzlaKind.BITWUZLA_KIND_BV_SADD_OVERFLOW
) { a0Sign, a1Sign ->
if (mode == BvOverflowCheckMode.OVERFLOW) {
if (mode == OVERFLOW) {
// Both positive
mkAndTerm(longArrayOf(mkNotTerm(a0Sign), mkNotTerm(a1Sign)))
} else {
Expand All @@ -833,7 +833,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
a1,
BitwuzlaKind.BITWUZLA_KIND_BV_SSUB_OVERFLOW
) { a0Sign, a1Sign ->
if (mode == BvOverflowCheckMode.OVERFLOW) {
if (mode == OVERFLOW) {
// Positive sub negative
mkAndTerm(longArrayOf(mkNotTerm(a0Sign), a1Sign))
} else {
Expand All @@ -853,7 +853,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
a1,
BitwuzlaKind.BITWUZLA_KIND_BV_SMUL_OVERFLOW
) { a0Sign, a1Sign ->
if (mode == BvOverflowCheckMode.OVERFLOW) {
if (mode == OVERFLOW) {
// Overflow is possible when sign bits are equal
mkEqTerm(bitwuzlaCtx.ctx.boolSort, a0Sign, a1Sign)
} else {
Expand Down Expand Up @@ -1401,6 +1401,8 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
}

override fun transform(expr: KUninterpretedSortValue): KExpr<KUninterpretedSort> = expr.transform {
// register it for uninterpreted sort universe
bitwuzlaCtx.registerDeclaration(expr.decl)
Native.bitwuzlaMkBvValueUint32(
bitwuzla,
expr.sort.internalizeSort(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package io.ksmt.solver.bitwuzla

import io.ksmt.KContext
import io.ksmt.expr.KExpr
import io.ksmt.solver.KForkingSolver
import io.ksmt.solver.KSolverStatus
import io.ksmt.sort.KBoolSort
import kotlin.time.Duration

class KBitwuzlaForkingSolver(
private val ctx: KContext,
private val manager: KBitwuzlaForkingSolverManager,
parent: KBitwuzlaForkingSolver?
) : KBitwuzlaSolverBase(ctx),
KForkingSolver<KBitwuzlaSolverConfiguration> {

private val assertions = ScopedLinkedFrame<MutableList<KExpr<KBoolSort>>>(::ArrayList, ::ArrayList)
private val trackToExprFrames =
ScopedLinkedFrame<MutableList<Pair<KExpr<KBoolSort>, KExpr<KBoolSort>>>>(::ArrayList, ::ArrayList)

private val config: KBitwuzlaForkingSolverConfigurationImpl

init {
if (parent != null) {
config = parent.config.fork(bitwuzlaCtx.bitwuzla)
assertions.fork(parent.assertions)
trackToExprFrames.fork(parent.trackToExprFrames)
} else {
config = KBitwuzlaForkingSolverConfigurationImpl(bitwuzlaCtx.bitwuzla)
}
}

override fun configure(configurator: KBitwuzlaSolverConfiguration.() -> Unit) {
config.configurator()
}

/**
* Creates lazily initiated forked solver (without cache sharing), preserving parental assertions and configuration.
*/
override fun fork(): KForkingSolver<KBitwuzlaSolverConfiguration> = manager.createForkingSolver(this)

private var assertionsInitiated = parent == null

private fun ensureAssertionsInitiated() {
if (assertionsInitiated) return

assertions.stacked().zip(trackToExprFrames.stacked())
.asReversed()
.forEachIndexed { scope, (assertionsFrame, trackedExprsFrame) ->
if (scope > 0) super.push()

assertionsFrame.forEach { assertion ->
internalizeAndAssertWithAxioms(assertion)
}

trackedExprsFrame.forEach { (track, trackedExpr) ->
super.registerTrackForExpr(trackedExpr, track)
}
}
assertionsInitiated = true
}

override fun assert(expr: KExpr<KBoolSort>) = bitwuzlaCtx.bitwuzlaTry {
ctx.ensureContextMatch(expr)
ensureAssertionsInitiated()

internalizeAndAssertWithAxioms(expr)
assertions.currentFrame += expr
}

override fun assertAndTrack(expr: KExpr<KBoolSort>) {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
super.assertAndTrack(expr)
}

override fun registerTrackForExpr(expr: KExpr<KBoolSort>, track: KExpr<KBoolSort>) {
super.registerTrackForExpr(expr, track)
trackToExprFrames.currentFrame += track to expr
}

override fun push() {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
super.push()
assertions.push()
trackToExprFrames.push()
}

override fun pop(n: UInt) {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
super.pop(n)
assertions.pop(n)
trackToExprFrames.pop(n)
}

override fun check(timeout: Duration): KSolverStatus {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
return super.check(timeout)
}

override fun checkWithAssumptions(assumptions: List<KExpr<KBoolSort>>, timeout: Duration): KSolverStatus {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
return super.checkWithAssumptions(assumptions, timeout)
}

override fun close() {
super.close()
manager.close(this)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package io.ksmt.solver.bitwuzla

import io.ksmt.KContext
import io.ksmt.solver.KForkingSolver
import io.ksmt.solver.KForkingSolverManager
import java.util.concurrent.ConcurrentHashMap

/**
* Responsible for creation and managing of [KBitwuzlaForkingSolver].
*
* Neither native cache is shared between [KBitwuzlaForkingSolver]s
* because cache sharing is not supported in bitwuzla.
*/
class KBitwuzlaForkingSolverManager(private val ctx: KContext) : KForkingSolverManager<KBitwuzlaSolverConfiguration> {
private val solvers = ConcurrentHashMap.newKeySet<KBitwuzlaForkingSolver>()

override fun createForkingSolver(): KForkingSolver<KBitwuzlaSolverConfiguration> {
return KBitwuzlaForkingSolver(ctx, this, null).also {
solvers += it
}
}

internal fun createForkingSolver(parent: KBitwuzlaForkingSolver) = KBitwuzlaForkingSolver(ctx, this, parent)
.also { solvers += it }

internal fun close(solver: KBitwuzlaForkingSolver) {
solvers -= solver
}

override fun close() {
solvers.forEach(KBitwuzlaForkingSolver::close)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,15 @@ package io.ksmt.solver.bitwuzla

import io.ksmt.KContext
import io.ksmt.decl.KDecl
import io.ksmt.decl.KUninterpretedSortValueDecl
import io.ksmt.expr.KExpr
import io.ksmt.expr.KUninterpretedSortValue
import io.ksmt.solver.KModel
import io.ksmt.solver.KSolverUnsupportedFeatureException
import io.ksmt.solver.model.KFuncInterp
import io.ksmt.solver.model.KFuncInterpEntryVarsFree
import io.ksmt.solver.model.KFuncInterpEntryVarsFreeOneAry
import io.ksmt.solver.model.KFuncInterpVarsFree
import io.ksmt.solver.KModel
import io.ksmt.solver.KSolverUnsupportedFeatureException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.FunValue
import org.ksmt.solver.bitwuzla.bindings.Native
import io.ksmt.solver.model.KFuncInterpWithVars
import io.ksmt.solver.model.KModelEvaluator
import io.ksmt.solver.model.KModelImpl
Expand All @@ -23,6 +20,10 @@ import io.ksmt.sort.KSort
import io.ksmt.sort.KUninterpretedSort
import io.ksmt.utils.mkFreshConstDecl
import io.ksmt.utils.uncheckedCast
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.FunValue
import org.ksmt.solver.bitwuzla.bindings.Native

open class KBitwuzlaModel(
private val ctx: KContext,
Expand Down Expand Up @@ -77,7 +78,12 @@ open class KBitwuzlaModel(
* to ensure that [uninterpretedSortValueContext] contains
* all possible values for the given sort.
* */
sortDependency.forEach { interpretation(it) }
sortDependency.forEach {
if (it is KUninterpretedSortValueDecl) {
val value = ctx.mkUninterpretedSortValue(it.sort, it.valueIdx)
uninterpretedSortValueContext.registerValue(value)
} else interpretation(it)
}

uninterpretedSortValueContext.currentSortUniverse(sort)
}
Expand Down
Loading