Skip to content

Commit

Permalink
cleanups for flake
Browse files Browse the repository at this point in the history
  • Loading branch information
SCHREIBER Martin committed Nov 13, 2024
1 parent 0926461 commit b9b24a3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 44 deletions.
66 changes: 41 additions & 25 deletions src/psyclone/psyir/nodes/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def node_str(self, colour=True):
def __str__(self):
return self.node_str(False)


def copy(self):
'''Return a copy of this node. This is a bespoke implementation for
a Call node that ensures that any internal id's are
Expand All @@ -442,7 +441,6 @@ def copy(self):

return new_copy


def get_callees(self):
'''
Searches for the implementation(s) of all potential target routines
Expand Down Expand Up @@ -613,15 +611,17 @@ def get_argument_routine_match(self, routine: Routine):

# Create a copy of the list
# Once an argument has been successfully matched, set it to 'None'
routine_argument_list: List[DataNode] = routine.symbol_table.argument_list[:]
routine_argument_list: List[DataNode] = \
routine.symbol_table.argument_list[:]

# Find matching argument list
#if len(self.arguments) != len(routine_argument_list):
# if len(self.arguments) != len(routine_argument_list):
# return None

if len(self.arguments) > len(routine.symbol_table.argument_list):
raise self.MatchingArgumentsNotFound(
f"More arguments in callee (call '{self.routine.name}') than caller (routine '{routine.name}')"
f"More arguments in callee (call '{self.routine.name}')"
f" than caller (routine '{routine.name}')"
)

assert len(self.arguments) == len(self.argument_names)
Expand All @@ -632,19 +632,25 @@ def get_argument_routine_match(self, routine: Routine):
call_arg_idx: int
call_arg: DataSymbol

# If None, it's a positional argument => Just return the index if the types match
# If None, it's a positional argument => Just return the index if
# the types match
if self.argument_names[call_arg_idx] is None:
routine_arg = routine_argument_list[call_arg_idx]
routine_arg: DataSymbol

# Do the types of arguments match?
#
# TODO #759: If optional is used, it's an unsupported Fortran type and we need to use the following workaround
# TODO #759: If optional is used, it's an unsupported Fortran
# type and we need to use the following workaround
# Once this issue is resolved, simply remove this if branch
if not isinstance(routine_arg.datatype, UnsupportedFortranType):
if not isinstance(
routine_arg.datatype,
UnsupportedFortranType):
if call_arg.datatype != routine_arg.datatype:
raise self.MatchingArgumentsNotFound(
f"Argument type mismatch of call argument '{call_arg}' and routine argument '{routine_arg}'"
f"Argument type mismatch of call argument "
f"'{call_arg}' and routine argument "
f"'{routine_arg}'"
)

ret_arg_idx_list.append(call_arg_idx)
Expand All @@ -657,20 +663,26 @@ def get_argument_routine_match(self, routine: Routine):
arg_name = self.argument_names[call_arg_idx]
named_arg_found = False
routine_arg_idx = None
for routine_arg_idx, routine_arg in enumerate(routine_argument_list):
for routine_arg_idx, routine_arg in enumerate(
routine_argument_list):
routine_arg: DataSymbol

# Check if argument was already processed
if routine_arg is None:
continue

if arg_name == routine_arg.name:
# TODO #759: If optional is used, it's an unsupported Fortran type and we need to use the following workaround
# TODO #759: If optional is used, it's an unsupported
# Fortran type and we need to use the following workaround
# Once this issue is resolved, simply remove this if branch
if not isinstance(routine_arg.datatype, UnsupportedFortranType):
if not isinstance(
routine_arg.datatype,
UnsupportedFortranType):
if call_arg.datatype != routine_arg.datatype:
raise self.MatchingArgumentsNotFound(
f"Argument type mismatch of call argument '{call_arg}' and routine argument '{routine_arg}'"
f"Argument type mismatch of call argument "
f"'{call_arg}' and routine argument "
f"'{routine_arg}'"
)

ret_arg_idx_list.append(routine_arg_idx)
Expand All @@ -686,8 +698,6 @@ def get_argument_routine_match(self, routine: Routine):

routine_argument_list[routine_arg_idx] = None



#
# Finally, we check if all left-over arguments are optional arguments
#
Expand All @@ -703,18 +713,22 @@ def get_argument_routine_match(self, routine: Routine):
continue

raise self.MatchingArgumentsNotFound(
f"Argument '{routine_arg}' in subroutine '{routine.name}' not handled"
f"Argument '{routine_arg}' in subroutine"
f" '{routine.name}' not handled"
)

return ret_arg_idx_list


def get_callee(self, ret_arg_match_list:List[int] = None, check_matching_arguments: bool=True):
def get_callee(
self,
ret_arg_match_list: List[int] = None,
check_matching_arguments: bool = True):
'''
Searches for the implementation(s) of the target routine for this Call
including argument checks.
:param ret_arg_match_list: List in which the matching argument indices will be returned
:param ret_arg_match_list: List in which the matching argument
indices will be returned
:returns: the Routine(s) that this call targets.
:rtype: list[:py:class:`psyclone.psyir.nodes.Routine`]
Expand All @@ -736,13 +750,15 @@ def get_callee(self, ret_arg_match_list:List[int] = None, check_matching_argumen

if ret_arg_match_list is not None:
ret_arg_match_list[:] = arg_match_list

return routine

# If we didn't find any routine, return some routine if no matching arguments have been found.
# This is handy for the transition phase until optional argument matching is supported.

# If we didn't find any routine, return some routine if no matching
# arguments have been found.
# This is handy for the transition phase until optional argument
# matching is supported.
if not check_matching_arguments:
return routine_list[0]


raise NotImplementedError(f"No matching routine for call '{self.routine.name}' found")
raise NotImplementedError(f"No matching routine for call "
f"'{self.routine.name}' found")
42 changes: 23 additions & 19 deletions src/psyclone/tests/psyir/nodes/call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,6 @@ def test_call_get_callees_local(fortran_reader):
assert result == [psyir.walk(Routine)[1]]



def test_call_get_callee_1_simple_match(fortran_reader):
'''
Check that right routine has been found.
Expand Down Expand Up @@ -741,12 +740,13 @@ def test_call_get_callee_3_trigger_error(fortran_reader):
arg_idx_list = []

