Skip to content

Commit

Permalink
TensorFlow shape class
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Oct 28, 2018
1 parent 4ad3a43 commit c5ba169
Showing 1 changed file with 123 additions and 114 deletions.
237 changes: 123 additions & 114 deletions src/tf-model.js
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,8 @@ class TensorFlowNode {
}

get category() {
return this._graph.metadata.getOperatorCategory(this.operator);
var schema = this._graph.metadata.getSchema(this.operator);
return (schema && schema.category) ? schema.category : null;
}

get inputs() {
Expand Down Expand Up @@ -565,19 +566,18 @@ class TensorFlowAttribute {
this._name = name;
this._value = null;
this._type = null;
var schema = metadata.getAttributeSchema(operator, name);
if (value.hasOwnProperty('tensor')) {
this._type = new TensorFlowTensor(value.tensor).type;
this._tensor = value.tensor.tensor_shape && value.tensor.tensor_shape.dim && value.tensor.tensor_shape.dim.length > 0;
}
else {
var schema = metadata.getAttributeSchema(operator, name);
if (schema && schema.type) {
this._type = schema.type;
}
else if (schema && schema.type) {
this._type = schema.type;
}
if (value.hasOwnProperty('type')) {
this._value = () => TensorFlowTensor.formatDataType(value.type);
}
this._type = 'type';
}
else if (value.hasOwnProperty('i')) {
this._value = value.i;
}
Expand All @@ -588,7 +588,8 @@ class TensorFlowAttribute {
this._value = value.b;
}
else if (value.hasOwnProperty('shape')) {
this._value = () => TensorFlowTensor.formatTensorShape(value.shape);
this._type = 'shape';
this._value = new TensorFlowTensorShape(value.shape);
}
else if (value.hasOwnProperty('s')) {
if (value.s.filter(c => c <= 32 && c >= 128).length == 0) {
Expand Down Expand Up @@ -638,22 +639,46 @@ class TensorFlowAttribute {
this._value = () => '...';
}
else {
this._value = () => list.type.map((type) => TensorFlowTensor.formatDataType(type));
this._value = list.type.map((type) => TensorFlowTensor.formatDataType(type));
this._type = 'type[]';
}
}
else if (list.shape && list.shape.length > 0) {
if (list.shape.length > 65536) {
this._value = () => '...';
}
else {
this._value = () =>
list.shape.map((shape) => TensorFlowTensor.formatTensorShape(shape)).toString();
this._value = list.shape.map((shape) => new TensorFlowTensorShape(shape));
this._type = 'shape[]';
}
}
}

if (!metadata.getAttributeVisible(operator, name, this._value)) {
if (schema) {
if (schema.hasOwnProperty('visible') && !attributeSchema.visible) {
this._visible = false;
}
else if (schema.hasOwnProperty('default')) {
var valueText = TensorFlowGraphOperatorMetadata._formatAttributeValue(this._value);
var defaultValueText = TensorFlowGraphOperatorMetadata._formatAttributeValue(schema.default);
if (JSON.stringify(valueText) == JSON.stringify(defaultValueText)) {
this._visible = false;
}
}
}
if (name == '_output_shapes') {
this._visible = false;
this._type = 'shape[]';
}
if (name == '_class') {
this._visible = false;
}
var attributeVisibleMap = metadata.getAttributeVisibleMap(operator);
if (attributeVisibleMap[name]) {
this._visible = false;
}
if (this._type == 'list(shape)') {
this._type = 'shape[]';
}
}

Expand Down Expand Up @@ -901,36 +926,35 @@ class TensorFlowTensor {
}
return '?';
}

static formatTensorShape(shape) {
if (shape && shape.dim) {
if (shape.unknown_rank) {
return '[-]';
}
if (shape.dim.length == 0) {
return '';
}
if (shape.dim.length == 1 && !shape.dim[0].size) {
return '[0]';
}
return '[' + shape.dim.map((dim) => (dim.size && dim.size != -1) ? dim.size.toString() : '?').join(',') + ']';
}
return '?';
}
}

class TensorFlowTensorType {

constructor(dtype, shape) {
this._dtype = dtype;
this._shape = shape;
this._shape = new TensorFlowTensorShape(shape);
}

get dataType() {
return this._dtype ? TensorFlowTensor.formatDataType(this._dtype) : '?';
}

get shape() {
return this._shape;
}

toString() {
return this.dataType + this._shape.toString();
}
}

