forked from onnx/onnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
checker.py
85 lines (60 loc) · 2.59 KB
/
checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""onnx checker
This implements graphalities that allows us to check whether a serialized
proto is legal.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import functools
from onnx import (ValueInfoProto,
AttributeProto,
TensorProto,
NodeProto,
ModelProto,
GraphProto,
IR_VERSION)
import onnx.onnx_cpp2py_export.checker as C
import onnx.defs
from google.protobuf.message import Message
from typing import TypeVar, Callable, Any, Type, cast
# TODO: This thing where we reserialize the protobuf back into the
# string, only to deserialize it at the call site, is really goofy.
# Stop doing that.
# NB: Please don't edit this context!
DEFAULT_CONTEXT = C.CheckerContext()
DEFAULT_CONTEXT.ir_version = IR_VERSION
# TODO: Maybe ONNX-ML should also be defaulted?
DEFAULT_CONTEXT.opset_imports = {'': onnx.defs.onnx_opset_version()}
FuncType = TypeVar('FuncType', bound=Callable[..., Any])
# TODO: This really doesn't seem worth the metaprogramming...
def _create_checker(proto_type): # type: (Type[Message]) -> Callable[[FuncType], FuncType]
def decorator(py_func): # type: (FuncType) -> FuncType
@functools.wraps(py_func)
def checker(proto, ctx=DEFAULT_CONTEXT): # type: (Message, C.CheckerContext) -> Any
if not isinstance(proto, proto_type):
raise RuntimeError(
'You cannot pass an object that is not of type {}'.format(
proto_type.__name__))
return getattr(C, py_func.__name__)(
proto.SerializeToString(), ctx)
return cast(FuncType, checker)
return decorator
@_create_checker(ValueInfoProto)
def check_value_info(value_info, ctx=DEFAULT_CONTEXT): # type: (ValueInfoProto, C.CheckerContext) -> None
pass
@_create_checker(TensorProto)
def check_tensor(tensor, ctx=DEFAULT_CONTEXT): # type: (TensorProto, C.CheckerContext) -> None
pass
@_create_checker(AttributeProto)
def check_attribute(attr, ctx=DEFAULT_CONTEXT): # type: (AttributeProto, C.CheckerContext) -> None
pass
@_create_checker(NodeProto)
def check_node(node, ctx=DEFAULT_CONTEXT): # type: (NodeProto, C.CheckerContext) -> None
pass
@_create_checker(GraphProto)
def check_graph(graph, ctx=DEFAULT_CONTEXT): # type: (GraphProto, C.CheckerContext) -> None
pass
def check_model(model): # type: (ModelProto) -> None
C.check_model(model.SerializeToString())
ValidationError = C.ValidationError