try:
result: Routine = call_foo.get_callee(ret_arg_match_list=arg_idx_list)
except:
call_foo.get_callee(ret_arg_match_list=arg_idx_list)
except Exception:
print("Success! Exception triggered (as expected)")
return

assert False, "This should have triggered an error since there are more arguments in the call than in the routine"
assert False, ("This should have triggered an error since there"
"are more arguments in the call than in the routine")


def test_call_get_callee_4_named_arguments(fortran_reader):
Expand Down Expand Up @@ -914,7 +914,8 @@ def test_call_get_callee_6_interfaces(fortran_reader):
assert call_foo_a.routine.name == "foo"

arg_idx_list = []
result: Routine = call_foo_a.get_callee(ret_arg_match_list=arg_idx_list)
result: Routine = call_foo_a.get_callee(
ret_arg_match_list=arg_idx_list)

print(f" - Found matching argument list: {arg_idx_list}")

Expand All @@ -932,7 +933,8 @@ def test_call_get_callee_6_interfaces(fortran_reader):
assert call_foo_a.routine.name == "foo"

arg_idx_list = []
result: Routine = call_foo_a.get_callee(ret_arg_match_list=arg_idx_list)
result: Routine = call_foo_a.get_callee(
ret_arg_match_list=arg_idx_list)

print(f" - Found matching argument list: {arg_idx_list}")

Expand All @@ -955,7 +957,8 @@ def test_call_get_callee_6_interfaces(fortran_reader):
assert call_foo_b.routine.name == "foo"

arg_idx_list = []
result: Routine = call_foo_b.get_callee(ret_arg_match_list=arg_idx_list)
result: Routine = call_foo_b.get_callee(
ret_arg_match_list=arg_idx_list)

print(f" - Found matching argument list: {arg_idx_list}")

Expand All @@ -973,7 +976,8 @@ def test_call_get_callee_6_interfaces(fortran_reader):
assert call_foo_b.routine.name == "foo"

arg_idx_list = []
result: Routine = call_foo_b.get_callee(ret_arg_match_list=arg_idx_list)
result: Routine = call_foo_b.get_callee(
ret_arg_match_list=arg_idx_list)

print(f" - Found matching argument list: {arg_idx_list}")

Expand All @@ -992,7 +996,8 @@ def test_call_get_callee_6_interfaces(fortran_reader):
assert call_foo_b.routine.name == "foo"

arg_idx_list = []
result: Routine = call_foo_b.get_callee(ret_arg_match_list=arg_idx_list)
result: Routine = call_foo_b.get_callee(
ret_arg_match_list=arg_idx_list)

print(f" - Found matching argument list: {arg_idx_list}")

Expand All @@ -1004,8 +1009,6 @@ def test_call_get_callee_6_interfaces(fortran_reader):
assert result is routine_foo_b
print(" - Passed subtest foo_b[2]")



if 1:
routine_foo_c: Routine = root_node.walk(Routine)[3]
assert routine_foo_c.name == "foo_c"
Expand All @@ -1015,9 +1018,10 @@ def test_call_get_callee_6_interfaces(fortran_reader):

call_foo_c: Call = routine_main.walk(Call)[5]
assert call_foo_c.routine.name == "foo"

arg_idx_list = []
result: Routine = call_foo_c.get_callee(ret_arg_match_list=arg_idx_list)
result: Routine = call_foo_c.get_callee(
ret_arg_match_list=arg_idx_list)

print(f" - Found matching argument list: {arg_idx_list}")

Expand All @@ -1034,9 +1038,10 @@ def test_call_get_callee_6_interfaces(fortran_reader):

call_foo_c: Call = routine_main.walk(Call)[6]
assert call_foo_c.routine.name == "foo"

arg_idx_list = []
result: Routine = call_foo_c.get_callee(ret_arg_match_list=arg_idx_list)
result: Routine = call_foo_c.get_callee(
ret_arg_match_list=arg_idx_list)

print(f" - Found matching argument list: {arg_idx_list}")

Expand All @@ -1052,9 +1057,10 @@ def test_call_get_callee_6_interfaces(fortran_reader):

call_foo_c: Call = routine_main.walk(Call)[7]
assert call_foo_c.routine.name == "foo"

arg_idx_list = []
result: Routine = call_foo_c.get_callee(ret_arg_match_list=arg_idx_list)
result: Routine = call_foo_c.get_callee(
ret_arg_match_list=arg_idx_list)

print(f" - Found matching argument list: {arg_idx_list}")

Expand All @@ -1067,8 +1073,6 @@ def test_call_get_callee_6_interfaces(fortran_reader):
print(" - Passed subtest foo_c[2]")




@pytest.mark.usefixtures("clear_module_manager_instance")
def test_call_get_callees_unresolved(fortran_reader, tmpdir, monkeypatch):
'''
Expand Down

0 comments on commit b9b24a3

Please sign in to comment.