Skip to content

Commit

Permalink
Add TorchScript schema parser (#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 27, 2022
1 parent 13d208e commit 43953cb
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ update: install
@./tools/om schema
@./tools/rknn schema
@./tools/paddle sync schema
@./tools/pytorch sync schema
@./tools/pytorch sync schema metadata
@./tools/sklearn sync install metadata
@./tools/tf sync install schema metadata
@./tools/uff schema
Expand Down
3 changes: 1 addition & 2 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -1185,8 +1185,7 @@
{ "name": "stride", "type": "int64[]" },
{ "name": "padding", "type": "int64[]", "default": 0 },
{ "name": "ceil_mode", "type": "boolean", "default": false },
{ "name": "count_include_pad", "type": "boolean", "default": true },
{ "name": "divisor_override" }
{ "name": "count_include_pad", "type": "boolean", "default": true }
],
"outputs": [
{ "name": "output", "type": "Tensor" }
Expand Down
221 changes: 221 additions & 0 deletions source/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def argument(value):
return arguments_map[value]

for _ in graph.inputs():
# if len(_.uses()) == 0:
# continue
json_graph['inputs'].append({
'name': _.debugName(),
'arguments': [ argument(_) ]
Expand All @@ -79,6 +81,15 @@ def argument(value):
'arguments': [ argument(_) ]
})
for node in graph.nodes():
# if node.kind() == 'prim::ListConstruct':
# continue
# if node.kind() == 'prim::Constant':
# continue
# if node.kind() == 'prim::GetAttr':
# continue
schema = node.schema() if hasattr(node, 'schema') else None
if schema and schema != '(no schema)':
schema = Schema(schema)
json_node = {
'type': { 'name': node.kind() },
'inputs': [],
Expand Down Expand Up @@ -113,3 +124,213 @@ def argument(value):
'arguments': [ argument(output_value) ]
})
return json_graph

