From 55ec77032eb12eda3e406d15d6cd6c261fe9b3cb Mon Sep 17 00:00:00 2001 From: Lutz Roeder Date: Thu, 4 Jul 2024 10:07:31 -0700 Subject: [PATCH] Update view.js (#1285) --- source/grapher.css | 14 ++-- source/grapher.js | 99 ++++++++++++++------------ source/pytorch.js | 2 +- source/view.js | 170 ++++++++++++++++++++++++--------------------- 4 files changed, 155 insertions(+), 130 deletions(-) diff --git a/source/grapher.css b/source/grapher.css index 0739516707b..173e17ab049 100644 --- a/source/grapher.css +++ b/source/grapher.css @@ -51,10 +51,10 @@ .node-item-undefined:hover { cursor: pointer; } .node-item-undefined:hover path { fill: #fff; } -.node-attribute-list > path { fill: #fff; stroke-width: 0; stroke: #000; } -.node-attribute-list:hover { cursor: pointer; } -.node-attribute-list:hover > path { fill: #f6f6f6; } -.node-attribute > text { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", "Ubuntu", "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricPrecision; user-select: none; } +.node-argument-list > path { fill: #fff; stroke-width: 0; stroke: #000; } +.node-argument-list:hover { cursor: pointer; } +.node-argument-list:hover > path { fill: #f6f6f6; } +.node-argument > text { font-family: -apple-system, BlinkMacSystemFont, "Segoe WPC", "Segoe UI", "Ubuntu", "Droid Sans", sans-serif, "PingFang SC"; font-size: 9px; font-weight: normal; text-rendering: geometricPrecision; user-select: none; } .graph-item-input path { fill: #eee; } .graph-item-input:hover { cursor: pointer; } @@ -110,9 +110,9 @@ .node-item path { stroke: #fff; } .node-item text { fill: #dfdfdf; } - .node-attribute > text { fill: #b2b2b2; } - .node-attribute-list > path { fill: #2d2d2d; } - .node-attribute-list:hover > path { fill: #666666; } + .node-argument > text { fill: #b2b2b2; } + .node-argument-list > path { fill: #2d2d2d; } + .node-argument-list:hover > path { fill: #666666; } .graph-item-input path { fill: #404040; } .graph-item-input:hover { cursor: pointer; } diff --git a/source/grapher.js b/source/grapher.js index 7833323177d..9acca522f71 100644 --- a/source/grapher.js +++ b/source/grapher.js @@ -277,7 +277,7 @@ grapher.Node = class { } list() { - const block = new grapher.Node.List(); + const block = new grapher.ArgumentList(); this._blocks.push(block); return block; } @@ -524,17 +524,19 @@ grapher.Node.Header.Entry = class { } }; -grapher.Node.List = class { +grapher.ArgumentList = class { constructor() { this._items = []; this._events = {}; } - add(name, value, tooltip, separator) { - const item = new grapher.Node.List.Item(name, value, tooltip, separator); - this._items.push(item); - return item; + argument(name, value) { + return new grapher.Argument(name, value); + } + + add(value) { + this._items.push(value); } on(event, callback) { @@ -553,7 +555,7 @@ grapher.Node.List = class { build(document, parent) { this._document = document; this.element = document.createElementNS('http://www.w3.org/2000/svg', 'g'); - this.element.setAttribute('class', 'node-attribute-list'); + this.element.setAttribute('class', 'node-argument-list'); if (this._events.click) { this.element.addEventListener('click', (e) => { e.stopPropagation(); @@ -564,38 +566,7 @@ grapher.Node.List = class { this.element.appendChild(this.background); parent.appendChild(this.element); for (const item of this._items) { - const group = document.createElementNS('http://www.w3.org/2000/svg', 'g'); - group.setAttribute('class', 'node-attribute'); - const text = document.createElementNS('http://www.w3.org/2000/svg', 'text'); - text.setAttribute('xml:space', 'preserve'); - if (item.tooltip) { - const title = document.createElementNS('http://www.w3.org/2000/svg', 'title'); - title.textContent = item.tooltip; - text.appendChild(title); - } - const colon = item.type === 'node' || item.type === 'node[]'; - const name = document.createElementNS('http://www.w3.org/2000/svg', 'tspan'); - name.textContent = colon ? `${item.name}:` : item.name; - if (item.separator.trim() !== '=' && !colon) { - name.style.fontWeight = 'bold'; - } - text.appendChild(name); - group.appendChild(text); - this.element.appendChild(group); - item.group = group; - item.text = text; - if (item.type === 'node') { - const node = item.value; - node.build(document, item.group); - } else if (item.type === 'node[]') { - for (const node of item.value) { - node.build(document, item.group); - } - } else { - const tspan = document.createElementNS('http://www.w3.org/2000/svg', 'tspan'); - tspan.textContent = item.separator + item.value; - item.text.appendChild(tspan); - } + item.build(document, this.element); } if (!this.first) { this.line = document.createElementNS('http://www.w3.org/2000/svg', 'line'); @@ -697,19 +668,59 @@ grapher.Node.List = class { } }; -grapher.Node.List.Item = class { +grapher.Argument = class { - constructor(name, value, tooltip, separator) { + constructor(name, value) { this.name = name; this.value = value; - this.tooltip = tooltip; - this.separator = separator; if (value instanceof grapher.Node) { this.type = 'node'; } else if (Array.isArray(value) && value.every((value) => value instanceof grapher.Node)) { this.type = 'node[]'; } } + + build(document, parent) { + const group = document.createElementNS('http://www.w3.org/2000/svg', 'g'); + group.setAttribute('class', 'node-argument'); + const text = document.createElementNS('http://www.w3.org/2000/svg', 'text'); + text.setAttribute('xml:space', 'preserve'); + if (this.tooltip) { + const title = document.createElementNS('http://www.w3.org/2000/svg', 'title'); + title.textContent = this.tooltip; + text.appendChild(title); + } + const colon = this.type === 'node' || this.type === 'node[]'; + const name = document.createElementNS('http://www.w3.org/2000/svg', 'tspan'); + name.textContent = colon ? `${this.name}:` : this.name; + if (this.separator && this.separator.trim() !== '=' && !colon) { + name.style.fontWeight = 'bold'; + } + text.appendChild(name); + group.appendChild(text); + parent.appendChild(group); + this.group = group; + this.text = text; + switch (this.type) { + case 'node': { + const node = this.value; + node.build(document, this.group); + break; + } + case 'node[]': { + for (const node of this.value) { + node.build(document, this.group); + } + break; + } + default: { + const tspan = document.createElementNS('http://www.w3.org/2000/svg', 'tspan'); + tspan.textContent = (this.separator || '') + this.value; + this.text.appendChild(tspan); + break; + } + } + } }; grapher.Node.Canvas = class { @@ -947,4 +958,4 @@ grapher.Edge.Path = class { } }; -export const { Graph, Node, Edge } = grapher; +export const { Graph, Node, Edge, Argument } = grapher; diff --git a/source/pytorch.js b/source/pytorch.js index 5840e271e24..dea02a904e8 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -291,7 +291,7 @@ pytorch.Node = class { return type; }; const createAttribute = (metadata, name, value) => { - let visible = false; + let visible = true; let type = null; if (name === 'training') { visible = false; diff --git a/source/view.js b/source/view.js index 0f5de7e6ab4..d1568d40ee1 100644 --- a/source/view.js +++ b/source/view.js @@ -1719,6 +1719,7 @@ view.Graph = class extends grapher.Graph { this.options = options; this._nodeKey = 0; this._values = new Map(); + this._tensors = new Map(); this._table = new Map(); this._selection = new Set(); } @@ -1765,8 +1766,15 @@ view.Graph = class extends grapher.Graph { } createTensor(value) { - const obj = new view.Tensor(this, value); - this._table.set(value, obj); + if (this._tensors.has(value)) { + const obj = this._tensors.get(value); + this._table.set(value, obj); + } else { + const obj = new view.Tensor(this, value); + this._tensors.set(value, obj); + this._table.set(value, obj); + } + return this._tensors.get(value); } add(graph, signature) { @@ -1962,12 +1970,8 @@ view.Node = class extends grapher.Node { _add(node) { const options = this.context.options; const header = this.header(); - const styles = ['node-item-type']; const type = node.type; const category = type && type.category ? type.category : ''; - if (category) { - styles.push(`node-item-type-${category.toLowerCase()}`); - } if (typeof type.name !== 'string' || !type.name.split) { // #416 const error = new view.Error(`Unsupported node type '${JSON.stringify(type.name)}'.`); if (this.context.model && this.context.model.identifier) { @@ -1980,6 +1984,7 @@ view.Node = class extends grapher.Node { if (content.length > 24) { content = `${content.substring(0, 12)}\u2026${content.substring(content.length - 12, content.length)}`; } + const styles = category ? ['node-item-type', `node-item-type-${category.toLowerCase()}`] : ['node-item-type']; const title = header.add(null, styles, content, tooltip); title.on('click', () => { this.context.activate(node); @@ -1992,113 +1997,122 @@ view.Node = class extends grapher.Node { // this._expand = header.add(null, styles, '+', null); // this._expand.on('click', () => this.toggle()); } - const initializers = []; - let hiddenInitializers = false; - if (options.weights) { - if (Array.isArray(node.inputs)) { - for (const input of node.inputs) { - if (input.visible !== false && input.value.length === 1 && input.value[0].initializer) { - initializers.push(input); - } - if ((input.visible === false || input.value.length > 1) && - (!input.type || input.type.endsWith('*')) && input.value.some((value) => value.initializer)) { - hiddenInitializers = true; - } - } + let current = null; + const list = () => { + if (!current) { + current = this.list(); + current.on('click', () => this.context.activate(node)); } - } + return current; + }; + let hiddenTensors = false; + const tensors = []; const objects = []; const attributes = []; - if (Array.isArray(node.attributes) && node.attributes.length > 0) { - for (const attribute of node.attributes) { - switch (attribute.type) { + if (Array.isArray(node.inputs)) { + for (const input of node.inputs) { + switch (input.type) { case 'graph': case 'object': case 'object[]': case 'function': case 'function[]': { - objects.push(attribute); + objects.push(input); break; } default: { - if (options.attributes && attribute.visible !== false) { - attributes.push(attribute); + if (options.weights && input.visible !== false && input.value.length === 1 && input.value[0].initializer) { + tensors.push(input); + } else if (options.weights && (input.visible === false || input.value.length > 1) && (!input.type || input.type.endsWith('*')) && input.value.some((value) => value.initializer)) { + hiddenTensors = true; + } else if (options.attributes && input.visible !== false && input.type && !input.type.endsWith('*')) { + attributes.push(input); } } } } - attributes.sort((a, b) => a.name.toUpperCase().localeCompare(b.name.toUpperCase())); } - if (Array.isArray(node.inputs)) { - for (const input of node.inputs) { - switch (input.type) { + if (Array.isArray(node.attributes)) { + for (const attribute of node.attributes) { + switch (attribute.type) { case 'graph': case 'object': case 'object[]': case 'function': case 'function[]': { - objects.push(input); + objects.push(attribute); break; } default: { - break; + if (options.attributes && attribute.visible !== false) { + attributes.push(attribute); + } } } } } - if (initializers.length > 0 || hiddenInitializers || attributes.length > 0 || objects.length > 0) { - const list = this.list(); - list.on('click', () => this.context.activate(node)); - for (const argument of initializers) { - const [value] = argument.value; - const type = value.type; - let shape = ''; - let separator = ''; - if (type && type.shape && type.shape.dimensions && Array.isArray(type.shape.dimensions)) { - shape = `\u3008${type.shape.dimensions.map((d) => (d !== null && d !== undefined) ? d : '?').join('\u00D7')}\u3009`; - if (type.shape.dimensions.length === 0 && value.initializer) { - try { - const tensor = new base.Tensor(value.initializer); - const encoding = tensor.encoding; - if ((encoding === '<' || encoding === '>' || encoding === '|') && !tensor.empty && tensor.type.dataType !== '?') { - shape = tensor.toString(); - if (shape && shape.length > 10) { - shape = `${shape.substring(0, 10)}\u2026`; - } - separator = ' = '; + if (attributes.length > 0) { + attributes.sort((a, b) => a.name.toUpperCase().localeCompare(b.name.toUpperCase())); + } + for (const argument of tensors) { + const [value] = argument.value; + const type = value.type; + let shape = ''; + let separator = ''; + if (type && type.shape && type.shape.dimensions && Array.isArray(type.shape.dimensions)) { + shape = `\u3008${type.shape.dimensions.map((d) => (d !== null && d !== undefined) ? d : '?').join('\u00D7')}\u3009`; + if (type.shape.dimensions.length === 0 && value.initializer) { + try { + const tensor = new base.Tensor(value.initializer); + const encoding = tensor.encoding; + if ((encoding === '<' || encoding === '>' || encoding === '|') && !tensor.empty && tensor.type.dataType !== '?') { + shape = tensor.toString(); + if (shape && shape.length > 10) { + shape = `${shape.substring(0, 10)}\u2026`; } - } catch (error) { - this.context.view.exception(error, false); + separator = ' = '; } + } catch (error) { + this.context.view.exception(error, false); } } - list.add(argument.name, shape, type ? type.toString() : '', separator); } - if (hiddenInitializers) { - list.add('\u3008\u2026\u3009', '', null, ''); - } - for (const attribute of attributes) { - if (attribute.visible !== false) { - let value = new view.Formatter(attribute.value, attribute.type).toString(); - if (value && value.length > 25) { - value = `${value.substring(0, 25)}\u2026`; - } - list.add(attribute.name, value, attribute.type, ' = '); + const item = list().argument(argument.name, shape); + item.tooltip = type ? type.toString() : ''; + item.separator = separator; + list().add(item); + } + if (hiddenTensors) { + const item = list().argument('\u3008\u2026\u3009', ''); + list().add(item); + } + for (const attribute of attributes) { + if (attribute.visible !== false) { + let value = new view.Formatter(attribute.value, attribute.type).toString(); + if (value && value.length > 25) { + value = `${value.substring(0, 25)}\u2026`; } + const item = list().argument(attribute.name, value); + item.tooltip = attribute.type; + item.separator = ' = '; + list().add(item); } - for (const attribute of objects) { - if (attribute.type === 'graph') { - const node = this.context.createNode(null, attribute.value); - list.add(attribute.name, node, '', ''); - } - if (attribute.type === 'function' || attribute.type === 'object') { - const node = this.context.createNode(attribute.value); - list.add(attribute.name, node, '', ''); - } - if (attribute.type === 'function[]' || attribute.type === 'object[]') { - const nodes = attribute.value.map((value) => this.context.createNode(value)); - list.add(attribute.name, nodes, '', ''); - } + } + for (const argument of objects) { + if (argument.type === 'graph') { + const node = this.context.createNode(null, argument.value); + const item = list().argument(argument.name, node); + list().add(item); + } + if (argument.type === 'function' || argument.type === 'object') { + const node = this.context.createNode(argument.value); + const item = list().argument(argument.name, node); + list().add(item); + } + if (argument.type === 'function[]' || argument.type === 'object[]') { + const nodes = argument.value.map((value) => this.context.createNode(value)); + const item = list().argument(argument.name, nodes); + list().add(item); } } if (Array.isArray(node.nodes) && node.nodes.length > 0) {