Skip to content

Commit

Permalink
feat: add metrics.svg
Browse files Browse the repository at this point in the history
  • Loading branch information
eri24816 committed Apr 23, 2024
1 parent aa56219 commit 89ac766
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,10 @@ def render_from_fig(
return buf

@contextmanager
def open_fig() -> Generator[Tuple[Figure, Axes], None, None]:
def open_fig(equal=False) -> Generator[Tuple[Figure, Axes], None, None]:
fig = plt.figure()
ax = fig.gca()
ax.set_facecolor("black")
ax.set_aspect("equal")
try:
yield fig, ax
finally:
Expand Down Expand Up @@ -153,7 +152,7 @@ def build_node(self):
self.cmap = self.add_attribute(
"cmap",
StringTopic,
"gray",
"viridis",
editor_type="options",
options=[
"gray",
Expand Down Expand Up @@ -412,7 +411,7 @@ def preprocess_data(self, data):
return data

def update_image(self, data):
with open_fig() as (fig, ax):
with open_fig(equal=True) as (fig, ax):
for d in data:
if len(d.shape) == 3:
for slice in d:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import io
import pprint

from grapycal.sobjects.controls import TextControl
from grapycal.sobjects.edge import Edge
from grapycal.sobjects.node import RESTORE_FROM, Node
from grapycal.sobjects.node import Node
from grapycal.sobjects.port import InputPort


def get_pprint_str(data):
output = io.StringIO(newline="")
pprint.pprint(data, stream=output)
return output.getvalue()


class PrintNode(Node):
'''
Display the data received from the input edge.
Expand All @@ -24,13 +33,13 @@ def build_node(self):
def edge_activated(self, edge, port):
self.flash_running_indicator()
data = edge.get()
self.text_control.text.set(str(data))
self.text_control.text.set(get_pprint_str(data))

def input_edge_added(self, edge: Edge, port: InputPort):
if edge.is_data_ready():
self.flash_running_indicator()
data = edge.get()
self.text_control.text.set(str(data))
self.text_control.text.set(get_pprint_str(data))

def input_edge_removed(self, edge: Edge, port: InputPort):
self.text_control.text.set('')
5 changes: 3 additions & 2 deletions extensions/grapycal_torch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def task(self):

ds = []
for i in range(size):
ds.append(raw_ds[i])
pair = raw_ds[i]
ds.append({'image': pair[0], 'label': pair[1]})

if self.include_labels.get() == 'False':
ds = [x[0] for x in ds]
ds = [x['image'] for x in ds]

self.out.push(ds)
3 changes: 3 additions & 0 deletions extensions/grapycal_torch/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class BCEWithLogitsLossNode(SimpleModuleNode):
def build_node(self):
super().build_node()
self.label.set('BCEWithLogitsLoss')
self.icon_path.set('metrics')

def create_module(self) -> nn.Module:
return nn.BCEWithLogitsLoss()
Expand All @@ -29,6 +30,7 @@ class CrossEntropyLossNode(SimpleModuleNode):
def build_node(self):
super().build_node()
self.label.set('CrossEntropyLoss')
self.icon_path.set('metrics')

def create_module(self) -> nn.Module:
return nn.CrossEntropyLoss()
Expand All @@ -47,6 +49,7 @@ class MSELossNode(SimpleModuleNode):
def build_node(self):
super().build_node()
self.label.set('MSELoss')
self.icon_path.set('metrics')

def create_module(self) -> nn.Module:
return nn.MSELoss()
Expand Down
1 change: 1 addition & 0 deletions extensions/grapycal_torch/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def build_node(self):
super().build_node()
self.label.set('Accuracy')
self.shape.set('simple')
self.icon_path.set('metrics')

def calculate(self, prediction, target) -> Any:
#TODO: if target is one-hot encoded, convert it to class labels
Expand Down
3 changes: 3 additions & 0 deletions extensions/grapycal_torch/networkDef.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class NetworkCallNode(Node):
def build_node(self,name:str="Network"):
self.label.set('')
self.shape.set('normal')
self.icon_path.set('nn')
self.network_name = self.add_attribute('network name',StringTopic,editor_type='text',init_value=name)
self.network_name.add_validator(lambda x,_: x != '') # empty name may confuse users
self.mode_control = self.add_option_control(name='mode',options=['train','eval'],value='train',label='Mode')
Expand Down Expand Up @@ -142,6 +143,7 @@ def build_node(self,name:str="Network",inputs:List[str]|None=None):
inputs = ['x']

self.shape.set('normal')
self.icon_path.set('nn')

# setup attributes
# The self.outs attribute is actually "inputs" of the network, but it was mistakenly named "outs" and I didn't want to change it to avoid breaking backwards compatibility
Expand Down Expand Up @@ -229,6 +231,7 @@ def build_node(self,name:str="Network",outputs:List[str]|None=None):
if outputs is None:
outputs = ['y']
self.shape.set('normal')
self.icon_path.set('nn')

# setup attributes
self.ins = self.add_attribute('ins',ListTopic,editor_type='list',init_value=outputs,display_name='outputs')
Expand Down
11 changes: 6 additions & 5 deletions extensions/grapycal_torch/tensor_operations.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from ast import pattern
from typing import Any
from grapycal import FunctionNode, IntTopic, StringTopic
from grapycal.extension.utils import NodeInfo
from grapycal.sobjects.controls.textControl import TextControl

import einops
import torch
import torch.nn.functional as F
import einops
from grapycal import FunctionNode, IntTopic
from grapycal.extension.utils import NodeInfo
from grapycal.sobjects.controls.textControl import TextControl


class CatNode(FunctionNode):
category = 'torch/operations'
Expand Down
2 changes: 1 addition & 1 deletion frontend/dist/svg
Submodule svg updated from 6dcae2 to fc39ca

0 comments on commit 89ac766

Please sign in to comment.