From 59e47369120471566da62ab3e2fe8fcdb5d07b55 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Wed, 26 Oct 2022 18:26:56 -0700 Subject: [PATCH] Add TorchScript schema parser (#990) --- Makefile | 2 +- source/pytorch-metadata.json | 3 +- source/pytorch.py | 221 +++++++++++++++++++++++++++++++++++ tools/pytorch | 11 +- tools/pytorch_metadata.py | 95 +++++++++++++++ 5 files changed, 328 insertions(+), 4 deletions(-) create mode 100644 tools/pytorch_metadata.py diff --git a/Makefile b/Makefile index eaa0467a042..a227c6d9555 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ update: install @./tools/om schema @./tools/rknn schema @./tools/paddle sync schema - @./tools/pytorch sync schema + @./tools/pytorch sync schema metadata @./tools/sklearn sync install metadata @./tools/tf sync install schema metadata @./tools/uff schema diff --git a/source/pytorch-metadata.json b/source/pytorch-metadata.json index 0ede132f68d..3b79217878d 100755 --- a/source/pytorch-metadata.json +++ b/source/pytorch-metadata.json @@ -1185,8 +1185,7 @@ { "name": "stride", "type": "int64[]" }, { "name": "padding", "type": "int64[]", "default": 0 }, { "name": "ceil_mode", "type": "boolean", "default": false }, - { "name": "count_include_pad", "type": "boolean", "default": true }, - { "name": "divisor_override" } + { "name": "count_include_pad", "type": "boolean", "default": true } ], "outputs": [ { "name": "output", "type": "Tensor" } diff --git a/source/pytorch.py b/source/pytorch.py index c45a4d0b14d..2e3521ea4b3 100644 --- a/source/pytorch.py +++ b/source/pytorch.py @@ -69,6 +69,8 @@ def argument(value): return arguments_map[value] for _ in graph.inputs(): + # if len(_.uses()) == 0: + # continue json_graph['inputs'].append({ 'name': _.debugName(), 'arguments': [ argument(_) ] @@ -79,6 +81,15 @@ def argument(value): 'arguments': [ argument(_) ] }) for node in graph.nodes(): + # if node.kind() == 'prim::ListConstruct': + # continue + # if node.kind() == 'prim::Constant': + # continue + # if node.kind() == 'prim::GetAttr': + # continue + schema = node.schema() if hasattr(node, 'schema') else None + if schema and schema != '(no schema)': + schema = Schema(schema) json_node = { 'type': { 'name': node.kind() }, 'inputs': [], @@ -113,3 +124,213 @@ def argument(value): 'arguments': [ argument(output_value) ] }) return json_graph + +class Schema: # pylint: disable=too-few-public-methods,missing-class-docstring + def __init__(self, value): + lexer = Schema.Lexer(value) + lexer.whitespace(0) + self._parse_name(lexer) + lexer.whitespace(0) + if lexer.kind == '(': + self._parse_arguments(lexer) + lexer.whitespace(0) + lexer.expect('->') + lexer.whitespace(0) + self._parse_returns(lexer) + def _parse_name(self, lexer): + self.name = lexer.expect('id') + if lexer.eat(':'): + lexer.expect(':') + self.name = self.name + '::' + lexer.expect('id') + if lexer.eat('.'): + self.name = self.name + '.' + lexer.expect('id') + def _parse_arguments(self, lexer): + self.arguments = [] + self.is_vararg = False + self.kwarg_only = False + lexer.expect('(') + if not lexer.eat(')'): + while True: + lexer.whitespace(0) + if self.is_vararg: + raise Exception() + if lexer.eat('*'): + self.kwarg_only = True + elif lexer.eat('...'): + self.is_vararg = True + else: + self.arguments.append(Schema.Argument(lexer, False, self.kwarg_only)) + lexer.whitespace(0) + if not lexer.eat(','): + break + lexer.expect(')') + def _parse_returns(self, lexer): + self.returns = [] + self.is_varret = False + if lexer.eat('...'): + self.is_varret = True + elif lexer.eat('('): + lexer.whitespace(0) + if not lexer.eat(')'): + while True: + lexer.whitespace(0) + if self.is_varret: + raise Exception() + if lexer.eat('...'): + self.is_varret = True + else: + self.returns.append(Schema.Argument(lexer, True, False)) + lexer.whitespace(0) + if not lexer.eat(','): + break + lexer.expect(')') + lexer.whitespace(0) + else: + self.returns.append(Schema.Argument(lexer, True, False)) + class Argument: # pylint: disable=too-few-public-methods + def __init__(self, lexer, is_return, kwarg_only): + value = Schema.Type(lexer) + if lexer.eat('('): + while not lexer.eat(')'): + lexer.next() + while True: + if lexer.eat('['): + size = None + if lexer.kind == '#': + size = int(lexer.value) + lexer.next() + lexer.expect(']') + value = Schema.ListType(type, size) + elif lexer.eat('?'): + value = Schema.OptionalType(type) + else: + break + self.type = value + lexer.whitespace(0) + if is_return: + self.kwarg_only = False + if lexer.kind == 'id': + self.name = lexer.expect('id') + else: + self.kwarg_only = kwarg_only + self.name = lexer.expect('id') + lexer.whitespace(0) + if lexer.eat('='): + lexer.whitespace(0) + self.default = self._parse_value(lexer) + def _parse_value(self, lexer): + if lexer.kind == 'id': + if lexer.value in ('True', 'False'): + value = bool(lexer.value) + elif lexer.value == 'None': + value = None + elif lexer.value in ('Mean', 'contiguous_format', 'long'): + value = lexer.value + else: + raise Exception() + elif lexer.kind == '#': + value = float(lexer.value) if \ + lexer.value.find('.') != -1 or lexer.value.find('e') != -1 else \ + int(lexer.value) + elif lexer.kind == 'string': + value = lexer.value + elif lexer.eat('['): + value = [] + if not lexer.eat(']'): + while True: + lexer.whitespace(0) + value.append(self._parse_value(lexer)) + lexer.whitespace(0) + if not lexer.eat(','): + break + lexer.expect(']') + return value + else: + raise Exception() + lexer.next() + return value + class Type: # pylint: disable=too-few-public-methods,missing-class-docstring + def __init__(self, lexer): + self.name = lexer.expect('id') + while lexer.eat('.'): + self.name = self.name + '.' + lexer.expect('id') + class OptionalType: # pylint: disable=too-few-public-methods,missing-class-docstring + def __init__(self, element_type): + self.element_type = element_type + class ListType: # pylint: disable=too-few-public-methods,missing-class-docstring + def __init__(self, element_type, size): + self.element_type = element_type + self.size = size + class Lexer: # pylint: disable=too-few-public-methods,missing-class-docstring + def __init__(self, buffer): + self.buffer = buffer + self.position = 0 + self.value = '' + self.next() + def eat(self, kind): # pylint: disable=missing-function-docstring + if self.kind != kind: + return None + value = self.value + self.next() + return value + def expect(self, kind): # pylint: disable=missing-function-docstring + if self.kind != kind: + raise Exception('') + value = self.value + self.next() + return value + def whitespace(self, count): # pylint: disable=missing-function-docstring + if self.kind != ' ': + if count > len(self.value): + raise Exception('') + return False + self.next() + return True + def next(self): # pylint: disable=missing-function-docstring,too-many-branches + self.position += len(self.value) + i = self.position + if i >= len(self.buffer): + self.kind = '\0' + self.value = '' + elif self.buffer[i] == ' ': + while self.buffer[i] == ' ': + i += 1 + self.kind = ' ' + self.value = self.buffer[self.position:i] + elif self.buffer[i] in ('(', ')', ':', '.', '[', ']', ',', '=', '?', '!', '*'): + self.kind = self.buffer[i] + self.value = self.buffer[i] + elif (self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \ + (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or self.buffer[i] == '_': + i += 1 + while i < len(self.buffer) and \ + ((self.buffer[i] >= 'a' and self.buffer[i] <= 'z') or \ + (self.buffer[i] >= 'A' and self.buffer[i] <= 'Z') or \ + (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '_'): + i += 1 + self.kind = 'id' + self.value = self.buffer[self.position:i] + elif self.buffer[i] == '-' and self.buffer[i+1] == '>': + self.kind = '->' + self.value = '->' + elif (self.buffer[i] >= '0' and self.buffer[i] <= '9') or self.buffer[i] == '-': + i += 1 + while i < len(self.buffer) and \ + ((self.buffer[i] >= '0' and self.buffer[i] <= '9') or \ + self.buffer[i] == '.' or self.buffer[i] == 'e' or self.buffer[i] == '-'): + i += 1 + self.kind = '#' + self.value = self.buffer[self.position:i] + elif self.buffer[i] == '.' and self.buffer[i+1] == '.' and self.buffer[i+2] == '.': + self.kind = '...' + self.value = '...' + elif self.buffer[i] in ("'", '"'): + quote = self.buffer[i] + i += 1 + while i < len(self.buffer) and self.buffer[i] != quote: + i += 2 if self.buffer[i] == '\\' and self.buffer[i+1] in ("'", '"', '\\') else 1 + i += 1 + self.kind = 'string' + self.value = self.buffer[self.position:i] + else: + raise Exception("Unsupported token at " + self.position) diff --git a/tools/pytorch b/tools/pytorch index 0fc423984f1..d072c9b449c 100755 --- a/tools/pytorch +++ b/tools/pytorch @@ -28,11 +28,20 @@ schema() { fi } +metadata() { + echo "pytorch metadata" + [[ $(grep -U $'\x0D' ./source/pytorch-metadata.json) ]] && crlf=1 + ${python} ./tools/pytorch_metadata.py + if [[ -n ${crlf} ]]; then + unix2dos --quiet --newfile ./source/pytorch-metadata.json ./source/pytorch-metadata.json + fi +} + while [ "$#" != 0 ]; do command="$1" && shift case "${command}" in "clean") clean;; "sync") sync;; - "schema") schema;; + "metadata") metadata;; esac done diff --git a/tools/pytorch_metadata.py b/tools/pytorch_metadata.py new file mode 100644 index 00000000000..8c48652c8f6 --- /dev/null +++ b/tools/pytorch_metadata.py @@ -0,0 +1,95 @@ +''' TorchScript metadata script ''' + +import collections +import json +import os +import re +import sys + +root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(root_dir) +sys.pycache_prefix = os.path.join(root_dir, 'dist', 'pycache', 'test', 'backend') +pytorch = __import__('source.pytorch').pytorch + +source_dir = os.path.join(root_dir, 'source') +third_party_dir = os.path.join(root_dir, 'third_party') +metadata_file = os.path.join(source_dir, 'pytorch-metadata.json') +pytorch_source_dir = os.path.join(third_party_dir, 'source', 'pytorch') + +def _read(path): + with open(path, 'r', encoding='utf-8') as file: + return file.read() + +def _write(path, content): + with open(path, 'w', encoding='utf-8') as file: + file.write(content) + +def _read_metadata(): + metadata = json.loads(_read(metadata_file)) + return dict(map(lambda item: [ item['name'], item ], metadata)) + +def _write_metadata(value): + metadata = list(collections.OrderedDict(sorted(value.items())).values()) + content = json.dumps(metadata, indent=2, ensure_ascii=False) + content = re.sub(r'\s {8}', ' ', content) + content = re.sub(r',\s {8}', ', ', content) + content = re.sub(r'\s {6}}', ' }', content) + _write(metadata_file, content) + +schema_source_files = [ + ('aten/src/ATen/native/native_functions.yaml', + re.compile(r'-\s*func:\s*(.*)', re.MULTILINE), 'aten::'), + ('aten/src/ATen/native/quantized/library.cpp', + re.compile(r'TORCH_SELECTIVE_SCHEMA\("(.*)"\)', re.MULTILINE)), + ('aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp', + re.compile(r'TORCH_SELECTIVE_SCHEMA\("(.*)"', re.MULTILINE)), + ('torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp', + re.compile(r'(aten::.*->\s*Tensor)', re.MULTILINE)), + ('torch/csrc/jit/passes/shape_analysis.cpp', + re.compile(r'(aten::.*->\s*Tensor)', re.MULTILINE)), + ('caffe2/operators/copy_op.cc', + re.compile(r'(_caffe2::.*->\s*Tensor)', re.MULTILINE)), + ('caffe2/operators/batch_permutation_op.cc', + re.compile(r'(_caffe2::.*->\s*Tensor)', re.MULTILINE)), + ('caffe2/operators/collect_and_distribute_fpn_rpn_proposals_op.cc', + re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)), + ('caffe2/operators/box_with_nms_limit_op.cc', + re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)), + ('caffe2/operators/bbox_transform_op.cc', + re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)), + ('caffe2/operators/generate_proposals_op.cc', + re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->\s*\([\w"\s\[\],]*\))"', re.MULTILINE)), + ('caffe2/operators/roi_align_op.cc', + re.compile(r'"(_caffe2::[\w+]*\([\w"\s\[\],]*\)\s*->.*)"', re.MULTILINE)) +] + +def _metadata(): + + types = _read_metadata() + + _write_metadata(types) + + for key in list(types.keys()): + if key.startswith('torch.nn'): + types.pop(key) + + for entry in schema_source_files: + path = os.path.join(pytorch_source_dir, entry[0]) + content = _read(path) + for value in entry[1].findall(content): + value = re.sub(r'\n|\r|\s*"', '', value) if value.startswith('_caffe2::') else value + definition = entry[2] + value if len(entry) > 2 else value + schema = pytorch.Schema(definition) + if schema.name in types: + # value = types[schema.name] + # if len(schema.arguments) != len(value['inputs']): + # pass + types.pop(schema.name) + + print('\n'.join(list(types.keys()))) + +def main(): # pylint: disable=missing-function-docstring + _metadata() + +if __name__ == '__main__': + main()