Skip to content

Commit

Permalink
#2727 Started to add detection of conditional accesses.
Browse files Browse the repository at this point in the history
  • Loading branch information
hiker committed Nov 14, 2024
1 parent 250fb00 commit 17238a7
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 8 deletions.
139 changes: 137 additions & 2 deletions src/psyclone/core/variables_access_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
'''This module provides management of variable access information.'''


from psyclone.core.access_type import AccessType
from psyclone.core.component_indices import ComponentIndices
from psyclone.core.signature import Signature
from psyclone.core.single_variable_access_info import SingleVariableAccessInfo
Expand Down Expand Up @@ -167,8 +168,8 @@ def __str__(self):
mode = "READ"
elif self.is_written(signature):
mode = "WRITE"
all_accesses = self[signature]
cond = any(acc.conditional for acc in all_accesses)
all_accesses = self[signature]
cond = any(acc.conditional for acc in all_accesses)
output_list.append(f"{'%' if cond else ''}{signature}: {mode}")
return ", ".join(output_list)

Expand Down Expand Up @@ -382,6 +383,140 @@ def has_read_write(self, signature):
var_access_info = self[signature]
return var_access_info.has_read_write()

def set_conditional_accesses(self, if_branch, else_branch):
'''This function adds the accesses from `if_branch` and `else_branch`,
marking them as conditional if the accesses are already conditional,
or only happen in one of the two branches. While this function is
at the moment only used for if-statements, it can also be used for
e.g. loops by providing None as `else_branch` object.
:param if_branch: the first branch.
:type if_branch: :py:class:`psyclone.psyir.nodes.Node`
:param else_branch: the second branch, which can be None.
:type else_branch: :py:class:`psyclone.psyir.nodes.Node`
'''
var_if = VariablesAccessInfo(if_branch, self.options())
# Create an empty access info object in case that we do not have
# a second branch.
if else_branch:
var_else = VariablesAccessInfo(else_branch, self.options())
else:
var_else = VariablesAccessInfo()

# Get the list of all signatures in the if and else branch:
all_sigs = set(var_if.keys())
all_sigs.update(set(var_else.keys()))

for sig in all_sigs:
if sig not in var_if or sig not in var_else:
# Signature is only in one branch. Mark all existing accesses
# as conditional
var_access = var_if[sig] if sig in var_if else var_else[sig]
for access in var_access.all_accesses:
access.conditional = True
continue

# Now we have a signature that is accessed in both
# the if and else block. In case of array variables, we need to
# distinguish between different indices, e.g. a(i) might be
# written to unconditionally, but a(i+1) might be written
# conditionally. Additionally, we should support mathematically
# equivalent statements (e.g. a(i+1), and a(1+i)).
# As a first step, split all the accesses into equivalence
# classes. Each equivalent class stores two lists as a pair: the
# first one with the accesses from the if branch, the second with
# the accesses from the else branch.
equiv = {}
for access in var_if[sig].all_accesses:
for comp_access in equiv.keys():
if access.component_indices.equal(comp_access):
equiv[comp_access][0].append(access)
break
else:
# New component index:
equiv[access.component_indices] = ([access], [])
# While we know that the signature is used in both branches, the
# accesses for a given equivalence class of indices could still
# be in only in one of them (e.g.
# if () then a(i)=1 else a(i+1)=2 endif). So it is still possible
# that we a new equivalence class in the second branch
for access in var_else[sig].all_accesses:
for comp_access in equiv.keys():
if access.component_indices.equal(comp_access):
equiv[comp_access][1].append(access)
break
else:
# New component index:
equiv[access.component_indices] = ([], [access])

# Now handle each equivalent set of component indices:
for comp_index in equiv.keys():
if_accesses, else_accesses = equiv[comp_index]
# If the access is not in both branches, it is conditional:
if not if_accesses or not else_accesses:
# Only accesses in one section, therefore conditional:
var_access = if_accesses if if_accesses else else_accesses
for access in var_access:
access.conditional = True
continue

# Now we have accesses to the same indices in both branches.
# We still need to distinguish between read and write accesses.
# This can result in incorrect/unexpected results in some rare
# cases:
# if ()
# call kernel(a(i)) ! Assume a(i) is READWRITE
# else
# b = a(i)
# endif
# Now the read access to a(i) is unconditional, but the write
# access to a(i) as part of the readwrite is conditional. But
# since there is only one accesses for the readwrite, we can't
# mark it as both conditional and unconditional
is_conditional = True

for mode in [AccessType.READ, AccessType.WRITE]:
for access in if_accesses:
# Ignore read or write accesses depending on mode
if mode is AccessType.READ and not access.is_read:
continue
if mode is AccessType.WRITE and not access.is_written:
continue
if not access.conditional:
is_conditional = False
break
# If an access is unconditional in the if branch, then we
# need to check the if branch
if not is_conditional:
for access in if_accesses:
# Ignore read or write accesses depending on mode
if mode is AccessType.READ and not access.is_read:
continue
if mode is AccessType.WRITE and \
not access.is_written:
continue
if not access.conditional:
is_conditional = False
break

