diff --git a/source/browser.js b/source/browser.js
index 5dd8ee0586..96bc3064b9 100644
--- a/source/browser.js
+++ b/source/browser.js
@@ -148,7 +148,8 @@ host.BrowserHost = class {
const [url] = this._meta.file;
if (this._view.accept(url)) {
const identifier = Array.isArray(this._meta.identifier) && this._meta.identifier.length === 1 ? this._meta.identifier[0] : null;
- const status = await this._openModel(this._url(url), identifier || null);
+ const name = this._meta.name || null;
+ const status = await this._openModel(this._url(url), identifier || null, name);
if (status === '') {
return;
}
@@ -413,7 +414,7 @@ host.BrowserHost = class {
return `${location.protocol}//${location.host}${pathname}${file}`;
}
- async _openModel(url, identifier) {
+ async _openModel(url, identifier, name) {
url = url.startsWith('data:') ? url : `${url + ((/\?/).test(url) ? '&' : '?')}cb=${(new Date()).getTime()}`;
this._view.show('welcome spinner');
let context = null;
@@ -430,7 +431,7 @@ host.BrowserHost = class {
stream = await this._request(url, null, null, progress);
}
}
- context = new host.BrowserHost.Context(this, url, identifier, stream);
+ context = new host.BrowserHost.Context(this, url, identifier, name, stream);
this._telemetry.set('session_engaged', 1);
} catch (error) {
await this._view.error(error, 'Model load request failed.');
@@ -474,7 +475,7 @@ host.BrowserHost = class {
const encoder = new TextEncoder();
const buffer = encoder.encode(file.content);
const stream = new base.BinaryStream(buffer);
- const context = new host.BrowserHost.Context(this, '', identifier, stream);
+ const context = new host.BrowserHost.Context(this, '', identifier, null, stream);
await this._openContext(context);
} catch (error) {
await this._view.error(error, 'Error while loading Gist.');
@@ -487,7 +488,7 @@ host.BrowserHost = class {
try {
const model = await this._view.open(context);
if (model) {
- this.document.title = context.identifier;
+ this.document.title = context.name || context.identifier;
return '';
}
this.document.title = '';
@@ -787,8 +788,9 @@ host.BrowserHost.FileStream = class {
host.BrowserHost.Context = class {
- constructor(host, url, identifier, stream) {
+ constructor(host, url, identifier, name, stream) {
this._host = host;
+ this._name = name;
this._stream = stream;
if (identifier) {
this._identifier = identifier;
@@ -807,6 +809,10 @@ host.BrowserHost.Context = class {
return this._identifier;
}
+ get name() {
+ return this._name;
+ }
+
get stream() {
return this._stream;
}
diff --git a/source/onnx.py b/source/onnx.py
index c5f31ce52c..7004e4669a 100644
--- a/source/onnx.py
+++ b/source/onnx.py
@@ -75,19 +75,19 @@ def _metadata_props(self, metadata_props): # pylint: disable=missing-function-do
class _Graph:
def __init__(self, graph, metadata):
self.metadata = metadata
- self.value = graph
- self.arguments_index = {}
- self.arguments = []
+ self.graph = graph
+ self.values_index = {}
+ self.values = []
def _tensor(self, tensor): # pylint: disable=unused-argument
return {}
- def argument(self, name, tensor_type=None, initializer=None): # pylint: disable=missing-function-docstring
- if not name in self.arguments_index:
- argument = _Argument(name, tensor_type, initializer)
- self.arguments_index[name] = len(self.arguments)
- self.arguments.append(argument)
- index = self.arguments_index[name]
+ def value(self, name, tensor_type=None, initializer=None): # pylint: disable=missing-function-docstring
+ if not name in self.values_index:
+ argument = _Value(name, tensor_type, initializer)
+ self.values_index[name] = len(self.values)
+ self.values.append(argument)
+ index = self.values_index[name]
# argument.set_initializer(initializer)
return index
@@ -138,17 +138,17 @@ def attribute(self, _, op_type): # pylint: disable=missing-function-docstring,to
return json_attribute
def to_json(self): # pylint: disable=missing-function-docstring
- graph = self.value
+ graph = self.graph
json_graph = {
'nodes': [],
'inputs': [],
'outputs': [],
- 'arguments': []
+ 'values': []
}
for value_info in graph.value_info:
- self.argument(value_info.name)
+ self.value(value_info.name)
for initializer in graph.initializer:
- self.argument(initializer.name, None, initializer)
+ self.value(initializer.name, None, initializer)
for node in graph.node:
op_type = node.op_type
json_node = {}
@@ -164,24 +164,24 @@ def to_json(self): # pylint: disable=missing-function-docstring
for value in node.input:
json_node['inputs'].append({
'name': 'X',
- 'arguments': [ self.argument(value) ]
+ 'value': [ self.value(value) ]
})
json_node['outputs'] = []
for value in node.output:
json_node['outputs'].append({
'name': 'X',
- 'arguments': [ self.argument(value) ]
+ 'value': [ self.value(value) ]
})
json_node['attributes'] = []
for _ in node.attribute:
json_attribute = self.attribute(_, op_type)
json_node['attributes'].append(json_attribute)
json_graph['nodes'].append(json_node)
- for _ in self.arguments:
- json_graph['arguments'].append(_.to_json())
+ for _ in self.values:
+ json_graph['values'].append(_.to_json())
return json_graph
-class _Argument: # pylint: disable=too-few-public-methods
+class _Value: # pylint: disable=too-few-public-methods
def __init__(self, name, tensor_type=None, initializer=None):
self.name = name
self.type = tensor_type
@@ -190,8 +190,8 @@ def __init__(self, name, tensor_type=None, initializer=None):
def to_json(self): # pylint: disable=missing-function-docstring
target = {}
target['name'] = self.name
- if self.initializer:
- target['initializer'] = {}
+ # if self.initializer:
+ # target['initializer'] = {}
return target
class _Metadata: # pylint: disable=too-few-public-methods
diff --git a/source/pytorch.py b/source/pytorch.py
index ee20017908..55953c2779 100644
--- a/source/pytorch.py
+++ b/source/pytorch.py
@@ -56,7 +56,7 @@ def to_json(self): # pylint: disable=missing-function-docstring,too-many-locals,
import torch # pylint: disable=import-outside-toplevel,import-error
graph = self.value
json_graph = {
- 'arguments': [],
+ 'values': [],
'nodes': [],
'inputs': [],
'outputs': []
@@ -73,59 +73,61 @@ def constant_value(node):
selector = node.kindOf('value')
return getattr(node, selector)('value')
return None
- arguments_map = {}
+ values_index = {}
def argument(value):
- if not value in arguments_map:
- json_argument = {}
- json_argument['name'] = str(value.unique())
+ if not value in values_index:
+ json_value = {}
+ json_value['name'] = str(value.unique())
node = value.node()
if node.kind() == "prim::GetAttr":
tensor, name = self._getattr(node)
if tensor is not None and len(name) > 0 and \
isinstance(tensor, torch.Tensor):
- json_argument['name'] = name
- json_argument['initializer'] = {}
json_tensor_shape = {
'dimensions': list(tensor.shape)
}
- json_argument['type'] = {
+ tensor_type = {
'dataType': data_type_map[tensor.dtype],
'shape': json_tensor_shape
}
+ json_value['name'] = name
+ json_value['type'] = tensor_type
+ json_value['initializer'] = { 'type': tensor_type }
elif node.kind() == "prim::Constant":
tensor = constant_value(node)
if tensor and isinstance(tensor, torch.Tensor):
- json_argument['initializer'] = {}
json_tensor_shape = {
'dimensions': list(tensor.shape)
}
- json_argument['type'] = {
+ tensor_type = {
'dataType': data_type_map[tensor.dtype],
'shape': json_tensor_shape
}
+ json_value['type'] = tensor_type
+ json_value['initializer'] = { 'type': tensor_type }
elif value.isCompleteTensor():
json_tensor_shape = {
'dimensions': value.type().sizes()
}
- json_argument['type'] = {
+ json_value['type'] = {
'dataType': data_type_map[value.type().dtype()],
'shape': json_tensor_shape
}
- arguments = json_graph['arguments']
- arguments_map[value] = len(arguments)
- arguments.append(json_argument)
- return arguments_map[value]
+ values = json_graph['values']
+ values_index[value] = len(values)
+ values.append(json_value)
+ return values_index[value]
for value in graph.inputs():
if len(value.uses()) != 0 and value.type().kind() != 'ClassType':
json_graph['inputs'].append({
'name': value.debugName(),
- 'arguments': [ argument(value) ]
+ 'value': [ argument(value) ]
})
for value in graph.outputs():
json_graph['outputs'].append({
'name': value.debugName(),
- 'arguments': [ argument(value) ]
+ 'value': [ argument(value) ]
})
constants = {}
for node in graph.nodes():
@@ -163,7 +165,7 @@ def create_node(node):
if torch.is_tensor(value):
json_node['inputs'].append({
'name': name,
- 'arguments': []
+ 'value': []
})
else:
json_node['attributes'].append(json_attribute)
@@ -177,7 +179,7 @@ def create_node(node):
if parameter_type == 'Tensor' or value.type().kind() == 'TensorType':
json_node['inputs'].append({
'name': parameter_name,
- 'arguments': [ argument(value) ]
+ 'value': [ argument(value) ]
})
else:
json_attribute = {
@@ -203,7 +205,7 @@ def create_node(node):
continue
json_node['inputs'].append({
'name': parameter_name,
- 'arguments': [ argument(value) ]
+ 'value': [ argument(value) ]
})
for i, value in enumerate(node.outputs()):
@@ -211,7 +213,7 @@ def create_node(node):
name = parameter['name'] if parameter and 'name' in parameter else 'output'
json_node['outputs'].append({
'name': name,
- 'arguments': [ argument(value) ]
+ 'value': [ argument(value) ]
})
for node in graph.nodes():
diff --git a/source/server.js b/source/server.js
index e6eab129cb..34c23479dc 100644
--- a/source/server.js
+++ b/source/server.js
@@ -47,25 +47,25 @@ message.Graph = class {
this.inputs = [];
this.outputs = [];
this.nodes = [];
- const args = data.arguments ? data.arguments.map((argument) => new message.Value(argument)) : [];
- for (const parameter of data.inputs || []) {
- parameter.arguments = parameter.arguments.map((index) => args[index]).filter((argument) => !argument.initializer);
- if (parameter.arguments.filter((argument) => !argument.initializer).length > 0) {
- this.inputs.push(new message.Argument(parameter));
+ const values = data.values ? data.values.map((value) => new message.Value(value)) : [];
+ for (const argument of data.inputs || []) {
+ argument.value = argument.value.map((index) => values[index]).filter((argument) => !argument.initializer);
+ if (argument.value.filter((argument) => !argument.initializer).length > 0) {
+ this.inputs.push(new message.Argument(argument));
}
}
- for (const parameter of data.outputs || []) {
- parameter.arguments = parameter.arguments.map((index) => args[index]);
- if (parameter.arguments.filter((argument) => !argument.initializer).length > 0) {
- this.outputs.push(new message.Argument(parameter));
+ for (const argument of data.outputs || []) {
+ argument.value = argument.value.map((index) => values[index]);
+ if (argument.value.filter((argument) => !argument.initializer).length > 0) {
+ this.outputs.push(new message.Argument(argument));
}
}
for (const node of data.nodes || []) {
- for (const parameter of node.inputs || []) {
- parameter.arguments = parameter.arguments.map((index) => args[index]);
+ for (const argument of node.inputs || []) {
+ argument.value = argument.value.map((index) => values[index]);
}
- for (const parameter of node.outputs || []) {
- parameter.arguments = parameter.arguments.map((index) => args[index]);
+ for (const argument of node.outputs || []) {
+ argument.value = argument.value.map((index) => values[index]);
}
this.nodes.push(new message.Node(node));
}
@@ -76,7 +76,7 @@ message.Argument = class {
constructor(data) {
this.name = data.name || '';
- this.value = (data.arguments || []);
+ this.value = data.value || [];
this.type = data.type || '';
}
};
@@ -129,6 +129,10 @@ message.TensorShape = class {
};
message.Tensor = class {
+
+ constructor(data) {
+ this.type = new message.TensorType(data.type);
+ }
};
message.Error = class extends Error {
diff --git a/source/server.py b/source/server.py
index 7c1f53ca8e..cb65d6ed14 100755
--- a/source/server.py
+++ b/source/server.py
@@ -22,9 +22,10 @@ class _ContentProvider: # pylint: disable=too-few-public-methods
base_dir = ''
base = ''
identifier = ''
- def __init__(self, data, path, file):
+ def __init__(self, data, path, file, name):
self.data = data if data else bytearray()
self.identifier = os.path.basename(file) if file else ''
+ self.name = name
if path:
self.dir = os.path.dirname(path) if os.path.dirname(path) else '.'
self.base = os.path.basename(path)
@@ -95,6 +96,9 @@ def do_GET(self): # pylint: disable=invalid-name
base = self.content.base
if base:
meta.append('')
+ name = self.content.name
+ if name:
+ meta.append('')
identifier = self.content.identifier
if identifier:
meta.append('')
@@ -281,14 +285,14 @@ def serve(file, data, address=None, browse=False, verbosity=1):
if not data and file and not os.path.exists(file):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file)
- content = _ContentProvider(data, file, file)
+ content = _ContentProvider(data, file, file, file)
if data and not isinstance(data, bytearray) and isinstance(data.__class__, type):
_log(verbosity > 1, 'Experimental\n')
model = _open(data)
if model:
text = json.dumps(model.to_json(), indent=4, ensure_ascii=False)
- content = _ContentProvider(text.encode('utf-8'), 'model.netron', file)
+ content = _ContentProvider(text.encode('utf-8'), 'model.netron', None, file)
address = _make_address(address)
if isinstance(address[1], int) and address[1] != 0: