Skip to content

Commit

Permalink
fix: dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
eri24816 committed Jun 14, 2024
1 parent d106ef7 commit 46044ba
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
14 changes: 13 additions & 1 deletion extensions/grapycal_torch/grapycal_torch/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from grapycal.stores import main_store
from grapycal_torch.store import GrapycalTorchStore
from topicsync.topic import GenericTopic
from torch import Tensor
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -56,7 +57,18 @@ def task(self):
)
return
if self.to_defalut_device.get():
batch = self.get_store(GrapycalTorchStore).to_default_device(batch)
if isinstance(batch, dict):
for key in batch:
if isinstance(batch[key], Tensor):
batch[key] = self.get_store(
GrapycalTorchStore
).to_default_device(batch[key])
elif isinstance(batch, Tensor):
batch = self.get_store(GrapycalTorchStore).to_default_device(
batch
)
else:
pass
self.out.push(batch)
self.flash_running_indicator()
yield
Expand Down
2 changes: 1 addition & 1 deletion extensions/grapycal_torch/grapycal_torch/moduleNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def destroy(self):

class SimpleModuleNode(ModuleNode):
module_type: type[nn.Module] = nn.Module
inputs: list[str] = ["input"]
inputs: list[str] = []
max_in_degree = [1]
outputs = ["output"]
display_port_names: bool | None = None
Expand Down

0 comments on commit 46044ba

Please sign in to comment.