Skip to content

Commit

Permalink
Merge pull request #34 from TheGreatfpmK/new-master
Browse files Browse the repository at this point in the history
Added MDP CEs to policy search
  • Loading branch information
TheGreatfpmK authored Dec 20, 2023
2 parents 009f819 + 879d44e commit dc4f7b0
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 18 deletions.
15 changes: 15 additions & 0 deletions paynt/quotient/mdp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ def fix_and_apply_policy_to_family(self, family, policy):
policy_fixed[state] = policy[state]

return policy_fixed,mdp


def apply_policy_to_family(self, family, policy):
policy_choices = []
for state,action in enumerate(policy):
if action is None:
for choice in self.state_action_choices[state]:
policy_choices += choice
else:
policy_choices += self.state_action_choices[state][action]
choices = stormpy.synthesis.policyToChoicesForFamily(policy_choices, family.selected_choices)

mdp = self.build_from_choice_mask(choices)

return mdp


def assert_mdp_is_deterministic(self, mdp, family):
Expand Down
66 changes: 48 additions & 18 deletions paynt/synthesizer/policy_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,9 @@ def synthesize_policy_for_family_using_ceg(self, family, prop):

smt_solver = paynt.family.smt.SmtSolver(family)

unsat_conflict_generator = paynt.synthesizer.conflict_generator.mdp.ConflictGeneratorMdp(self.quotient)
unsat_conflict_generator.initialize()

mdp_subfamily = smt_solver.pick_assignment(family)

while mdp_subfamily is not None:
Expand All @@ -830,13 +833,31 @@ def synthesize_policy_for_family_using_ceg(self, family, prop):
self.stat.iteration(mdp_subfamily.mdp)

if not result.sat:
# Potential for MDP CEs here
pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, [list(range(family.num_holes))])
# MDP CE
requests = [(0, self.quotient.specification.all_properties()[0], None)]
choices = self.quotient.coloring.selectCompatibleChoices(mdp_subfamily.family)
model,state_map,choice_map = self.quotient.restrict_quotient(choices)
model = paynt.quotient.models.MDP(model,self.quotient,state_map,choice_map,mdp_subfamily)
conflicts = unsat_conflict_generator.construct_conflicts(family, mdp_subfamily, model, requests)

# conflicts = [list(range(family.num_holes))] # UNSAT without CE generalization

pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, conflicts)
self.explored += pruned
unsat_mdp_families.append(mdp_subfamily)

# MDP CE
unsat_family = family.copy()
for hole_index in range(self.quotient.design_space.num_holes):
if hole_index in conflicts[0]:
unsat_family.hole_set_options(hole_index, mdp_subfamily.hole_options(hole_index))

mdp_subfamily.mdp = None
unsat_family.mdp = None
unsat_mdp_families.append(unsat_family)
else:
policy = self.quotient.scheduler_to_policy(result.result.scheduler, mdp_subfamily.mdp)
policy_fixed, policy_quotient_mdp = self.quotient.fix_and_apply_policy_to_family(family, policy)
policy, policy_quotient_mdp = self.quotient.fix_and_apply_policy_to_family(family, policy) # DTMC CE
# policy_quotient_mdp = self.quotient.apply_policy_to_family(family, policy) # MDP SAT CE
quotient_assignment = self.quotient.coloring.getChoiceToAssignment()
choice_to_hole_options = []
for choice in range(policy_quotient_mdp.choices):
Expand All @@ -845,16 +866,17 @@ def synthesize_policy_for_family_using_ceg(self, family, prop):

