From 9477659b43e4121d708dc7289f56b94bb73be1d4 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sat, 29 Jun 2024 17:04:22 -0700 Subject: [PATCH] Update backend test (#990) --- source/browser.js | 18 ++++++++++++------ source/onnx.py | 40 ++++++++++++++++++++-------------------- source/pytorch.py | 44 +++++++++++++++++++++++--------------------- source/server.js | 32 ++++++++++++++++++-------------- source/server.py | 10 +++++++--- 5 files changed, 80 insertions(+), 64 deletions(-) diff --git a/source/browser.js b/source/browser.js index 5dd8ee0586..96bc3064b9 100644 --- a/source/browser.js +++ b/source/browser.js @@ -148,7 +148,8 @@ host.BrowserHost = class { const [url] = this._meta.file; if (this._view.accept(url)) { const identifier = Array.isArray(this._meta.identifier) && this._meta.identifier.length === 1 ? this._meta.identifier[0] : null; - const status = await this._openModel(this._url(url), identifier || null); + const name = this._meta.name || null; + const status = await this._openModel(this._url(url), identifier || null, name); if (status === '') { return; } @@ -413,7 +414,7 @@ host.BrowserHost = class { return `${location.protocol}//${location.host}${pathname}${file}`; } - async _openModel(url, identifier) { + async _openModel(url, identifier, name) { url = url.startsWith('data:') ? url : `${url + ((/\?/).test(url) ? '&' : '?')}cb=${(new Date()).getTime()}`; this._view.show('welcome spinner'); let context = null; @@ -430,7 +431,7 @@ host.BrowserHost = class { stream = await this._request(url, null, null, progress); } } - context = new host.BrowserHost.Context(this, url, identifier, stream); + context = new host.BrowserHost.Context(this, url, identifier, name, stream); this._telemetry.set('session_engaged', 1); } catch (error) { await this._view.error(error, 'Model load request failed.'); @@ -474,7 +475,7 @@ host.BrowserHost = class { const encoder = new TextEncoder(); const buffer = encoder.encode(file.content); const stream = new base.BinaryStream(buffer); - const context = new host.BrowserHost.Context(this, '', identifier, stream); + const context = new host.BrowserHost.Context(this, '', identifier, null, stream); await this._openContext(context); } catch (error) { await this._view.error(error, 'Error while loading Gist.'); @@ -487,7 +488,7 @@ host.BrowserHost = class { try { const model = await this._view.open(context); if (model) { - this.document.title = context.identifier; + this.document.title = context.name || context.identifier; return ''; } this.document.title = ''; @@ -787,8 +788,9 @@ host.BrowserHost.FileStream = class { host.BrowserHost.Context = class { - constructor(host, url, identifier, stream) { + constructor(host, url, identifier, name, stream) { this._host = host; + this._name = name; this._stream = stream; if (identifier) { this._identifier = identifier; @@ -807,6 +809,10 @@ host.BrowserHost.Context = class { return this._identifier; } + get name() { + return this._name; + } + get stream() { return this._stream; } diff --git a/source/onnx.py b/source/onnx.py index c5f31ce52c..7004e4669a 100644 --- a/source/onnx.py +++ b/source/onnx.py @@ -75,19 +75,19 @@ def _metadata_props(self, metadata_props): # pylint: disable=missing-function-do class _Graph: def __init__(self, graph, metadata): self.metadata = metadata - self.value = graph - self.arguments_index = {} - self.arguments = [] + self.graph = graph + self.values_index = {} + self.values = [] def _tensor(self, tensor): # pylint: disable=unused-argument return {} - def argument(self, name, tensor_type=None, initializer=None): # pylint: disable=missing-function-docstring - if not name in self.arguments_index: - argument = _Argument(name, tensor_type, initializer) - self.arguments_index[name] = len(self.arguments) - self.arguments.append(argument) - index = self.arguments_index[name] + def value(self, name, tensor_type=None, initializer=None): # pylint: disable=missing-function-docstring + if not name in self.values_index: + argument = _Value(name, tensor_type, initializer) + self.values_index[name] = len(self.values) + self.values.append(argument) + index = self.values_index[name] # argument.set_initializer(initializer) return index @@ -138,17 +138,17 @@ def attribute(self, _, op_type): # pylint: disable=missing-function-docstring,to return json_attribute def to_json(self): # pylint: disable=missing-function-docstring - graph = self.value + graph = self.graph json_graph = { 'nodes': [], 'inputs': [], 'outputs': [], - 'arguments': [] + 'values': [] } for value_info in graph.value_info: - self.argument(value_info.name) + self.value(value_info.name) for initializer in graph.initializer: - self.argument(initializer.name, None, initializer) + self.value(initializer.name, None, initializer) for node in graph.node: op_type = node.op_type json_node = {} @@ -164,24 +164,24 @@ def to_json(self): # pylint: disable=missing-function-docstring for value in node.input: json_node['inputs'].append({ 'name': 'X', - 'arguments': [ self.argument(value) ] + 'value': [ self.value(value) ] }) json_node['outputs'] = [] for value in node.output: json_node['outputs'].append({ 'name': 'X', - 'arguments': [ self.argument(value) ] + 'value': [ self.value(value) ] }) json_node['attributes'] = [] for _ in node.attribute: json_attribute = self.attribute(_, op_type) json_node['attributes'].append(json_attribute) json_graph['nodes'].append(json_node) - for _ in self.arguments: - json_graph['arguments'].append(_.to_json()) + for _ in self.values: + json_graph['values'].append(_.to_json()) return json_graph -class _Argument: # pylint: disable=too-few-public-methods +class _Value: # pylint: disable=too-few-public-methods def __init__(self, name, tensor_type=None, initializer=None): self.name = name self.type = tensor_type @@ -190,8 +190,8 @@ def __init__(self, name, tensor_type=None, initializer=None): def to_json(self): # pylint: disable=missing-function-docstring target = {} target['name'] = self.name - if self.initializer: - target['initializer'] = {} + # if self.initializer: + # target['initializer'] = {} return target class _Metadata: # pylint: disable=too-few-public-methods diff --git a/source/pytorch.py b/source/pytorch.py index ee20017908..55953c2779 100644 --- a/source/pytorch.py +++ b/source/pytorch.py @@ -56,7 +56,7 @@ def to_json(self): # pylint: disable=missing-function-docstring,too-many-locals, import torch # pylint: disable=import-outside-toplevel,import-error graph = self.value json_graph = { - 'arguments': [], + 'values': [], 'nodes': [], 'inputs': [], 'outputs': [] @@ -73,59 +73,61 @@ def constant_value(node): selector = node.kindOf('value') return getattr(node, selector)('value') return None - arguments_map = {} + values_index = {} def argument(value): - if not value in arguments_map: - json_argument = {} - json_argument['name'] = str(value.unique()) + if not value in values_index: + json_value = {} + json_value['name'] = str(value.unique()) node = value.node() if node.kind() == "prim::GetAttr": tensor, name = self._getattr(node) if tensor is not None and len(name) > 0 and \ isinstance(tensor, torch.Tensor): - json_argument['name'] = name - json_argument['initializer'] = {} json_tensor_shape = { 'dimensions': list(tensor.shape) } - json_argument['type'] = { + tensor_type = { 'dataType': data_type_map[tensor.dtype], 'shape': json_tensor_shape } + json_value['name'] = name + json_value['type'] = tensor_type + json_value['initializer'] = { 'type': tensor_type } elif node.kind() == "prim::Constant": tensor = constant_value(node) if tensor and isinstance(tensor, torch.Tensor): - json_argument['initializer'] = {} json_tensor_shape = { 'dimensions': list(tensor.shape) } - json_argument['type'] = { + tensor_type = { 'dataType': data_type_map[tensor.dtype], 'shape': json_tensor_shape } + json_value['type'] = tensor_type + json_value['initializer'] = { 'type': tensor_type } elif value.isCompleteTensor(): json_tensor_shape = { 'dimensions': value.type().sizes() } - json_argument['type'] = { + json_value['type'] = { 'dataType': data_type_map[value.type().dtype()], 'shape': json_tensor_shape } - arguments = json_graph['arguments'] - arguments_map[value] = len(arguments) - arguments.append(json_argument) - return arguments_map[value] + values = json_graph['values'] + values_index[value] = len(values) + values.append(json_value) + return values_index[value] for value in graph.inputs(): if len(value.uses()) != 0 and value.type().kind() != 'ClassType': json_graph['inputs'].append({ 'name': value.debugName(), - 'arguments': [ argument(value) ] + 'value': [ argument(value) ] }) for value in graph.outputs(): json_graph['outputs'].append({ 'name': value.debugName(), - 'arguments': [ argument(value) ] + 'value': [ argument(value) ] }) constants = {} for node in graph.nodes(): @@ -163,7 +165,7 @@ def create_node(node): if torch.is_tensor(value): json_node['inputs'].append({ 'name': name, - 'arguments': [] + 'value': [] }) else: json_node['attributes'].append(json_attribute) @@ -177,7 +179,7 @@ def create_node(node): if parameter_type == 'Tensor' or value.type().kind() == 'TensorType': json_node['inputs'].append({ 'name': parameter_name, - 'arguments': [ argument(value) ] + 'value': [ argument(value) ] }) else: json_attribute = { @@ -203,7 +205,7 @@ def create_node(node): continue json_node['inputs'].append({ 'name': parameter_name, - 'arguments': [ argument(value) ] + 'value': [ argument(value) ] }) for i, value in enumerate(node.outputs()): @@ -211,7 +213,7 @@ def create_node(node): name = parameter['name'] if parameter and 'name' in parameter else 'output' json_node['outputs'].append({ 'name': name, - 'arguments': [ argument(value) ] + 'value': [ argument(value) ] }) for node in graph.nodes(): diff --git a/source/server.js b/source/server.js index e6eab129cb..34c23479dc 100644 --- a/source/server.js +++ b/source/server.js @@ -47,25 +47,25 @@ message.Graph = class { this.inputs = []; this.outputs = []; this.nodes = []; - const args = data.arguments ? data.arguments.map((argument) => new message.Value(argument)) : []; - for (const parameter of data.inputs || []) { - parameter.arguments = parameter.arguments.map((index) => args[index]).filter((argument) => !argument.initializer); - if (parameter.arguments.filter((argument) => !argument.initializer).length > 0) { - this.inputs.push(new message.Argument(parameter)); + const values = data.values ? data.values.map((value) => new message.Value(value)) : []; + for (const argument of data.inputs || []) { + argument.value = argument.value.map((index) => values[index]).filter((argument) => !argument.initializer); + if (argument.value.filter((argument) => !argument.initializer).length > 0) { + this.inputs.push(new message.Argument(argument)); } } - for (const parameter of data.outputs || []) { - parameter.arguments = parameter.arguments.map((index) => args[index]); - if (parameter.arguments.filter((argument) => !argument.initializer).length > 0) { - this.outputs.push(new message.Argument(parameter)); + for (const argument of data.outputs || []) { + argument.value = argument.value.map((index) => values[index]); + if (argument.value.filter((argument) => !argument.initializer).length > 0) { + this.outputs.push(new message.Argument(argument)); } } for (const node of data.nodes || []) { - for (const parameter of node.inputs || []) { - parameter.arguments = parameter.arguments.map((index) => args[index]); + for (const argument of node.inputs || []) { + argument.value = argument.value.map((index) => values[index]); } - for (const parameter of node.outputs || []) { - parameter.arguments = parameter.arguments.map((index) => args[index]); + for (const argument of node.outputs || []) { + argument.value = argument.value.map((index) => values[index]); } this.nodes.push(new message.Node(node)); } @@ -76,7 +76,7 @@ message.Argument = class { constructor(data) { this.name = data.name || ''; - this.value = (data.arguments || []); + this.value = data.value || []; this.type = data.type || ''; } }; @@ -129,6 +129,10 @@ message.TensorShape = class { }; message.Tensor = class { + + constructor(data) { + this.type = new message.TensorType(data.type); + } }; message.Error = class extends Error { diff --git a/source/server.py b/source/server.py index 7c1f53ca8e..cb65d6ed14 100755 --- a/source/server.py +++ b/source/server.py @@ -22,9 +22,10 @@ class _ContentProvider: # pylint: disable=too-few-public-methods base_dir = '' base = '' identifier = '' - def __init__(self, data, path, file): + def __init__(self, data, path, file, name): self.data = data if data else bytearray() self.identifier = os.path.basename(file) if file else '' + self.name = name if path: self.dir = os.path.dirname(path) if os.path.dirname(path) else '.' self.base = os.path.basename(path) @@ -95,6 +96,9 @@ def do_GET(self): # pylint: disable=invalid-name base = self.content.base if base: meta.append('') + name = self.content.name + if name: + meta.append('') identifier = self.content.identifier if identifier: meta.append('') @@ -281,14 +285,14 @@ def serve(file, data, address=None, browse=False, verbosity=1): if not data and file and not os.path.exists(file): raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file) - content = _ContentProvider(data, file, file) + content = _ContentProvider(data, file, file, file) if data and not isinstance(data, bytearray) and isinstance(data.__class__, type): _log(verbosity > 1, 'Experimental\n') model = _open(data) if model: text = json.dumps(model.to_json(), indent=4, ensure_ascii=False) - content = _ContentProvider(text.encode('utf-8'), 'model.netron', file) + content = _ContentProvider(text.encode('utf-8'), 'model.netron', None, file) address = _make_address(address) if isinstance(address[1], int) and address[1] != 0: