diff --git a/paynt/synthesizer/policy_tree.py b/paynt/synthesizer/policy_tree.py index 0684f941f..3fc4b1c6b 100644 --- a/paynt/synthesizer/policy_tree.py +++ b/paynt/synthesizer/policy_tree.py @@ -27,16 +27,34 @@ def __str__(self): return str(self.sat) +def actions_are_compatible(a1, a2): + return a1 is None or a2 is None or a1==a2 + def policies_are_compatible(policy1, policy2): for state,a1 in enumerate(policy1): a2 = policy2[state] - if a1 is not None and a2 is not None and a1 != a2: + if not actions_are_compatible(a1,a2): return False return True -def merge_policies(policy1, policy2): +def merge_policies(policies): + ''' + Attempt to merge multiple policies into one. + :returns one policy or None if some policies were incompatible + ''' + policy = policies[0].copy() + for policy2 in policies[1:]: + for state,a1 in enumerate(policy): + a2 = policy2[state] + if not actions_are_compatible(a1,a2): + return None + policy[state] = a1 or a2 + return policy + + +def merge_policies_exclusively(policy1, policy2): - num_states = len(policy1) + # num_states = len(policy1) # agree_mask = stormpy.storage.BitVector(num_states,False) # for state in range(num_states): # agree_mask.set(policy1[state] == policy2[state],True) @@ -55,7 +73,11 @@ def merge_policies(policy1, policy2): def test_nodes(quotient, prop, node1, node2): policy1 = node1.policy policy2 = node2.policy - policy12,policy21 = merge_policies(policy1,policy2) + policy = merge_policies([policy1,policy2]) + if policy is not None: + return policy + + policy12,policy21 = merge_policies_exclusively(policy1,policy2) # try policy1 for family2 policy,choice_mask = quotient.fix_policy_for_family(node2.family, policy12) @@ -376,6 +398,7 @@ def solve_singleton(self, family, prop): if not result.sat: return False policy = self.quotient.scheduler_to_policy(result.result.scheduler, family.mdp) + # uncomment below to preemptively double-check the policy # SynthesizerPolicyTree.double_check_policy(self.quotient, family, prop, policy) return policy @@ -389,9 +412,9 @@ def solve_game_abstraction(self, family, prop, game_solver): # logger.debug("game solved, value is {}".format(game_solver.solution_value)) game_policy = game_solver.solution_state_to_player1_action # fix irrelevant choices - policy = [None] * self.quotient.quotient_mdp.nr_states + policy = self.quotient.empty_policy() for state,action in enumerate(game_policy): - if action is not None: + if action < self.quotient.num_actions: policy[state] = action policy_sat = prop.satisfies_threshold(game_solver.solution_value) @@ -416,9 +439,8 @@ def verify_family(self, family, game_solver, prop, reference_policy=None): if family.size <= 8 and False: policy_is_unique, unsat_mdps, sat_mdps, sat_policies = self.synthesize_policy_for_family(family, prop) if policy_is_unique: - for policy in sat_policies: - assert policies_are_compatible(policy,sat_policies[0]) - exit() + policy = merge_policies(sat_policies) + assert policy is not None if False and reference_policy is not None: # try reference policy