Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic tensorflow subgraph (function) support #318

Merged
merged 1 commit into from
Aug 14, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 217 additions & 1 deletion src/tf.js
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,23 @@ tf.Model = class {
else if (model.meta_graphs.length > 1) {
name = i.toString();
}
else {
name = 'main';
}
this._graphs.push(new tf.Graph(metadata, metaGraph, name));
}

// Recursively add all subgraphs.
let visited_graph = [];
let pending_graphs = [...this._graphs];
while (pending_graphs.length > 0) {
let g = pending_graphs.shift();
visited_graph.push(g);
for (let f of g.functions)
pending_graphs.push(f);
}
this._graphs = visited_graph;

this._activeGraph = (this._graphs.length > 0) ? this._graphs[0] : null;
}

Expand Down Expand Up @@ -260,6 +275,8 @@ tf.Graph = class {
this._inputs = [];
this._outputs = [];
this._nodes = [];
this._functions = [];

if (metaGraph.graph_def) {
var graph = metaGraph.graph_def;
if (graph.versions) {
Expand Down Expand Up @@ -356,7 +373,7 @@ tf.Graph = class {
var shape = node.attr.shape;
if (dtype && dtype.type && shape && shape.shape) {
var type = new tf.TensorType(dtype.type, shape.shape);
var argument = new tf.Argument(node.output[0], type, null);
var argument = new tf.Argument(node.output[0], type, null);
inputMap[node.output[0]] = new tf.Parameter(node.name, [ argument ]);
}
}
Expand All @@ -371,6 +388,13 @@ tf.Graph = class {
}
}
}

if (graph.library) {
var funcs = graph.library.function;
for (var func of funcs) {
this._functions.push(new tf.Function(this, func, this._metadata));
}
}
}
}

Expand Down Expand Up @@ -411,6 +435,10 @@ tf.Graph = class {
return this._namespaces;
}

get functions() {
return this._functions;
}

_checkSingleOutput(node) {
if (node.output.length != 1) {
return false;
Expand Down Expand Up @@ -466,6 +494,189 @@ tf.Argument = class {
}
};

// Technically it is used as graph subclass...
tf.Function = class {
constructor(graph, func, metadata) {
this._name = func.signature.name;
this._version = null;
this._tags = null;
this._inputs = [];
this._outputs = [];
this._nodes = [];
this._metadata = metadata;
this._namespaces = {};
this._functions = [];

var in_args = func.signature.input_arg;
if (in_args) {
in_args.forEach((arg, indx) => {
var dtype = new tf.TensorType(arg.type, null);
var a = new tf.Argument(arg.name, dtype, null);
var p = new tf.Parameter(arg.name, [a])
this._inputs.push(p);
});
}

var ret_map = {};
for (let k of Object.keys(func.ret)) {
let v = func.ret[k].split(':', 2);
ret_map[k] = v[0];
}

var out_args_reverse_map = {};
var out_args = func.signature.output_arg;
if (out_args) {
out_args.forEach((arg, indx) => {
let dtype = new tf.TensorType(arg.type, null);
let name = ret_map[arg.name];
let a = new tf.Argument(name, dtype, null);
let p = new tf.Parameter(arg.name, [a])
this._outputs.push(p);
out_args_reverse_map[name] = arg.name;
});
}

var nodes = func.node_def;
if (nodes) {
var nodeMap = {};

for (let node of nodes) {
let nodeName = node.name;
nodeMap[nodeName] = node;
if (node.op != 'Const') {
let lastIndex = nodeName.lastIndexOf('/');
if (lastIndex != -1) {
let namespace = nodeName.substring(0, lastIndex);
this._namespaces[namespace] = true;
}
}
node.output = [];
}
for (let node of nodes) {
let inputs = node.input;
node.input = [];
node.controlDependencies = [];
for (let input of inputs) {
var split = input.split(':', 3);
var inputName = split[0];
var outputIndex = split.length == 1 ? 0 : parseInt(split[split.length - 1]);
var outputName = inputName.startsWith('^') ? inputName.substring(1) : inputName;
var outputNode = nodeMap[outputName];
outputName = outputIndex == 0 ? outputName : outputName + ':' + outputIndex.toString();
if (inputName.startsWith('^')) {
node.controlDependencies.push(outputName);
}
else {
node.input.push(outputName);
}
if (outputNode) {
for (var j = outputNode.output.length; j <= outputIndex; j++) {
outputNode.output.push('');
}
outputNode.output[outputIndex] = outputName;
}
}

if (out_args_reverse_map[node.name]) {
node.output.push(node.name);
}
}

let nodeOutputCountMap = {};
for (let node of nodes) {
for (let input of node.input) {
nodeOutputCountMap[input] = (nodeOutputCountMap[input] || 0) + 1;
}
for (let controlDependency of node.controlDependencies) {
nodeOutputCountMap[controlDependency] = (nodeOutputCountMap[controlDependency] || 0) + 1;
}
}

function _checkSingleOutput(node) {
if (node.output.length != 1) {
return false;
}
var output = node.output[0];
var count = nodeOutputCountMap[output];
if (count != 1) {
return false;
}
return true;
}

var initializers = {};
for (let node of nodes) {
if (node.op == 'Const' && node.input.length == 0 && node.controlDependencies.length == 0 && _checkSingleOutput(node)) {
var value = node.attr.value;
if (value && Object.prototype.hasOwnProperty.call(value, 'tensor')) {
var output = node.output[0];
if (output) {
initializers[output] = new tf.Tensor(value.tensor, node.name, 'Constant');
}
}
}
}
for (let node of nodes) {
if (node.op == 'Identity' && node.input.length == 1 && node.controlDependencies.length == 0 && _checkSingleOutput(node)) {
var initializer_name = node.input[0];
var initializer = initializers[initializer_name];
if (initializer) {
initializers[initializer_name] = "-";
initializer.kind = 'Identity Constant';
initializers[node.output[0]] = initializer;
}
}
}

for (let node of nodes) {
if (!initializers[node.name])
this._nodes.push(new tf.Node(this, node, initializers));
}
}
}

get name() {
return this._name;
}

get version() {
return this._version;
}

get tags() {
return this._tags;
}

get groups() {
return false;
// TODO return true;
}

get inputs() {
return this._inputs;
}

get outputs() {
return this._outputs;
}

get nodes() {
return this._nodes;
}

get metadata() {
return this._metadata;
}

get namespaces() {
return this._namespaces;
}

get functions() {
return this._functions;
}
}

tf.Node = class {

constructor(graph, node, initializers) {
Expand Down Expand Up @@ -744,6 +955,11 @@ tf.Attribute = class {
this._value = list.shape.map((shape) => new tf.TensorShape(shape));
}
}
else if (Object.prototype.hasOwnProperty.call(value, 'func')) {
var func = value.func;
this._type = 'function';
this._value = func.name;
}

if (schema) {
if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
Expand Down