Skip to content

Commit

Permalink
integrate policy search
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Dec 8, 2023
1 parent 18bc4c8 commit eedf67f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 68 deletions.
18 changes: 14 additions & 4 deletions paynt/quotient/mdp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ def __init__(self, quotient_mdp, coloring, specification):

self.design_space = paynt.quotient.holes.DesignSpace(coloring.holes)

# a list of action labels
self.action_labels = None
# for each choice of the quotient, the executed action
self.choice_to_action = None
# for each state of the quotient and for each action, a list of choices that execute this action
self.state_action_choices = None
# for each state of the quotient, a list of available actions
self.state_to_actions = None
# for each choice of the quotient, a list of its state-destinations
self.choice_destinations = None

self.action_labels,self.choice_to_action = MdpFamilyQuotientContainer.extract_choice_labels(self.quotient_mdp)
self.state_action_choices = MdpFamilyQuotientContainer.map_state_action_to_choices(
self.quotient_mdp, self.num_actions, self.choice_to_action)
Expand Down Expand Up @@ -119,7 +130,7 @@ def choices_to_hole_selection(self, choice_mask):
return hole_selection

def empty_policy(self):
return [self.num_actions] * self.quotient_mdp.nr_states
return [None] * self.quotient_mdp.nr_states

def scheduler_to_policy(self, scheduler, mdp):
policy = self.empty_policy()
Expand All @@ -141,10 +152,9 @@ def fix_policy_for_family(self, family, policy):
:return fixed policy
:return choice mask from which Q-MDP x policy can be constructed
'''
invalid_action = self.num_actions

choice_mask = stormpy.BitVector(self.quotient_mdp.nr_choices,False)
policy_fixed = [invalid_action] * self.quotient_mdp.nr_states
policy_fixed = self.empty_policy()

initial_state = list(self.quotient_mdp.initial_states)[0]
tm = self.quotient_mdp.transition_matrix
Expand All @@ -155,7 +165,7 @@ def fix_policy_for_family(self, family, policy):
while state_queue:
state = state_queue.pop()
action = policy[state]
if action == invalid_action:
if action is None:
action = self.state_to_actions[state][0]
policy_fixed[state] = action
for choice in self.state_action_choices[state][action]:
Expand Down
131 changes: 67 additions & 64 deletions paynt/synthesizer/policy_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ def __str__(self):
return str(self.sat)


def merge_policies(policy1, policy2, dont_care_action):
def policies_are_compatible(policy1, policy2):
for state,a1 in enumerate(policy1):
a2 = policy2[state]
if a1 is not None and a2 is not None and a1 != a2:
return False
return True

def merge_policies(policy1, policy2):

num_states = len(policy1)
# agree_mask = stormpy.storage.BitVector(num_states,False)
Expand All @@ -36,21 +43,19 @@ def merge_policies(policy1, policy2, dont_care_action):

policy12 = policy1.copy()
policy21 = policy2.copy()
for state in range(num_states):
a1 = policy1[state]
for state,a1 in enumerate(policy1):
a2 = policy2[state]
if a1 == dont_care_action:
if a1 is None:
policy12[state] = a2
if a2 == dont_care_action:
if a2 is None:
policy21[state] = a1
return policy12,policy21


def test_nodes(quotient, prop, node1, node2):
dont_care_action = quotient.num_actions
policy1 = node1.policy
policy2 = node2.policy
policy12,policy21 = merge_policies(policy1,policy2, dont_care_action)
policy12,policy21 = merge_policies(policy1,policy2)

# try policy1 for family2
policy,choice_mask = quotient.fix_policy_for_family(node2.family, policy12)
Expand Down Expand Up @@ -220,7 +225,7 @@ def collect_solved(self):
return solved


def double_check_all_families(self, quotient, prop):
def double_check(self, quotient, prop):
leaves = self.collect_leaves()
logger.info("double-checking {} families...".format(len(leaves)))
for leaf in leaves:
Expand Down Expand Up @@ -382,7 +387,12 @@ def solve_game_abstraction(self, family, prop, game_solver):
game_solver.solve(family.selected_actions_bv, prop.maximizing, prop.minimizing)
self.stat.iteration_game(family.mdp.states)
# logger.debug("game solved, value is {}".format(game_solver.solution_value))
policy = game_solver.solution_state_to_player1_action
game_policy = game_solver.solution_state_to_player1_action
# fix irrelevant choices
policy = [None] * self.quotient.quotient_mdp.nr_states
for state,action in enumerate(game_policy):
if action is not None:
policy[state] = action
policy_sat = prop.satisfies_threshold(game_solver.solution_value)

if False:
Expand All @@ -399,9 +409,17 @@ def solve_game_abstraction(self, family, prop, game_solver):


def verify_family(self, family, game_solver, prop, reference_policy=None):
# logger.info("investigating family of size {}".format(family.size))
self.quotient.build(family)
mdp_family_result = MdpFamilyResult()

if family.size <= 8 and False:
policy_is_unique, unsat_mdps, sat_mdps, sat_policies = self.synthesize_policy_for_family(family, prop)
if policy_is_unique:
for policy in sat_policies:
assert policies_are_compatible(policy,sat_policies[0])
exit()

if False and reference_policy is not None:
# try reference policy
reference_policy_sat = self.verify_policy(family,prop,reference_policy)
Expand Down Expand Up @@ -471,12 +489,6 @@ def verify_family(self, family, game_solver, prop, reference_policy=None):
return mdp_family_result


def choose_splitter_round_robin(self, family, prop, scheduler_choices, state_values, hole_selection):
splitter = (self.last_splitter+1) % family.num_holes
while family[splitter].size == 1:
splitter = (splitter+1) % family.num_holes
return splitter

def choose_splitter(self, family, prop, scheduler_choices, state_values, hole_selection):
splitter = None
inconsistent_assignments = {hole_index:options for hole_index,options in enumerate(hole_selection) if len(options) > 1}
Expand Down Expand Up @@ -536,39 +548,35 @@ def split(self, family, prop, hole_selection, splitter):
return suboptions,subfamilies


def create_action_coloring(self, quotient_mdp):
def create_action_coloring(self):

quotient_mdp = self.quotient.quotient_mdp
holes = paynt.quotient.holes.Holes()
action_to_hole_options = []
for state in quotient_mdp.states:
choice_to_hole_options = [{} for choice in range(quotient_mdp.nr_choices)]

state_actions = self.quotient.state_to_actions[int(state)]
for state in range(quotient_mdp.nr_states):

state_actions = self.quotient.state_to_actions[state]
if len(state_actions) <= 1:
for action in range(quotient_mdp.get_nr_available_actions(state)):
action_to_hole_options.append({})
# hole is not needed
continue

# create fresh hole
hole_index = holes.num_holes
name = f'state_{state}'
options = list(range(len(state_actions)))
option_labels = [self.quotient.action_labels[action] for action in state_actions]
hole = paynt.quotient.holes.Hole(name, options, option_labels)
holes.append(hole)

for action in range(quotient_mdp.get_nr_available_actions(state)):
choice = quotient_mdp.get_choice_index(state, action)
choice_index = -1
for index, action_list in enumerate(list(self.quotient.state_action_choices[int(state)])):
if choice in action_list:
choice_index = index
break
assert choice_index != -1

hole_options = {len(holes)-1: state_actions.index(choice_index)}
action_to_hole_options.append(hole_options)
for action_index,action in enumerate(state_actions):
color = {hole_index: action_index}
for choice in self.quotient.state_action_choices[state][action]:
choice_to_hole_options[choice] = color

coloring = paynt.quotient.coloring.Coloring(quotient_mdp, holes, action_to_hole_options)

return coloring
coloring = paynt.quotient.coloring.Coloring(quotient_mdp, holes, choice_to_hole_options)
self.action_coloring = coloring
return


def update_scores(self, score_lists, selection):
Expand All @@ -578,20 +586,19 @@ def update_scores(self, score_lists, selection):
score_list.append(choice)


# synthesize one policy for family of MDPs (if such policy exists)
# set all_sat=True if all MDPs in the family are sat
# returns - True, unsat_families, sat_families, policy
# - False, unsat_families, sat_families, sat_policies
def synthesize_policy_for_tree_node(self, family, prop, all_sat=False, iteration_limit=0):
def synthesize_policy_for_family(self, family, prop, all_sat=False, iteration_limit=0):
'''
Synthesize one policy for family of MDPs (if such policy exists).
:param all_sat if True, it is assumed that all MDPs are SAT
:returns whether all SAT MDPs are solved using a single policy
:returns a list of UNSAT MDPs
:returns a list of SAT MDPs
:returns to each SAT MDP, a corresponding policy
'''
sat_mdp_families = []
sat_mdp_policies = []
unsat_mdp_families = []

# coloring for MDP choices
# TODO move this outside of the function since this needs to be created only once or create coloring on family.mdp
action_coloring = self.create_action_coloring(self.quotient.quotient_mdp)
action_family = paynt.quotient.holes.DesignSpace(action_coloring.holes)

# create MDP subfamilies
for hole_assignment in family.all_combinations():
subfamily = family.copy()
Expand All @@ -602,6 +609,7 @@ def synthesize_policy_for_tree_node(self, family, prop, all_sat=False, iteration
if not all_sat:
self.quotient.build(subfamily)
primary_result = subfamily.mdp.model_check_property(prop)
assert primary_result.result.has_scheduler
self.stat.iteration_mdp(subfamily.mdp.states)

if primary_result.sat == False:
Expand All @@ -616,11 +624,12 @@ def synthesize_policy_for_tree_node(self, family, prop, all_sat=False, iteration

# no sat mdps
if len(sat_mdp_families) == 0:
return False, unsat_mdp_families, sat_mdp_families, None
return False, unsat_mdp_families, sat_mdp_families, sat_mdp_policies

if len(sat_mdp_policies) == 0:
sat_mdp_policies = [None for _ in sat_mdp_families]

action_family = paynt.quotient.holes.DesignSpace(self.action_coloring.holes)
action_family_stack = [action_family]
iter = 0

Expand All @@ -638,7 +647,7 @@ def synthesize_policy_for_tree_node(self, family, prop, all_sat=False, iteration
# try to find controller inconsistency across the MDPs
# if the controllers are consistent, return True
for index, mdp_subfamily in enumerate(sat_mdp_families):
self.quotient.build_with_second_coloring(mdp_subfamily, action_coloring, current_action_family) # maybe copy to new family?
self.quotient.build_with_second_coloring(mdp_subfamily, self.action_coloring, current_action_family) # maybe copy to new family?

mc_result = stormpy.model_checking(
current_action_family.mdp.model, prop.formula, extract_scheduler=True, environment=Property.environment)
Expand All @@ -653,12 +662,12 @@ def synthesize_policy_for_tree_node(self, family, prop, all_sat=False, iteration

# add policy if current mdp doesn't have one yet
# TODO maybe this can be done after some number of controllers are consistent?
if sat_mdp_policies[index] == None:
if sat_mdp_policies[index] is None:
policy = self.quotient.scheduler_to_policy(primary_result.result.scheduler, mdp_subfamily.mdp)
sat_mdp_policies[index] = policy

current_results.append(primary_result)
selection = self.quotient.scheduler_selection_with_coloring(current_action_family.mdp, primary_result.result.scheduler, action_coloring)
selection = self.quotient.scheduler_selection_with_coloring(current_action_family.mdp, primary_result.result.scheduler, self.action_coloring)
self.update_scores(score_lists, selection)

scores = {hole:len(score_list) for hole, score_list in score_lists.items()}
Expand Down Expand Up @@ -705,6 +714,7 @@ def synthesize_policy_for_tree_node(self, family, prop, all_sat=False, iteration
for mdp_index in mdps_without_policy:
self.quotient.build(sat_mdp_families[mdp_index])
primary_result = sat_mdp_families[mdp_index].mdp.model_check_property(prop)
assert primary_result.result.has_scheduler
self.stat.iteration_mdp(sat_mdp_families[mdp_index].mdp.states)
policy = self.quotient.scheduler_to_policy(primary_result.result.scheduler, sat_mdp_families[mdp_index].mdp)
sat_mdp_policies[mdp_index] = policy
Expand All @@ -715,33 +725,27 @@ def synthesize_policy_for_tree_node(self, family, prop, all_sat=False, iteration

def synthesize_policy_tree(self, family):

self.last_splitter = -1
prop = self.quotient.specification.constraints[0]
prop = self.quotient.get_property()
game_solver = self.quotient.build_game_abstraction_solver(prop)
# game_solver.enable_profiling(True)
policy_tree = PolicyTree(family)
self.create_action_coloring()

# self.quotient.build(policy_tree.root.family)
# policy_exists,_,_,_ = self.synthesize_policy_for_tree_node(policy_tree.root.family, prop, all_sat=True)
# print("Policy exists: ", policy_exists)
# self.stat.finished(None)
# self.stat.print()
# exit()
if False:
self.quotient.build(policy_tree.root.family)
policy_exists,_,_,_ = self.synthesize_policy_for_family(policy_tree.root.family, prop, all_sat=True)
print("Policy exists: ", policy_exists)
return None

reference_policy = None
policy_tree_leaves = [policy_tree.root]
while policy_tree_leaves:

policy_tree_node = policy_tree_leaves.pop(-1)
family = policy_tree_node.family
# logger.info("investigating family of size {}".format(family.size))
result = self.verify_family(family,game_solver,prop,reference_policy)
policy_tree_node.policy = result.policy
policy_tree_node.policy_source = result.policy_source

# if family.size < 8:
# policy_exists, unsat, sat, policy = self.synthesize_policy_for_tree_node(family, prop)

if result.policy == False:
reference_policy = None
self.explore(family)
Expand All @@ -759,8 +763,7 @@ def synthesize_policy_tree(self, family):
policy_tree_node.split(result.splitter,suboptions,subfamilies)
policy_tree_leaves = policy_tree_leaves + policy_tree_node.child_nodes

# game_solver.print_profiling()
policy_tree.double_check_all_families(self.quotient, prop)
policy_tree.double_check(self.quotient, prop)
policy_tree.print_stats()
policy_tree.postprocess(self.quotient, prop)
return policy_tree
Expand Down

0 comments on commit eedf67f

Please sign in to comment.