From 552b09aa88beb7b7640e126a6565820d943451b6 Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Thu, 4 Jul 2024 10:07:31 -0700 Subject: [PATCH] Update view.js (#1285) --- source/grapher.css | 34 +++-- source/grapher.js | 302 ++++++++++++++++++++++++---------------- source/pytorch.js | 2 +- source/view.js | 338 ++++++++++++++++++++++++++------------------- 4 files changed, 403 insertions(+), 273 deletions(-) diff --git a/source/grapher.css b/source/grapher.css index 0739516707..4df2bf9d6a 100644 --- a/source/grapher.css +++ b/source/grapher.css @@ -51,10 +51,11 @@ .node-item-undefined:hover { cursor: pointer; } .node-item-undefined:hover path { fill: #fff; } -.node-attribute-list > path { fill: #fff; stroke-width: 0; stroke: #000; } -.node-attribute-list:hover { cursor: pointer; } -.node-attribute-list:hover > path { fill: #f6f6f6; } -.node-attribute > text { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", "Ubuntu", "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricPrecision; user-select: none; } +.node-argument-list > path { fill: #fff; stroke-width: 0; stroke: #000; } +.node-argument-list:hover { cursor: pointer; } +.node-argument-list:hover > path { fill: #f6f6f6; } +.node-argument > text { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", "Ubuntu", "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricPrecision; user-select: none; } +.node-argument > rect { fill: transparent; } .graph-item-input path { fill: #eee; } .graph-item-input:hover { cursor: pointer; } @@ -65,13 +66,16 @@ .graph-item-output:hover path { fill: #fff; } #arrowhead { fill: #000; } -#arrowhead-select { fill: #e00; } +#arrowhead-hover { fill: rgba(238, 0, 0, 0.8); } +#arrowhead-select { fill: rgba(238, 0, 0, 0.8); } .edge-path { stroke: #000; stroke-width: 1px; fill: none; marker-end: url("#arrowhead"); } .edge-path-hit-test { pointer-events: stroke; stroke-width: 0.5em; fill: none; stroke: #000; stroke-opacity: 0.001; } -.select > .node.node-border { stroke: #e00; stroke-width: 2px; } -.select.edge-path { stroke: #e00; stroke-width: 1px; marker-end: url("#arrowhead-select"); } +.select > .node.node-border { stroke: rgba(238, 0, 0, 0.8); stroke-width: 2px; } +.select.edge-path { stroke: rgba(238, 0, 0, 0.8); stroke-width: 1px; marker-end: url("#arrowhead-select"); } +.select.node-argument > rect { fill: rgba(238, 0, 0, 0.8); } +.select.node-argument > text { fill: #f6f6f6; } .edge-label { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", "Ubuntu", "Droid Sans", sans-serif, "PingFang SC"; font-size: 10px; } .edge-path-control-dependency { stroke-dasharray: 3, 2; } @@ -87,12 +91,14 @@ .node path { stroke: #1d1d1d; } .node line { stroke: #1d1d1d; } - .select > .node.node-border { stroke: #b00; } - .select.edge-path { stroke: #b00; } + .select > .node.node-border { stroke: rgba(187, 0, 0, 0.8); } + .select.edge-path { stroke: rgba(187, 0, 0, 0.8); } + .select.node-argument > rect { fill: rgba(187, 0, 0, 0.8); } + .select.node-argument > text { fill: #b2b2b2; } #arrowhead { fill: #888; } - #arrowhead-hover { fill: #b00; } - #arrowhead-select { fill: #b00 } + #arrowhead-hover { fill: rgba(187, 0, 0, 0.8); } + #arrowhead-select { fill: rgba(187, 0, 0, 0.8) } .edge-label { fill: #b2b2b2; } @@ -110,9 +116,9 @@ .node-item path { stroke: #fff; } .node-item text { fill: #dfdfdf; } - .node-attribute > text { fill: #b2b2b2; } - .node-attribute-list > path { fill: #2d2d2d; } - .node-attribute-list:hover > path { fill: #666666; } + .node-argument > text { fill: #b2b2b2; } + .node-argument-list > path { fill: #2d2d2d; } + .node-argument-list:hover > path { fill: #303030; } .graph-item-input path { fill: #404040; } .graph-item-input:hover { cursor: pointer; } diff --git a/source/grapher.js b/source/grapher.js index 7833323177..be18d79662 100644 --- a/source/grapher.js +++ b/source/grapher.js @@ -277,7 +277,7 @@ grapher.Node = class { } list() { - const block = new grapher.Node.List(); + const block = new grapher.ArgumentList(); this._blocks.push(block); return block; } @@ -359,16 +359,7 @@ grapher.Node = class { r2 = r2 ? radius : 0; r3 = r3 ? radius : 0; r4 = r4 ? radius : 0; - return `M${x + r1},${y - }h${width - r1 - r2 - }a${r2},${r2} 0 0 1 ${r2},${r2 - }v${height - r2 - r3 - }a${r3},${r3} 0 0 1 ${-r3},${r3 - }h${r3 + r4 - width - }a${r4},${r4} 0 0 1 ${-r4},${-r4 - }v${-height + r4 + r1 - }a${r1},${r1} 0 0 1 ${r1},${-r1 - }z`; + return `M${x + r1},${y}h${width - r1 - r2}a${r2},${r2} 0 0 1 ${r2},${r2}v${height - r2 - r3}a${r3},${r3} 0 0 1 ${-r3},${r3}h${r3 + r4 - width}a${r4},${r4} 0 0 1 ${-r4},${-r4}v${-height + r4 + r1}a${r1},${r1} 0 0 1 ${r1},${-r1}z`; } }; @@ -524,17 +515,19 @@ grapher.Node.Header.Entry = class { } }; -grapher.Node.List = class { +grapher.ArgumentList = class { constructor() { this._items = []; this._events = {}; } - add(name, value, tooltip, separator) { - const item = new grapher.Node.List.Item(name, value, tooltip, separator); - this._items.push(item); - return item; + argument(name, value) { + return new grapher.Argument(name, value); + } + + add(value) { + this._items.push(value); } on(event, callback) { @@ -553,7 +546,7 @@ grapher.Node.List = class { build(document, parent) { this._document = document; this.element = document.createElementNS('http://www.w3.org/2000/svg', 'g'); - this.element.setAttribute('class', 'node-attribute-list'); + this.element.setAttribute('class', 'node-argument-list'); if (this._events.click) { this.element.addEventListener('click', (e) => { e.stopPropagation(); @@ -564,38 +557,7 @@ grapher.Node.List = class { this.element.appendChild(this.background); parent.appendChild(this.element); for (const item of this._items) { - const group = document.createElementNS('http://www.w3.org/2000/svg', 'g'); - group.setAttribute('class', 'node-attribute'); - const text = document.createElementNS('http://www.w3.org/2000/svg', 'text'); - text.setAttribute('xml:space', 'preserve'); - if (item.tooltip) { - const title = document.createElementNS('http://www.w3.org/2000/svg', 'title'); - title.textContent = item.tooltip; - text.appendChild(title); - } - const colon = item.type === 'node' || item.type === 'node[]'; - const name = document.createElementNS('http://www.w3.org/2000/svg', 'tspan'); - name.textContent = colon ? `${item.name}:` : item.name; - if (item.separator.trim() !== '=' && !colon) { - name.style.fontWeight = 'bold'; - } - text.appendChild(name); - group.appendChild(text); - this.element.appendChild(group); - item.group = group; - item.text = text; - if (item.type === 'node') { - const node = item.value; - node.build(document, item.group); - } else if (item.type === 'node[]') { - for (const node of item.value) { - node.build(document, item.group); - } - } else { - const tspan = document.createElementNS('http://www.w3.org/2000/svg', 'tspan'); - tspan.textContent = item.separator + item.value; - item.text.appendChild(tspan); - } + item.build(document, this.element); } if (!this.first) { this.line = document.createElementNS('http://www.w3.org/2000/svg', 'line'); @@ -607,62 +569,31 @@ grapher.Node.List = class { measure() { this.width = 75; this.height = 3; - const yPadding = 1; - const xPadding = 6; for (let i = 0; i < this._items.length; i++) { const item = this._items[i]; - const size = item.text.getBBox(); - item.width = xPadding + size.width + xPadding; - item.height = yPadding + size.height + yPadding; - item.offset = size.y; + item.measure(); this.height += item.height; - if (item.type === 'node') { - const node = item.value; - node.measure(); - this.width = Math.max(150, this.width, node.width + (2 * xPadding)); - this.height += node.height + yPadding + yPadding + yPadding + yPadding; - if (i === this._items.length - 1) { - this.height += 3; - } - } else if (item.type === 'node[]') { - for (const node of item.value) { - node.measure(); - this.width = Math.max(150, this.width, node.width + (2 * xPadding)); - this.height += node.height + yPadding + yPadding + yPadding + yPadding; - } + this.width = Math.max(this.width, item.width); + if (item.type === 'node' || item.type === 'node[]') { if (i === this._items.length - 1) { this.height += 3; } } - this.width = Math.max(this.width, item.width); + } + for (const item of this._items) { + item.width = this.width; } this.height += 3; } layout() { - const yPadding = 1; - const xPadding = 6; let y = 3; for (const item of this._items) { - item.x = this.x + xPadding; - item.y = y + yPadding - item.offset; + item.x = this.x; + item.y = y; + item.width = this.width; + item.layout(); y += item.height; - if (item.type === 'node') { - const node = item.value; - node.width = this.width - xPadding - xPadding; - node.layout(); - node.x = this.x + xPadding + (node.width / 2); - node.y = y + (node.height / 2) + yPadding + yPadding; - y += node.height + yPadding + yPadding + yPadding + yPadding; - } else if (item.type === 'node[]') { - for (const node of item.value) { - node.width = this.width - xPadding - xPadding; - node.layout(); - node.x = this.x + xPadding + (node.width / 2); - node.y = y + (node.height / 2) + yPadding + yPadding; - y += node.height + yPadding + yPadding + yPadding + yPadding; - } - } } } @@ -670,17 +601,7 @@ grapher.Node.List = class { this.element.setAttribute('transform', `translate(${this.x},${this.y})`); this.background.setAttribute('d', grapher.Node.roundedRect(0, 0, this.width, this.height, this.first, this.first, this.last, this.last)); for (const item of this._items) { - const text = item.text; - text.setAttribute('x', item.x); - text.setAttribute('y', item.y); - if (item.type === 'node') { - const node = item.value; - node.update(); - } else if (item.type === 'node[]') { - for (const node of item.value) { - node.update(); - } - } + item.update(); } if (this.line) { this.line.setAttribute('x1', 0); @@ -688,28 +609,158 @@ grapher.Node.List = class { this.line.setAttribute('y1', 0); this.line.setAttribute('y2', 0); } - for (const item of this._items) { - if (item.value instanceof grapher.Node) { - const node = item.value; - node.update(); - } - } } }; -grapher.Node.List.Item = class { +grapher.Argument = class { - constructor(name, value, tooltip, separator) { + constructor(name, content) { this.name = name; - this.value = value; - this.tooltip = tooltip; - this.separator = separator; - if (value instanceof grapher.Node) { + this.content = content; + this.tooltip = ''; + this.separator = ''; + if (content instanceof grapher.Node) { this.type = 'node'; - } else if (Array.isArray(value) && value.every((value) => value instanceof grapher.Node)) { + } else if (Array.isArray(content) && content.every((value) => value instanceof grapher.Node)) { this.type = 'node[]'; } } + + build(document, parent) { + this.element = document.createElementNS('http://www.w3.org/2000/svg', 'g'); + this.element.setAttribute('class', 'node-argument'); + this.border = document.createElementNS('http://www.w3.org/2000/svg', 'rect'); + this.border.setAttribute('rx', 3); + this.border.setAttribute('ry', 3); + this.element.appendChild(this.border); + const text = document.createElementNS('http://www.w3.org/2000/svg', 'text'); + text.setAttribute('xml:space', 'preserve'); + if (this.tooltip) { + const title = document.createElementNS('http://www.w3.org/2000/svg', 'title'); + title.textContent = this.tooltip; + text.appendChild(title); + } + const colon = this.type === 'node' || this.type === 'node[]'; + const name = document.createElementNS('http://www.w3.org/2000/svg', 'tspan'); + name.textContent = colon ? `${this.name}:` : this.name; + if (this.separator.trim() !== '=' && !colon) { + name.style.fontWeight = 'bold'; + } + if (this.focus) { + this.element.addEventListener('pointerover', (e) => { + this.focus(); + e.stopPropagation(); + }); + } + if (this.blur) { + this.element.addEventListener('pointerleave', (e) => { + this.blur(); + e.stopPropagation(); + }); + } + if (this.activate) { + this.element.addEventListener('click', (e) => { + this.activate(); + e.stopPropagation(); + }); + } + text.appendChild(name); + this.element.appendChild(text); + parent.appendChild(this.element); + this.text = text; + switch (this.type) { + case 'node': { + const node = this.content; + node.build(document, this.element); + break; + } + case 'node[]': { + for (const node of this.content) { + node.build(document, this.element); + } + break; + } + default: { + const tspan = document.createElementNS('http://www.w3.org/2000/svg', 'tspan'); + tspan.textContent = (this.separator || '') + this.content; + this.text.appendChild(tspan); + break; + } + } + } + + measure() { + const yPadding = 1; + const xPadding = 6; + const size = this.text.getBBox(); + this.width = xPadding + size.width + xPadding; + this.bottom = yPadding + size.height + yPadding; + this.offset = size.y; + this.height = this.bottom; + if (this.type === 'node') { + const node = this.content; + node.measure(); + this.width = Math.max(150, this.width, node.width + (2 * xPadding)); + this.height += node.height + yPadding + yPadding + yPadding + yPadding; + } else if (this.type === 'node[]') { + for (const node of this.content) { + node.measure(); + this.width = Math.max(150, this.width, node.width + (2 * xPadding)); + this.height += node.height + yPadding + yPadding + yPadding + yPadding; + } + } + } + + layout() { + const yPadding = 1; + const xPadding = 6; + if (this.type === 'node') { + const node = this.content; + node.width = this.width - xPadding - xPadding; + node.layout(); + node.x = this.x + xPadding + (node.width / 2); + node.y = this.y + this.bottom + (node.height / 2) + yPadding + yPadding; + } else if (this.type === 'node[]') { + for (const node of this.content) { + node.layout(); + node.x = this.x + xPadding + (node.width / 2); + node.y = this.y + (node.height / 2) + yPadding + yPadding; + } + } + } + + update() { + const yPadding = 1; + const xPadding = 6; + this.text.setAttribute('x', this.x + xPadding); + this.text.setAttribute('y', this.y + yPadding - this.offset); + this.border.setAttribute('x', this.x + 3); + this.border.setAttribute('y', this.y + 1); + this.border.setAttribute('width', this.width - 6); + this.border.setAttribute('height', this.height - 1); + if (this.type === 'node') { + const node = this.content; + node.update(); + } else if (this.type === 'node[]') { + for (const node of this.content) { + node.update(); + } + } + } + + select() { + if (this.element) { + this.element.classList.add('select'); + return [this.element]; + } + return []; + } + + deselect() { + if (this.element) { + this.element.classList.remove('select'); + } + } }; grapher.Node.Canvas = class { @@ -745,9 +796,24 @@ grapher.Edge = class { edgePathGroupElement.appendChild(this.element); this.hitTest = createElement('path'); this.hitTest.setAttribute('class', 'edge-path-hit-test'); - this.hitTest.addEventListener('pointerover', () => this.emit('pointerover')); - this.hitTest.addEventListener('pointerleave', () => this.emit('pointerleave')); - this.hitTest.addEventListener('click', () => this.emit('click')); + if (this.focus) { + this.hitTest.addEventListener('pointerover', (e) => { + this.focus(); + e.stopPropagation(); + }); + } + if (this.blur) { + this.hitTest.addEventListener('pointerleave', (e) => { + this.blur(); + e.stopPropagation(); + }); + } + if (this.activate) { + this.hitTest.addEventListener('click', (e) => { + this.activate(); + e.stopPropagation(); + }); + } edgePathGroupElement.appendChild(this.hitTest); if (this.label) { const tspan = createElement('tspan'); @@ -947,4 +1013,4 @@ grapher.Edge.Path = class { } }; -export const { Graph, Node, Edge } = grapher; +export const { Graph, Node, Edge, Argument } = grapher; diff --git a/source/pytorch.js b/source/pytorch.js index 5840e271e2..dea02a904e 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -291,7 +291,7 @@ pytorch.Node = class { return type; }; const createAttribute = (metadata, name, value) => { - let visible = false; + let visible = true; let type = null; if (name === 'training') { visible = false; diff --git a/source/view.js b/source/view.js index 0f5de7e6ab..3edb9a9868 100644 --- a/source/view.js +++ b/source/view.js @@ -1719,6 +1719,7 @@ view.Graph = class extends grapher.Graph { this.options = options; this._nodeKey = 0; this._values = new Map(); + this._tensors = new Map(); this._table = new Map(); this._selection = new Set(); } @@ -1764,9 +1765,20 @@ view.Graph = class extends grapher.Graph { return this._values.get(name); } - createTensor(value) { - const obj = new view.Tensor(this, value); - this._table.set(value, obj); + createArgument(value) { + if (Array.isArray(value.value) && value.value.length === 1 && value.value[0].initializer) { + const [key] = value.value; + if (this._tensors.has(key)) { + const obj = this._tensors.get(key); + this._table.set(key, obj); + } else { + const obj = new view.Argument(this, value); + this._tensors.set(key, obj); + this._table.set(key, obj); + } + return this._tensors.get(key); + } + return null; } add(graph, signature) { @@ -1804,11 +1816,13 @@ view.Graph = class extends grapher.Graph { const inputs = node.inputs; for (const input of inputs) { if (!input.type || input.type.endsWith('*')) { - for (const value of input.value) { - if (value.name !== '' && !value.initializer) { - this.createValue(value).to.push(viewNode); - } else if (value.initializer) { - this.createTensor(value); + if (Array.isArray(input.value) && input.value.length === 1 && input.value[0].initializer) { + this.createArgument(input); + } else { + for (const value of input.value) { + if (value.name !== '' && !value.initializer) { + this.createValue(value).to.push(viewNode); + } } } } @@ -1962,12 +1976,8 @@ view.Node = class extends grapher.Node { _add(node) { const options = this.context.options; const header = this.header(); - const styles = ['node-item-type']; const type = node.type; const category = type && type.category ? type.category : ''; - if (category) { - styles.push(`node-item-type-${category.toLowerCase()}`); - } if (typeof type.name !== 'string' || !type.name.split) { // #416 const error = new view.Error(`Unsupported node type '${JSON.stringify(type.name)}'.`); if (this.context.model && this.context.model.identifier) { @@ -1980,6 +1990,7 @@ view.Node = class extends grapher.Node { if (content.length > 24) { content = `${content.substring(0, 12)}\u2026${content.substring(content.length - 12, content.length)}`; } + const styles = category ? ['node-item-type', `node-item-type-${category.toLowerCase()}`] : ['node-item-type']; const title = header.add(null, styles, content, tooltip); title.on('click', () => { this.context.activate(node); @@ -1992,113 +2003,98 @@ view.Node = class extends grapher.Node { // this._expand = header.add(null, styles, '+', null); // this._expand.on('click', () => this.toggle()); } - const initializers = []; - let hiddenInitializers = false; - if (options.weights) { - if (Array.isArray(node.inputs)) { - for (const input of node.inputs) { - if (input.visible !== false && input.value.length === 1 && input.value[0].initializer) { - initializers.push(input); - } - if ((input.visible === false || input.value.length > 1) && - (!input.type || input.type.endsWith('*')) && input.value.some((value) => value.initializer)) { - hiddenInitializers = true; - } - } + let current = null; + const list = () => { + if (!current) { + current = this.list(); + current.on('click', () => this.context.activate(node)); } - } + return current; + }; + let hiddenTensors = false; + const tensors = []; const objects = []; const attributes = []; - if (Array.isArray(node.attributes) && node.attributes.length > 0) { - for (const attribute of node.attributes) { - switch (attribute.type) { + if (Array.isArray(node.inputs)) { + for (const input of node.inputs) { + switch (input.type) { case 'graph': case 'object': case 'object[]': case 'function': case 'function[]': { - objects.push(attribute); + objects.push(input); break; } default: { - if (options.attributes && attribute.visible !== false) { - attributes.push(attribute); + if (options.weights && input.visible !== false && input.value.length === 1 && input.value[0].initializer) { + tensors.push(input); + } else if (options.weights && (input.visible === false || input.value.length > 1) && (!input.type || input.type.endsWith('*')) && input.value.some((value) => value.initializer)) { + hiddenTensors = true; + } else if (options.attributes && input.visible !== false && input.type && !input.type.endsWith('*')) { + attributes.push(input); } } } } - attributes.sort((a, b) => a.name.toUpperCase().localeCompare(b.name.toUpperCase())); } - if (Array.isArray(node.inputs)) { - for (const input of node.inputs) { - switch (input.type) { + if (Array.isArray(node.attributes)) { + for (const attribute of node.attributes) { + switch (attribute.type) { case 'graph': case 'object': case 'object[]': case 'function': case 'function[]': { - objects.push(input); + objects.push(attribute); break; } default: { - break; + if (options.attributes && attribute.visible !== false) { + attributes.push(attribute); + } } } } } - if (initializers.length > 0 || hiddenInitializers || attributes.length > 0 || objects.length > 0) { - const list = this.list(); - list.on('click', () => this.context.activate(node)); - for (const argument of initializers) { - const [value] = argument.value; - const type = value.type; - let shape = ''; - let separator = ''; - if (type && type.shape && type.shape.dimensions && Array.isArray(type.shape.dimensions)) { - shape = `\u3008${type.shape.dimensions.map((d) => (d !== null && d !== undefined) ? d : '?').join('\u00D7')}\u3009`; - if (type.shape.dimensions.length === 0 && value.initializer) { - try { - const tensor = new base.Tensor(value.initializer); - const encoding = tensor.encoding; - if ((encoding === '<' || encoding === '>' || encoding === '|') && !tensor.empty && tensor.type.dataType !== '?') { - shape = tensor.toString(); - if (shape && shape.length > 10) { - shape = `${shape.substring(0, 10)}\u2026`; - } - separator = ' = '; - } - } catch (error) { - this.context.view.exception(error, false); - } - } + if (attributes.length > 0) { + attributes.sort((a, b) => a.name.toUpperCase().localeCompare(b.name.toUpperCase())); + } + for (const argument of tensors) { + const item = this.context.createArgument(argument); + list().add(item); + } + if (hiddenTensors) { + const item = list().argument('\u3008\u2026\u3009', ''); + list().add(item); + } + for (const attribute of attributes) { + if (attribute.visible !== false) { + let value = new view.Formatter(attribute.value, attribute.type).toString(); + if (value && value.length > 12) { + value = `${value.substring(0, 12)}\u2026`; } - list.add(argument.name, shape, type ? type.toString() : '', separator); + const item = list().argument(attribute.name, value); + item.tooltip = attribute.type; + item.separator = ' = '; + list().add(item); } - if (hiddenInitializers) { - list.add('\u3008\u2026\u3009', '', null, ''); + } + for (const argument of objects) { + if (argument.type === 'graph') { + const node = this.context.createNode(null, argument.value); + const item = list().argument(argument.name, node); + list().add(item); } - for (const attribute of attributes) { - if (attribute.visible !== false) { - let value = new view.Formatter(attribute.value, attribute.type).toString(); - if (value && value.length > 25) { - value = `${value.substring(0, 25)}\u2026`; - } - list.add(attribute.name, value, attribute.type, ' = '); - } + if (argument.type === 'function' || argument.type === 'object') { + const node = this.context.createNode(argument.value); + const item = list().argument(argument.name, node); + list().add(item); } - for (const attribute of objects) { - if (attribute.type === 'graph') { - const node = this.context.createNode(null, attribute.value); - list.add(attribute.name, node, '', ''); - } - if (attribute.type === 'function' || attribute.type === 'object') { - const node = this.context.createNode(attribute.value); - list.add(attribute.name, node, '', ''); - } - if (attribute.type === 'function[]' || attribute.type === 'object[]') { - const nodes = attribute.value.map((value) => this.context.createNode(value)); - list.add(attribute.name, nodes, '', ''); - } + if (argument.type === 'function[]' || argument.type === 'object[]') { + const nodes = argument.value.map((value) => this.context.createNode(value)); + const item = list().argument(argument.name, nodes); + list().add(item); } } if (Array.isArray(node.nodes) && node.nodes.length > 0) { @@ -2292,22 +2288,43 @@ view.Value = class { } }; -view.Tensor = class { +view.Argument = class extends grapher.Argument { constructor(context, value) { + const name = value.name; + let content = ''; + let separator = ''; + let tooltip = ''; + if (Array.isArray(value.value) && value.value.length === 1 && value.value[0].initializer) { + const tensor = value.value[0].initializer; + const type = value.value[0].type; + if (type && type.shape && type.shape.dimensions && Array.isArray(type.shape.dimensions)) { + tooltip = type.toString(); + content = `\u3008${type.shape.dimensions.map((d) => (d !== null && d !== undefined) ? d : '?').join('\u00D7')}\u3009`; + if (type.shape.dimensions.length === 0) { + const formatter = new view.Formatter(tensor, 'tensor'); + content = formatter.toString(); + separator = ' = '; + } + } + } + super(name, content); this.context = context; this.value = value; + this.separator = separator; + this.tooltip = tooltip; } - select() { - return []; + focus() { + this.context.focus([this.value.value[0]]); } - deselect() { + blur() { + this.context.blur([this.value.value[0]]); } activate() { - this.context.view.showTensorProperties(this.value); + this.context.view.showTensorProperties(this.value.value[0]); } }; @@ -2326,23 +2343,16 @@ view.Edge = class extends grapher.Edge { return 1; } - emit(event) { - switch (event) { - case 'pointerover': { - this.value.context.focus([this.value.value]); - break; - } - case 'pointerleave': { - this.value.context.blur([this.value.value]); - break; - } - case 'click': { - this.value.context.activate(this.value.value); - break; - } - default: - break; - } + focus() { + this.value.context.focus([this.value.value]); + } + + blur() { + this.value.context.blur([this.value.value]); + } + + activate() { + this.value.context.activate(this.value.value); } }; @@ -2639,8 +2649,16 @@ view.NodeSidebar = class extends view.ObjectSidebar { addArgument(name, argument, source) { const value = new view.ArgumentView(this._view, argument, source); - value.on('focus', (sender, value) => this.emit('focus', value)); - value.on('blur', (sender, value) => this.emit('blur', value)); + value.on('focus', (sender, value) => { + this.emit('focus', value); + this._focused = this._focused || new Set(); + this._focused.add(value); + }); + value.on('blur', (sender, value) => { + this.emit('blur', value); + this._focused = this._focused || new Set(); + this._focused.delete(value); + }); value.on('select', (sender, value) => this.emit('select', value)); value.on('activate', (sender, value) => this.emit('activate', value)); this.addEntry(name, value); @@ -2649,6 +2667,16 @@ view.NodeSidebar = class extends view.ObjectSidebar { activate() { this.emit('select', this._node); } + + deactivate() { + this.emit('select', null); + if (this._focused) { + for (const value of this._focused) { + this.emit('blur', value); + } + this._focused.clear(); + } + } }; view.NameValueView = class extends view.Control { @@ -3217,6 +3245,8 @@ view.NodeListView = class extends view.Control { this._elements = []; for (const node of list) { const item = new view.NodeView(this._view, node); + item.on('focus', (sender, value) => this.emit('focus', value)); + item.on('blur', (sender, value) => this.emit('blur', value)); item.on('activate', (sender, value) => this.emit('activate', value)); item.on('deactivate', (sender, value) => this.emit('deactivate', value)); item.on('select', (sender, value) => this.emit('select', value)); @@ -3254,27 +3284,44 @@ view.ConnectionSidebar = class extends view.ObjectSidebar { } if (from) { this.addHeader('Inputs'); - const list = new view.NodeListView(this._view, [from]); - list.on('focus', (sender, value) => this.emit('focus', value)); - list.on('blur', (sender, value) => this.emit('blur', value)); - list.on('select', (sender, value) => this.emit('select', value)); - list.on('activate', (sender, value) => this.emit('activate', value)); - this.addEntry('from', list); + this.addNodeList('from', [from]); } if (Array.isArray(to) && to.length > 0) { this.addHeader('Outputs'); - const list = new view.NodeListView(this._view, to); - list.on('focus', (sender, value) => this.emit('focus', value)); - list.on('blur', (sender, value) => this.emit('blur', value)); - list.on('select', (sender, value) => this.emit('select', value)); - list.on('activate', (sender, value) => this.emit('activate', value)); - this.addEntry('to', list); + this.addNodeList('to', to); } } + addNodeList(name, list) { + const entry = new view.NodeListView(this._view, list); + entry.on('focus', (sender, value) => { + this.emit('focus', value); + this._focused = this._focused || new Set(); + this._focused.add(value); + }); + entry.on('blur', (sender, value) => { + this.emit('blur', value); + this._focused = this._focused || new Set(); + this._focused.delete(value); + }); + entry.on('select', (sender, value) => this.emit('select', value)); + entry.on('activate', (sender, value) => this.emit('activate', value)); + this.addEntry(name, entry); + } + activate() { this.emit('select', this._value); } + + deactivate() { + this.emit('select', null); + if (this._focused) { + for (const value of this._focused) { + this.emit('blur', value); + } + this._focused.clear(); + } + } }; view.TensorSidebar = class extends view.ObjectSidebar { @@ -3610,7 +3657,7 @@ view.FindSidebar = class extends view.Control { } _clear() { - for (const identifier in this._focused) { + for (const identifier of this._focused) { this._blur(identifier); } this._focused.clear(); @@ -3752,10 +3799,10 @@ view.FindSidebar = class extends view.Control { if (value.initializer && this._value(value)) { if (value.name && !edges.has(value.name)) { const content = `${value.name.split('\n').shift()}`; // split custom argument id - this._add(node, content, 'weight'); + this._add(value, content, 'weight'); } else if (value.type && value.type.shape && Array.isArray(value.type.shape.dimensions) && value.type.shape.dimensions.length > 0) { const content = `${value.type.shape.dimensions.map((d) => (d !== null && d !== undefined) ? d : '?').join('\u00D7')}`; - this._add(node, content, 'weight'); + this._add(value, content, 'weight'); } } } @@ -3834,6 +3881,7 @@ view.FindSidebar = class extends view.Control { for (const identifier of this._focused) { this._blur(identifier); } + this._focused.clear(); } error(error, fatal) { @@ -3844,15 +3892,6 @@ view.FindSidebar = class extends view.Control { } }; -view.Argument = class { - - constructor(name, value, type) { - this.name = name; - this.value = value; - this.type = type; - } -}; - view.Quantization = class { constructor(quantization) { @@ -4139,11 +4178,21 @@ view.Formatter = class { return value ? value.name : '(null)'; case 'graph[]': return value ? value.map((graph) => graph.name).join(', ') : '(null)'; - case 'tensor': - if (value && value.type && value.type.shape && value.type.shape.dimensions && value.type.shape.dimensions.length === 0) { - return value.toString(); + case 'tensor': { + const type = value.type; + if (type && type.shape && type.shape.dimensions && Array.isArray(type.shape.dimensions) && type.shape.dimensions.length === 0) { + const tensor = new base.Tensor(value); + const encoding = tensor.encoding; + if ((encoding === '<' || encoding === '>' || encoding === '|') && !tensor.empty && tensor.type.dataType !== '?') { + let content = tensor.toString(); + if (content && content.length > 10) { + content = `${content.substring(0, 10)}\u2026`; + } + return content; + } } - return '[...]'; + return '[\u2026]'; + } case 'object': case 'function': return value.type.name; @@ -4932,6 +4981,15 @@ markdown.Generator = class { } }; +metrics.Argument = class { + + constructor(name, value, type) { + this.name = name; + this.value = value; + this.type = type; + } +}; + metrics.Tensor = class { constructor(tensor) { @@ -4960,7 +5018,7 @@ metrics.Tensor = class { } } const value = parameters > 0 ? zeros / parameters : 0; - const argument = new view.Argument('sparsity', value, 'percentage'); + const argument = new metrics.Argument('sparsity', value, 'percentage'); this._metrics.push(argument); } }