From c5ba169558034c1104a22e354cdfa37f78e7bdb4 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Sun, 28 Oct 2018 03:19:22 -0700 Subject: [PATCH] TensorFlow shape class --- src/tf-model.js | 237 +++++++++++++++++++++++++----------------------- 1 file changed, 123 insertions(+), 114 deletions(-) diff --git a/src/tf-model.js b/src/tf-model.js index b370343551..23e78e8e9a 100644 --- a/src/tf-model.js +++ b/src/tf-model.js @@ -521,7 +521,8 @@ class TensorFlowNode { } get category() { - return this._graph.metadata.getOperatorCategory(this.operator); + var schema = this._graph.metadata.getSchema(this.operator); + return (schema && schema.category) ? schema.category : null; } get inputs() { @@ -565,19 +566,18 @@ class TensorFlowAttribute { this._name = name; this._value = null; this._type = null; + var schema = metadata.getAttributeSchema(operator, name); if (value.hasOwnProperty('tensor')) { this._type = new TensorFlowTensor(value.tensor).type; this._tensor = value.tensor.tensor_shape && value.tensor.tensor_shape.dim && value.tensor.tensor_shape.dim.length > 0; } - else { - var schema = metadata.getAttributeSchema(operator, name); - if (schema && schema.type) { - this._type = schema.type; - } + else if (schema && schema.type) { + this._type = schema.type; } if (value.hasOwnProperty('type')) { this._value = () => TensorFlowTensor.formatDataType(value.type); - } + this._type = 'type'; + } else if (value.hasOwnProperty('i')) { this._value = value.i; } @@ -588,7 +588,8 @@ class TensorFlowAttribute { this._value = value.b; } else if (value.hasOwnProperty('shape')) { - this._value = () => TensorFlowTensor.formatTensorShape(value.shape); + this._type = 'shape'; + this._value = new TensorFlowTensorShape(value.shape); } else if (value.hasOwnProperty('s')) { if (value.s.filter(c => c <= 32 && c >= 128).length == 0) { @@ -638,7 +639,8 @@ class TensorFlowAttribute { this._value = () => '...'; } else { - this._value = () => list.type.map((type) => TensorFlowTensor.formatDataType(type)); + this._value = list.type.map((type) => TensorFlowTensor.formatDataType(type)); + this._type = 'type[]'; } } else if (list.shape && list.shape.length > 0) { @@ -646,14 +648,37 @@ class TensorFlowAttribute { this._value = () => '...'; } else { - this._value = () => - list.shape.map((shape) => TensorFlowTensor.formatTensorShape(shape)).toString(); + this._value = list.shape.map((shape) => new TensorFlowTensorShape(shape)); + this._type = 'shape[]'; } } + } - if (!metadata.getAttributeVisible(operator, name, this._value)) { + if (schema) { + if (schema.hasOwnProperty('visible') && !attributeSchema.visible) { this._visible = false; } + else if (schema.hasOwnProperty('default')) { + var valueText = TensorFlowGraphOperatorMetadata._formatAttributeValue(this._value); + var defaultValueText = TensorFlowGraphOperatorMetadata._formatAttributeValue(schema.default); + if (JSON.stringify(valueText) == JSON.stringify(defaultValueText)) { + this._visible = false; + } + } + } + if (name == '_output_shapes') { + this._visible = false; + this._type = 'shape[]'; + } + if (name == '_class') { + this._visible = false; + } + var attributeVisibleMap = metadata.getAttributeVisibleMap(operator); + if (attributeVisibleMap[name]) { + this._visible = false; + } + if (this._type == 'list(shape)') { + this._type = 'shape[]'; } } @@ -901,29 +926,13 @@ class TensorFlowTensor { } return '?'; } - - static formatTensorShape(shape) { - if (shape && shape.dim) { - if (shape.unknown_rank) { - return '[-]'; - } - if (shape.dim.length == 0) { - return ''; - } - if (shape.dim.length == 1 && !shape.dim[0].size) { - return '[0]'; - } - return '[' + shape.dim.map((dim) => (dim.size && dim.size != -1) ? dim.size.toString() : '?').join(',') + ']'; - } - return '?'; - } } class TensorFlowTensorType { constructor(dtype, shape) { this._dtype = dtype; - this._shape = shape; + this._shape = new TensorFlowTensorShape(shape); } get dataType() { @@ -931,6 +940,21 @@ class TensorFlowTensorType { } get shape() { + return this._shape; + } + + toString() { + return this.dataType + this._shape.toString(); + } +} + +class TensorFlowTensorShape { + + constructor(shape) { + this._shape = shape; + } + + get dimensions() { if (this._shape && this._shape.dim) { if (this._shape.unknown_rank) { return null; @@ -947,9 +971,20 @@ class TensorFlowTensorType { } toString() { - return this.dataType + TensorFlowTensor.formatTensorShape(this._shape); + if (this._shape && this._shape.dim) { + if (this._shape.unknown_rank) { + return '[-]'; + } + if (this._shape.dim.length == 0) { + return ''; + } + if (this._shape.dim.length == 1 && !this._shape.dim[0].size) { + return '[0]'; + } + return '[' + this._shape.dim.map((dim) => (dim.size && dim.size != -1) ? dim.size.toString() : '?').join(',') + ']'; + } + return '?'; } - } class TensorFlowGraphOperatorMetadata { @@ -970,6 +1005,63 @@ class TensorFlowGraphOperatorMetadata { return schema; } + getAttributeSchema(operator, name, value) { + var schema = this.getSchema(operator); + if (schema) { + var attributeMap = schema.attributeMap; + if (!attributeMap) { + attributeMap = {}; + if (schema.attributes) { + schema.attributes.forEach((attribute) => { + attributeMap[attribute.name] = attribute; + }); + } + schema.attributeMap = attributeMap; + } + return attributeMap[name] || null; + } + return null; + } + + getAttributeVisibleMap(operator) { + var schema = this.getSchema(operator); + if (schema) { + var map = schema.__visisbleAttributeMap__; + if (!map) { + map = {}; + if (schema.inputs) { + schema.inputs.forEach((input) => { + if (input.typeAttr) { + map[input.typeAttr] = true; + } + else if (input.typeListAttr) { + map[input.typeListAttr] = true; + } + if (input.numberAttr) { + map[input.numberAttr] = true; + } + }); + } + if (schema.outputs) { + schema.outputs.forEach((output) => { + if (output.typeAttr) { + map[output.typeAttr] = true; + } + else if (output.typeListAttr) { + map[output.typeListAttr] = true; + } + if (output.numberAttr) { + map[output.numberAttr] = true; + } + }); + } + schema.__visisbleAttributeMap__ = map; + } + return map; + } + return {}; + } + getInputs(node) { var results = []; var index = 0; @@ -1060,89 +1152,6 @@ class TensorFlowGraphOperatorMetadata { return results; } - getAttributeSchema(operator, name, value) { - var schema = this.getSchema(operator); - if (schema) { - var attributeMap = schema.attributeMap; - if (!attributeMap) { - attributeMap = {}; - if (schema.attributes) { - schema.attributes.forEach((attribute) => { - attributeMap[attribute.name] = attribute; - }); - } - schema.attributeMap = attributeMap; - } - return attributeMap[name] || null; - } - return null; - } - - getAttributeVisible(operator, name, value) { - var schema = this.getSchema(operator); - if (schema) { - var attributeSchema = this.getAttributeSchema(operator, name); - if (attributeSchema) { - if (attributeSchema.hasOwnProperty('visible')) { - return attributeSchema.visible; - } - if (attributeSchema.hasOwnProperty('default')) { - var valueText = TensorFlowGraphOperatorMetadata._formatAttributeValue(value); - var defaultValueText = TensorFlowGraphOperatorMetadata._formatAttributeValue(attributeSchema.default); - if (JSON.stringify(valueText) == JSON.stringify(defaultValueText)) { - return false; - } - } - } - if (name == '_output_shapes' || name == '_class') { - return false; - } - var hiddenAttributeMap = schema.hiddenAttributeMap; - if (!hiddenAttributeMap) { - hiddenAttributeMap = {}; - if (schema.inputs) { - schema.inputs.forEach((input) => { - if (input.typeAttr) { - hiddenAttributeMap[input.typeAttr] = true; - } - else if (input.typeListAttr) { - hiddenAttributeMap[input.typeListAttr] = true; - } - if (input.numberAttr) { - hiddenAttributeMap[input.numberAttr] = true; - } - }); - } - if (schema.outputs) { - schema.outputs.forEach((output) => { - if (output.typeAttr) { - hiddenAttributeMap[output.typeAttr] = true; - } - else if (output.typeListAttr) { - hiddenAttributeMap[output.typeListAttr] = true; - } - if (output.numberAttr) { - hiddenAttributeMap[output.numberAttr] = true; - } - }); - } - schema.hiddenAttributeMap = hiddenAttributeMap; - } - if (hiddenAttributeMap[name]) { - return false; - } - } - return true; - } - - getOperatorCategory(node) { - var schema = this.getSchema(node); - if (schema && schema.category) { - return schema.category; - } - return null; - } - getOperatorDocumentation(operator) { var schema = this.getSchema(operator); if (schema) {