Skip to content

Commit

Permalink
WIP DVS - DVS node not given
Browse files Browse the repository at this point in the history
moved get_entry_nodes() to after the DVS node creating has been handled..
  • Loading branch information
Willian-Girao committed Oct 28, 2024
1 parent 5366513 commit 460e7c6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def __init__(self, spiking_model: nn.Module, dummy_input: torch.tensor, dvs_inpu
self._name_2_indx_map = self._get_name_2_indx_map(nir_graph)
# Extract edges list from graph
self._edges = self._get_edges_from_nir(nir_graph, self._name_2_indx_map)
# Determine entry points to graph
self._entry_nodes = self._get_entry_nodes(self._edges)
# Store the associated `nn.Module` (layer) of each node.
self._indx_2_module_map = self._get_named_modules(spiking_model)

Expand All @@ -72,6 +70,9 @@ def __init__(self, spiking_model: nn.Module, dummy_input: torch.tensor, dvs_inpu
_, _, height, width = dummy_input.shape
self._add_dvs_node(dvs_input_shape=(height, width))

# Determine entry points to graph
self._entry_nodes = self._get_entry_nodes(self._edges)

# Verify that graph is compatible
self.verify_graph_integrity()

Expand Down Expand Up @@ -240,6 +241,7 @@ def _add_dvs_node(self, dvs_input_shape: Tuple[int, int]) -> None:
"""

# [] @TODO - not considering pooling after the DVSLayer yet.
# [] @TODO - I/O shape in 'self._nodes_io_shapes' not being handled yet.

# add name entry for node 'dvs'.
self._name_2_indx_map['dvs'] = len(self._name_2_indx_map)
Expand All @@ -248,8 +250,6 @@ def _add_dvs_node(self, dvs_input_shape: Tuple[int, int]) -> None:
# set DVS node as input to each entry node of the graph.
self._edges.update({(self._name_2_indx_map['dvs'], entry_node) for entry_node in self._entry_nodes})

# [] @TODO - all indexes in 'self._entry_nodes' are no longer entry nodes of the network since a DVS layer is being added.

def _need_dvs_node(self, model: nn.Module, dvs_input: bool) -> bool:
""" Returns whether or not a node will need to be added to represent a `DVSLayer` instance. A new node will have
to be added if `model` does not start with a `DVSLayer` instance and `dvs_input == True`.
Expand Down

0 comments on commit 460e7c6

Please sign in to comment.