Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch listconstruct errors when dependent on inputs flexible shapes #2050

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,23 +626,74 @@ def unflatten(context, node):
def _array_construct(context, node, array_type):
assert len(node.outputs) == 1
inputs = _get_inputs(context, node)
scalar_inputs = [
inp
for inp in inputs
if isinstance(inp, Var) and inp.can_be_folded_to_const() and len(inp.shape) == 0
]

if len(scalar_inputs) == len(inputs):
is_all_const = all(map(lambda inp : isinstance(inp, Var) and inp.can_be_folded_to_const() and len(inp.shape) == 0, inputs))
if is_all_const:
# All the list items are compile-time scalar constants, so let's create
# a new const that concatenates them.
val = array_type([inp.val for inp in inputs])
const = mb.const(val=val, name=node.name)
context.add(const)
else:
# If at least one input to the construct op is non-const, collect
# the inputs and add them directly to the context. Ops that use this
# node's output will take the list directly as input.
context.add(array_type(inputs), node.name)
return

nodes = {n.name : n for n in context.torch_graph.nodes}
is_known_name = lambda name : name in nodes
inheriting_bookkeeping = {name : -1 for name in nodes.keys()}
def dfs_graph_input_dependent(inputs, non_const=None):
'''
inputs would be [] if all constant
otherwise further depend on each of their inputs, all the way to the root

if some name is not in context.torch_graph.nodes, then it should be a symbolic in graph input
'''
if non_const is None:
# init, effectively only at dfs.layer[0]
non_const = set()

# len(inputs) == 0 is dfs base
for i in inputs:
if is_known_name(i):
if inheriting_bookkeeping[i] == -1:
inheriting = dfs_graph_input_dependent(nodes[i].inputs, non_const)
inheriting_bookkeeping[i] = len(inheriting)
else:
non_const.add(i)
return non_const
any_inheriting = dfs_graph_input_dependent(node.inputs)
dependent_on_graph_input = len(any_inheriting) > 0

if dependent_on_graph_input:
to_concat = []
for input in node.inputs:
inheriting = inheriting_bookkeeping[input]
if inheriting <= 0:
# is const
to_concat.append([context[input].val])

else:
# is non_const
iter_node = nodes[input]
while all([is_known_name(i) for i in iter_node.inputs]):
iter_node = nodes[iter_node.inputs[0]]

if context[iter_node.name].op.op_type == 'gather':
non_const = iter_node.inputs[0]
non_const_name = iter_node.inputs[1]
non_const_idx = context[non_const_name].val
to_concat.append(mb.slice_by_size(x=mb.shape(x=context[non_const]), begin=[non_const_idx], size=[1]))

else:
to_concat = []
break

if len(to_concat) > 0:
context.add(mb.concat(values=to_concat, axis=0), node.name)
return

# If at least one input to the construct op is neither const nor symbolic, collect
# the inputs and add them directly to the context. Ops that use this
# node's output will take the list directly as input.
context.add(array_type(inputs), node.name)


@register_torch_op
Expand Down Expand Up @@ -1595,6 +1646,12 @@ def pad(context, node):
pad = pad.val.reshape((-1, 2))[::-1].reshape(-1).tolist()
missing_dims = x.rank - (len(pad) // 2)
pad = [0, 0] * missing_dims + pad
else:
missing_dims = (x.rank * 2 - pad.shape[0]) // 2
pad = mb.concat(values=[pad, [0, 0] * missing_dims], axis=0)
pad = mb.reshape(x=pad, shape=[-1,2])
pad = mb.reverse(x=pad, axes=[0])
pad = mb.reshape(x=pad, shape=[-1]) # mil.ops.defs.iOS15.pad asserts 1D tensor

if len(inputs) == 4:
mode = inputs[2].val
Expand Down