coloring = stormpy.synthesis.Coloring(family.family, policy_quotient_mdp.model.nondeterministic_choice_indices, choice_to_hole_options)
quotient_container = paynt.quotient.quotient.DtmcFamilyQuotient(policy_quotient_mdp.model, family, coloring, self.quotient.specification.negate())
conflict_generator = paynt.synthesizer.conflict_generator.dtmc.ConflictGeneratorDtmc(quotient_container)
# conflict_generator = paynt.synthesizer.conflict_generator.mdp.ConflictGeneratorMdp(quotient_container)
conflict_generator = paynt.synthesizer.conflict_generator.dtmc.ConflictGeneratorDtmc(quotient_container) # DTMC CE
# conflict_generator = paynt.synthesizer.conflict_generator.mdp.ConflictGeneratorMdp(quotient_container) # MDP SAT CE
conflict_generator.initialize()
mdp_subfamily.constraint_indices = family.constraint_indices
requests = [(0, quotient_container.specification.all_properties()[0], None)]
model = quotient_container.build_assignment(mdp_subfamily)

# choices = coloring.selectCompatibleChoices(mdp_subfamily.family)
# model,state_map,choice_map = quotient_container.restrict_quotient(choices)
# model = paynt.quotient.models.MDP(model,quotient_container,state_map,choice_map,mdp_subfamily)
model = quotient_container.build_assignment(mdp_subfamily) # DTMC CE

# choices = coloring.selectCompatibleChoices(mdp_subfamily.family) # MDP SAT CE
# model,state_map,choice_map = quotient_container.restrict_quotient(choices) # MDP SAT CE
# model = paynt.quotient.models.MDP(model,quotient_container,state_map,choice_map,mdp_subfamily) # MDP SAT CE

conflicts = conflict_generator.construct_conflicts(family, mdp_subfamily, model, requests)
pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, conflicts)
Expand All @@ -865,9 +887,10 @@ def synthesize_policy_for_family_using_ceg(self, family, prop):
if hole_index in conflicts[0]:
sat_family.hole_set_options(hole_index, mdp_subfamily.hole_options(hole_index))

sat_family.mdp = None
sat_mdp_families.append(sat_family)
sat_mdp_to_policy_map.append(len(sat_mdp_policies))
sat_mdp_policies.append(policy_fixed)
sat_mdp_to_policy_map.append(len(sat_mdp_policies))
sat_mdp_policies.append(policy)

mdp_subfamily = smt_solver.pick_assignment(family)

Expand All @@ -882,17 +905,24 @@ def evaluate_all(self, family, prop, keep_value_only=False):
assert not prop.reward, "expecting reachability probability propery"
game_solver = self.quotient.build_game_abstraction_solver(prop)
policy_tree = PolicyTree(family)
self.create_action_coloring()


### POLICY SEARCH TESTING
#self.create_action_coloring()

# choose policy search method
# unsat, sat, policies, policy_map = self.synthesize_policy_for_family_linear(policy_tree.root.family, prop)
# unsat, sat, policies, policy_map = self.synthesize_policy_for_family_using_ceg(policy_tree.root.family, prop)

# self.stat.synthesis_timer.stop()

# unsat_mdps_count = sum([s.size for s in unsat])
# sat_mdps_count = sum([s.size for s in sat])

# print(f'unSAT: {len(unsat)}')
# print(f'SAT: {len(sat)}')
# print(f'policies: {len(policies)}')
# print(self.stat.iterations_mdp)
# print(f'unSAT MDPs: {unsat_mdps_count}\tunSAT families: {len(unsat)}\tavg. unSAT family size: {round(unsat_mdps_count/len(unsat),2) if len(unsat) != 0 else "N/A"}')
# print(f'SAT MDPs: {sat_mdps_count}\tSAT families: {len(sat)}\tavg. SAT family size: {round(sat_mdps_count/len(sat),2) if len(sat) != 0 else "N/A"}')
# print(f'policies: {len(policies)}\tpolicy per SAT MDP: {round(len(policies)/sat_mdps_count,2) if sat_mdps_count != 0 else "N/A"}')
# print(f'iterations: {self.stat.iterations_mdp}')
# print(f'time: {round(self.stat.synthesis_timer.time,2)}s')

# self.double_check_policy_synthesis(unsat, sat, policies, policy_map, prop)
# exit()
Expand Down

0 comments on commit dc4f7b0

Please sign in to comment.