Skip to content

Commit

Permalink
WIP DVS - DVS node not given
Browse files Browse the repository at this point in the history
dvs_input (bool) arg passed down to collect_dynapcnn_layer_info() to raise error if edge type involving DVS is not found when dvs_input == True.
  • Loading branch information
Willian-Girao committed Oct 28, 2024
1 parent cb63e6f commit 49c013a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
4 changes: 2 additions & 2 deletions sinabs/backend/dynapcnn/dynapcnn_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def __init__(
batch_size = sinabs.utils.get_smallest_compatible_time_dimension(snn)
# computational graph from original PyTorch module.
self._graph_extractor = GraphExtractor(
snn, torch.randn((batch_size, *self.input_shape))
snn, torch.randn((batch_size, *self.input_shape), self.dvs_input)
) # needs the batch dimension.

# Remove nodes of ignored classes (including merge nodes)
self._graph_extractor.remove_nodes_by_class(DEFAULT_IGNORED_LAYER_TYPES)

# Module to execute forward pass through network
self._dynapcnn_module = self._graph_extractor.get_dynapcnn_network_module(
discretize=discretize, weight_rescaling_fn=weight_rescaling_fn
discretize=discretize, weight_rescaling_fn=weight_rescaling_fn, dvs_input=self.dvs_input
)
self._dynapcnn_module.setup_dynapcnnlayer_graph(index_layers_topologically=True)

Expand Down
10 changes: 7 additions & 3 deletions sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def indx_2_module_map(self) -> Dict[int, nn.Module]:
return {n: module for n, module in self._indx_2_module_map.items()}

def get_dynapcnn_network_module(
self, discretize: bool = False, weight_rescaling_fn: Optional[Callable] = None
self, discretize: bool = False, weight_rescaling_fn: Optional[Callable] = None, dvs_input: bool = False
) -> DynapcnnNetworkModule:
""" Create DynapcnnNetworkModule based on stored graph representation
Expand All @@ -117,6 +117,7 @@ def get_dynapcnn_network_module(
weights to dynapcnn. Set to `False` only for testing purposes.
- weight_rescaling_fn (callable): a method that handles how the re-scaling factor for one or more `SumPool2d` projecting to
the same convolutional layer are combined/re-scaled before applying them.
- dvs_input (bool): wether or not dynapcnn receive input from its DVS camera.
Returns
-------
Expand All @@ -129,6 +130,7 @@ def get_dynapcnn_network_module(
edges = self.edges,
nodes_io_shapes=self.nodes_io_shapes,
entry_nodes=self.entry_nodes,
dvs_input=dvs_input,
)

# build `DynapcnnLayer` instances from mapper.
Expand Down Expand Up @@ -236,8 +238,8 @@ def _add_dvs_node(self, dvs_input_shape: Tuple[int, int]) -> None:
----------
- dvs_input_shape (tuple): Shape of input in format `(height, width)`.
"""
# @TODO - not considering pooling after the DVSLayer yet.
# @TODO - does self._entry_nodes need to have the index of the DVS node?

# [] @TODO - not considering pooling after the DVSLayer yet.

# add name entry for node 'dvs'.
self._name_2_indx_map['dvs'] = len(self._name_2_indx_map)
Expand All @@ -246,6 +248,8 @@ def _add_dvs_node(self, dvs_input_shape: Tuple[int, int]) -> None:
# set DVS node as input to each entry node of the graph.
self._edges.update({(self._name_2_indx_map['dvs'], entry_node) for entry_node in self._entry_nodes})

# [] @TODO - all indexes in 'self._entry_nodes' are no longer entry nodes of the network since a DVS layer is being added.

def _need_dvs_node(self, model: nn.Module, dvs_input: bool) -> bool:
""" Returns whether or not a node will need to be added to represent a `DVSLayer` instance. A new node will have
to be added if `model` does not start with a `DVSLayer` instance and `dvs_input == True`.
Expand Down
15 changes: 11 additions & 4 deletions sinabs/backend/dynapcnn/sinabs_edges_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def collect_dynapcnn_layer_info(
edges: Set[Edge],
nodes_io_shapes: Dict[int, Dict[str, Tuple[Size, Size]]],
entry_nodes: Set[int],
dvs_input: bool,
) -> Dict[int, Dict]:
"""Collect information to construct DynapcnnLayer instances.
Expand All @@ -37,10 +38,11 @@ def collect_dynapcnn_layer_info(
Parameters
----------
indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value`
edges (set of tuples): Represent connections between two nodes in computational graph
nodes_io_shapes (dict): Map from node ID to dict containing node's in- and output shapes
entry_nodes (set of int): IDs of nodes that receive external input
- indx_2_module_map (dict): Maps node IDs of the graph as `key` to their associated module as `value`
- edges (set of tuples): Represent connections between two nodes in computational graph
- nodes_io_shapes (dict): Map from node ID to dict containing node's in- and output shapes
- entry_nodes (set of int): IDs of nodes that receive external input
- dvs_input (bool): wether or not dynapcnn receive input from its DVS camera.
Returns
-------
Expand All @@ -62,6 +64,11 @@ def collect_dynapcnn_layer_info(
"None such weight-neuron pair has been found in the provided network."
)

if not any(edge in edges_by_type for edge in ["dvs-weight", "dvs-pooling"]) and dvs_input:
raise InvalidGraphStructure(
"DVS camera is set selected for usage (dvs_input == True) but edge type involving it has not been found."
)

# Dict to collect information for each future dynapcnn layer
dynapcnn_layer_info = dict()
# Map node IDs to dynapcnn layer ID
Expand Down

0 comments on commit 49c013a

Please sign in to comment.