# If the access to this equivalence class is conditional,
# mark all accesses as conditional:
if is_conditional:
for access in if_accesses + else_accesses:
# Ignore read or write accesses depending on mode
if mode is AccessType.READ and not access.is_read:
continue
if mode is AccessType.WRITE and \
not access.is_written:
continue
access.conditional = True
print("CONDITIONAL", sig.to_language(
component_indices=comp_index))

self.merge(var_if)
self.merge(var_else)


# ---------- Documentation utils -------------------------------------------- #
# The list of module members that we wish AutoAPI to generate
Expand Down
12 changes: 8 additions & 4 deletions src/psyclone/psyir/nodes/if_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,13 @@ def reference_accesses(self, var_accesses):
# The first child is the if condition - all variables are read-only
self.condition.reference_accesses(var_accesses)
var_accesses.next_location()
self.if_body.reference_accesses(var_accesses)
var_accesses.next_location()

if self.else_body:
self.else_body.reference_accesses(var_accesses)
if not var_accesses.options("FLATTEN"):
self.if_body.reference_accesses(var_accesses)
var_accesses.next_location()
if self.else_body:
self.else_body.reference_accesses(var_accesses)
var_accesses.next_location()
return

var_accesses.set_conditional_accesses(self.if_body, self.else_body)
69 changes: 67 additions & 2 deletions src/psyclone/tests/core/variables_access_info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from psyclone.core import ComponentIndices, Signature, VariablesAccessInfo
from psyclone.core.access_type import AccessType
from psyclone.errors import InternalError
from psyclone.psyir.nodes import Assignment, Node
from psyclone.psyir.nodes import Assignment, IfBlock, Node
from psyclone.tests.utilities import get_invoke


Expand Down Expand Up @@ -395,18 +395,20 @@ def test_variables_access_info_options():
assert vai.options("COLLECT-ARRAY-SHAPE-READS") is True
assert vai.options("USE-ORIGINAL-NAMES") is False
assert vai.options() == {"COLLECT-ARRAY-SHAPE-READS": True,
"FLATTEN": False,
"USE-ORIGINAL-NAMES": False}

vai = VariablesAccessInfo(options={'USE-ORIGINAL-NAMES': True})
assert vai.options("COLLECT-ARRAY-SHAPE-READS") is False
assert vai.options("USE-ORIGINAL-NAMES") is True
assert vai.options() == {"COLLECT-ARRAY-SHAPE-READS": False,
"FLATTEN": False,
"USE-ORIGINAL-NAMES": True}

with pytest.raises(InternalError) as err:
vai.options("invalid")
assert ("Option key 'invalid' is invalid, it must be one of "
"['COLLECT-ARRAY-SHAPE-READS', 'USE-ORIGINAL-NAMES']."
"['COLLECT-ARRAY-SHAPE-READS', 'FLATTEN', 'USE-ORIGINAL-NAMES']."
in str(err.value))


Expand Down Expand Up @@ -477,3 +479,66 @@ def test_lfric_access_info():
"READ, np_xy_qr: READ, np_z_qr: READ, undf_w1: READ, undf_w2: "
"READ, undf_w3: READ, weights_xy_qr: READ, weights_z_qr: READ"
== str(vai))


# -----------------------------------------------------------------------------
def test_variables_access_info_flatten(fortran_reader):
'''Test that flatten works as expected.
'''
code = '''module test
contains
subroutine tmp()
integer :: cond_var
integer :: write_if, write_else, write_if_else
integer :: read_if, read_else, read_if_else
if (cond_var .eq. 1) then
write_if = read_if
write_if_else = read_if_else
else
write_else = read_else
write_if_else = read_if_else
endif
end subroutine tmp
end module test'''
psyir = fortran_reader.psyir_from_source(code)
node1 = psyir.walk(IfBlock)[0]

# By default, array shape accesses are not reads.
vai = VariablesAccessInfo(node1, options={"FLATTEN": True})

print(vai)


# -----------------------------------------------------------------------------
def test_variables_access_info_array_conditional(fortran_reader):
'''Test that flatten works as expected.
'''
code = '''module test
contains
subroutine tmp(i)
integer :: cond_var, i
integer, dimension(10) :: write_if, write_else, write_if_else
integer, dimension(10) :: read_if, read_else, read_if_else
if (cond_var .ne. 1) then
if (cond_var .eq. 2) then
write_if(i) = read_if(i)
write_if_else(i+1) = read_if_else(i)
endif
write_if_else(i+1) = read_if_else(1+i-1)
else
write_if(i) = read_if(i)
write_else(i) = read_else(i)
write_if_else(i) = read_if_else(i)
write_if_else(i+1) = read_if_else(i)
endif
end subroutine tmp
end module test'''
psyir = fortran_reader.psyir_from_source(code)
node1 = psyir.walk(IfBlock)[0]

# By default, array shape accesses are not reads.
vai = VariablesAccessInfo(node1, options={"FLATTEN": True})

print(vai)

0 comments on commit 17238a7

Please sign in to comment.