class TensorFlowTensorShape {

constructor(shape) {
this._shape = shape;
}

get dimensions() {
if (this._shape && this._shape.dim) {
if (this._shape.unknown_rank) {
return null;
Expand All @@ -947,9 +971,20 @@ class TensorFlowTensorType {
}

toString() {
return this.dataType + TensorFlowTensor.formatTensorShape(this._shape);
if (this._shape && this._shape.dim) {
if (this._shape.unknown_rank) {
return '[-]';
}
if (this._shape.dim.length == 0) {
return '';
}
if (this._shape.dim.length == 1 && !this._shape.dim[0].size) {
return '[0]';
}
return '[' + this._shape.dim.map((dim) => (dim.size && dim.size != -1) ? dim.size.toString() : '?').join(',') + ']';
}
return '?';
}

}

class TensorFlowGraphOperatorMetadata {
Expand All @@ -970,6 +1005,63 @@ class TensorFlowGraphOperatorMetadata {
return schema;
}

getAttributeSchema(operator, name, value) {
var schema = this.getSchema(operator);
if (schema) {
var attributeMap = schema.attributeMap;
if (!attributeMap) {
attributeMap = {};
if (schema.attributes) {
schema.attributes.forEach((attribute) => {
attributeMap[attribute.name] = attribute;
});
}
schema.attributeMap = attributeMap;
}
return attributeMap[name] || null;
}
return null;
}

getAttributeVisibleMap(operator) {
var schema = this.getSchema(operator);
if (schema) {
var map = schema.__visisbleAttributeMap__;
if (!map) {
map = {};
if (schema.inputs) {
schema.inputs.forEach((input) => {
if (input.typeAttr) {
map[input.typeAttr] = true;
}
else if (input.typeListAttr) {
map[input.typeListAttr] = true;
}
if (input.numberAttr) {
map[input.numberAttr] = true;
}
});
}
if (schema.outputs) {
schema.outputs.forEach((output) => {
if (output.typeAttr) {
map[output.typeAttr] = true;
}
else if (output.typeListAttr) {
map[output.typeListAttr] = true;
}
if (output.numberAttr) {
map[output.numberAttr] = true;
}
});
}
schema.__visisbleAttributeMap__ = map;
}
return map;
}
return {};
}

getInputs(node) {
var results = [];
var index = 0;
Expand Down Expand Up @@ -1060,89 +1152,6 @@ class TensorFlowGraphOperatorMetadata {
return results;
}

getAttributeSchema(operator, name, value) {
var schema = this.getSchema(operator);
if (schema) {
var attributeMap = schema.attributeMap;
if (!attributeMap) {
attributeMap = {};
if (schema.attributes) {
schema.attributes.forEach((attribute) => {
attributeMap[attribute.name] = attribute;
});
}
schema.attributeMap = attributeMap;
}
return attributeMap[name] || null;
}
return null;
}

getAttributeVisible(operator, name, value) {
var schema = this.getSchema(operator);
if (schema) {
var attributeSchema = this.getAttributeSchema(operator, name);
if (attributeSchema) {
if (attributeSchema.hasOwnProperty('visible')) {
return attributeSchema.visible;
}
if (attributeSchema.hasOwnProperty('default')) {
var valueText = TensorFlowGraphOperatorMetadata._formatAttributeValue(value);
var defaultValueText = TensorFlowGraphOperatorMetadata._formatAttributeValue(attributeSchema.default);
if (JSON.stringify(valueText) == JSON.stringify(defaultValueText)) {
return false;
}
}
}
if (name == '_output_shapes' || name == '_class') {
return false;
}
var hiddenAttributeMap = schema.hiddenAttributeMap;
if (!hiddenAttributeMap) {
hiddenAttributeMap = {};
if (schema.inputs) {
schema.inputs.forEach((input) => {
if (input.typeAttr) {
hiddenAttributeMap[input.typeAttr] = true;
}
else if (input.typeListAttr) {
hiddenAttributeMap[input.typeListAttr] = true;
}
if (input.numberAttr) {
hiddenAttributeMap[input.numberAttr] = true;
}
});
}
if (schema.outputs) {
schema.outputs.forEach((output) => {
if (output.typeAttr) {
hiddenAttributeMap[output.typeAttr] = true;
}
else if (output.typeListAttr) {
hiddenAttributeMap[output.typeListAttr] = true;
}
if (output.numberAttr) {
hiddenAttributeMap[output.numberAttr] = true;
}
});
}
schema.hiddenAttributeMap = hiddenAttributeMap;
}
if (hiddenAttributeMap[name]) {
return false;
}
}
return true;
}

getOperatorCategory(node) {
var schema = this.getSchema(node);
if (schema && schema.category) {
return schema.category;
}
return null;
}

getOperatorDocumentation(operator) {
var schema = this.getSchema(operator);
if (schema) {
Expand Down

0 comments on commit c5ba169

Please sign in to comment.