class Schema: # pylint: disable=too-few-public-methods,missing-class-docstring
def __init__(self, value):
lexer = Schema.Lexer(value)
lexer.whitespace(0)
self._parse_name(lexer)
lexer.whitespace(0)
if lexer.kind == '(':
self._parse_arguments(lexer)
lexer.whitespace(0)
lexer.expect('->')
lexer.whitespace(0)
self._parse_returns(lexer)
def _parse_name(self, lexer):
self.name = lexer.expect('id')
if lexer.eat(':'):
lexer.expect(':')
self.name = self.name + '::' + lexer.expect('id')
if lexer.eat('.'):
self.name = self.name + '.' + lexer.expect('id')
def _parse_arguments(self, lexer):
self.arguments = []
self.is_vararg = False
self.kwarg_only = False
lexer.expect('(')
if not lexer.eat(')'):
while True:
lexer.whitespace(0)
if self.is_vararg:
raise Exception()
if lexer.eat('*'):
self.kwarg_only = True
elif lexer.eat('...'):
self.is_vararg = True
else:
self.arguments.append(Schema.Argument(lexer, False, self.kwarg_only))
lexer.whitespace(0)
if not lexer.eat(','):
break
lexer.expect(')')
def _parse_returns(self, lexer):
self.returns = []
self.is_varret = False
if lexer.eat('...'):
self.is_varret = True
elif lexer.eat('('):
lexer.whitespace(0)
if not lexer.eat(')'):
while True:
lexer.whitespace(0)
if self.is_varret:
raise Exception()
if lexer.eat('...'):
self.is_varret = True
else:
self.returns.append(Schema.Argument(lexer, True, False))
lexer.whitespace(0)
if not lexer.eat(','):
break
lexer.expect(')')
lexer.whitespace(0)
else:
self.returns.append(Schema.Argument(lexer, True, False))
class Argument: # pylint: disable=too-few-public-methods
def __init__(self, lexer, is_return, kwarg_only):
value = Schema.Type(lexer)
if lexer.eat('('):
while not lexer.eat(')'):
lexer.next()
while True:
if lexer.eat('['):
size = None
if lexer.kind == '#':
size = int(lexer.value)
lexer.next()
lexer.expect(']')
value = Schema.ListType(type, size)
elif lexer.eat('?'):
value = Schema.OptionalType(type)
else:
break
self.type = value
lexer.whitespace(0)
if is_return:
self.kwarg_only = False
if lexer.kind == 'id':
self.name = lexer.expect('id')
else:
self.kwarg_only = kwarg_only
self.name = lexer.expect('id')
lexer.whitespace(0)
if lexer.eat('='):
lexer.whitespace(0)
self.default = self._parse_value(lexer)
def _parse_value(self, lexer):
if lexer.kind == 'id':
if lexer.value in ('True', 'False'):
value = bool(lexer.value)
elif lexer.value == 'None':
value = None
elif lexer.value in ('Mean', 'contiguous_format', 'long'):
value = lexer.value
else:
raise Exception()
elif lexer.kind == '#':
value = float(lexer.value) if \
lexer.value.find('.') != -1 or lexer.value.find('e') != -1 else \
int(lexer.value)
elif lexer.kind == 'string':
value = lexer.value
elif lexer.eat('['):
value = []
if not lexer.eat(']'):
while True:
lexer.whitespace(0)
value.append(self._parse_value(lexer))
lexer.whitespace(0)
if not lexer.eat(','):
break
lexer.expect(']')
return value
else:
raise Exception()
lexer.next()
return value
class Type: # pylint: disable=too-few-public-methods,missing-class-docstring
def __init__(self, lexer):
self.name = lexer.expect('id')
while lexer.eat('.'):
self.name = self.name + '.' + lexer.expect('id')
class OptionalType: # pylint: disable=too-few-public-methods,missing-class-docstring
def __init__(self, element_type):
self.element_type = element_type
class ListType: # pylint: disable=too-few-public-methods,missing-class-docstring
def __init__(self, element_type, size):
self.element_type = element_type
self.size = size
class Lexer: # pylint: disable=too-few-public-methods,missing-class-docstring
def __init__(self, buffer):
self.buffer = buffer
self.position = 0
self.value = ''
self.next()
def eat(self, kind): # pylint: disable=missing-function-docstring
if self.kind != kind:
return None
value = self.value
self.next()
return value
def expect(self, kind): # pylint: disable=missing-function-docstring
if self.kind != kind:
raise Exception('')
value = self.value
self.next()
return value
def whitespace(self, count): # pylint: disable=missing-function-docstring
if self.kind != ' ':
if count > len(self.value):
raise Exception('')
return False
self.next()
return True
def next(self): # pylint: disable=missing-function-docstring,too-many-branches
self.position += len(self.value)
i = self.position
if i >= len(self.buffer):
self.kind = '\0'
self.value = ''
elif self.buffer[i] == ' ':
while self.buffer[i] == ' ':
i += 1
self.kind = ' '
self.value = self.buffer[self.position:i]
elif self.buffer[i] in ('(', ')', ':', '.', '[', ']', ',', '=', '?', '!', '*'):
self.kind = self.buffer[i]
self.value = self.buffer[i]
elif (self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
(self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or self.buffer[i] == '_':
i += 1
while i < len(self.buffer) and \
((self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \
(self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or \
(self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '_'):
i += 1
self.kind = 'id'
self.value = self.buffer[self.position:i]
elif self.buffer[i] == '-' and self.buffer[i+1] == '>':
self.kind = '->'
self.value = '->'
elif (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '-':
i += 1
while i < len(self.buffer) and \
((self.buffer[i] >= '0' and self.buffer[i] <= '9') or \
self.buffer[i] == '.' or self.buffer[i] == 'e' or self.buffer[i] == '-'):
i += 1
self.kind = '#'
self.value = self.buffer[self.position:i]
elif self.buffer[i] == '.' and self.buffer[i+1] == '.' and self.buffer[i+2] == '.':
self.kind = '...'
self.value = '...'
elif self.buffer[i] in ("'", '"'):
quote = self.buffer[i]
i += 1
while i < len(self.buffer) and self.buffer[i] != quote:
i += 2 if self.buffer[i] == '\\' and self.buffer[i+1] in ("'", '"', '\\') else 1
i += 1
self.kind = 'string'
self.value = self.buffer[self.position:i]
else:
raise Exception("Unsupported token at " + self.position)
11 changes: 10 additions & 1 deletion tools/pytorch
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,20 @@ schema() {
fi
}

metadata() {
echo "pytorch metadata"
[[ $(grep -U $'\x0D' ./source/pytorch-metadata.json) ]] && crlf=1
${python} ./tools/pytorch_metadata.py
if [[ -n ${crlf} ]]; then
unix2dos --quiet --newfile ./source/pytorch-metadata.json ./source/pytorch-metadata.json
fi
}

while [ "$#" != 0 ]; do
command="$1" && shift
case "${command}" in
"clean") clean;;
"sync") sync;;
"schema") schema;;
"metadata") metadata;;
esac
done
95 changes: 95 additions & 0 deletions tools/pytorch_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
''' TorchScript metadata script '''

import collections
import json
import os
import re
import sys

root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(root_dir)
sys.pycache_prefix = os.path.join(root_dir, 'dist', 'pycache', 'test', 'backend')
pytorch = __import__('source.pytorch').pytorch

source_dir = os.path.join(root_dir, 'source')
third_party_dir = os.path.join(root_dir, 'third_party')
metadata_file = os.path.join(source_dir, 'pytorch-metadata.json')
pytorch_source_dir = os.path.join(third_party_dir, 'source', 'pytorch')

def _read(path):
with open(path, 'r', encoding='utf-8') as file:
return file.read()

def _write(path, content):
with open(path, 'w', encoding='utf-8') as file:
file.write(content)

def _read_metadata():
metadata = json.loads(_read(metadata_file))
return dict(map(lambda item: [ item['name'], item ], metadata))

def _write_metadata(value):
metadata = list(collections.OrderedDict(sorted(value.items())).values())
content = json.dumps(metadata, indent=2, ensure_ascii=False)
content = re.sub(r'\s {8}', ' ', content)
content = re.sub(r',\s {8}', ', ', content)
content = re.sub(r'\s {6}}', ' }', content)
_write(metadata_file, content)

schema_source_files = [
('aten/src/ATen/native/native_functions.yaml',
re.compile(r'-\s*func:\s*(.*)', re.MULTILINE), 'aten::'),
('aten/src/ATen/native/quantized/library.cpp',
re.compile(r'TORCH_SELECTIVE_SCHEMA\("(.*)"\)', re.MULTILINE)),
('aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp',
re.compile(r'TORCH_SELECTIVE_SCHEMA\("(.*)"', re.MULTILINE)),
('torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp',
re.compile(r'(aten::.*->\s*Tensor)', re.MULTILINE)),
('torch/csrc/jit/passes/shape_analysis.cpp',
re.compile(r'(aten::.*->\s*Tensor)', re.MULTILINE)),
('caffe2/operators/copy_op.cc',
re.compile(r'(_caffe2::.*->\s*Tensor)', re.MULTILINE)),
('caffe2/operators/batch_permutation_op.cc',
re.compile(r'(_caffe2::.*->\s*Tensor)', re.MULTILINE)),
('caffe2/operators/collect_and_distribute_fpn_rpn_proposals_op.cc',
re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)),
('caffe2/operators/box_with_nms_limit_op.cc',
re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)),
('caffe2/operators/bbox_transform_op.cc',
re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)),
('caffe2/operators/generate_proposals_op.cc',
re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)),
('caffe2/operators/roi_align_op.cc',
re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->.*)"', re.MULTILINE))
]

def _metadata():

types = _read_metadata()

_write_metadata(types)

for key in list(types.keys()):
if key.startswith('torch.nn'):
types.pop(key)

for entry in schema_source_files:
path = os.path.join(pytorch_source_dir, entry[0])
content = _read(path)
for value in entry[1].findall(content):
value = re.sub(r'\n|\r|\s*"', '', value) if value.startswith('_caffe2::') else value
definition = entry[2] + value if len(entry) > 2 else value
schema = pytorch.Schema(definition)
if schema.name in types:
# value = types[schema.name]
# if len(schema.arguments) != len(value['inputs']):
# pass
types.pop(schema.name)

# print('\n'.join(list(types.keys())))

def main(): # pylint: disable=missing-function-docstring
_metadata()

if __name__ == '__main__':
main()

0 comments on commit 43953cb

Please sign in to comment.