Skip to content

Commit

Permalink
WIP DVS - DVS node not given
Browse files Browse the repository at this point in the history
modifying graph extractor's name 2 index / index to module mapping and edges (in-place) to add entry for DVS node
  • Loading branch information
Willian-Girao committed Oct 28, 2024
1 parent e9bf8a1 commit d0ae8a5
Showing 1 changed file with 33 additions and 34 deletions.
67 changes: 33 additions & 34 deletions sinabs/backend/dynapcnn/nir_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,20 @@ def __init__(self, spiking_model: nn.Module, dummy_input: torch.tensor, dvs_inpu
spiking_model, dummy_input, model_name=None
).ignore_tensors()

# This var. will be set to `True` if `dvs_input == True` and `spiking_model` does not start with DVS layer.
need_dvs_node = self._need_dvs_node(spiking_model, dvs_input)
dvs_input_shape = None
if need_dvs_node:
# We need to provide `(height, width)` to the DVSLayer instance that will be the module of the node 'dvs'.
_, _, height, width = dummy_input.shape
dvs_input_shape = (height, width)

# Map node names to indices
self._name_2_indx_map = self._get_name_2_indx_map(nir_graph, need_dvs_node)
self._name_2_indx_map = self._get_name_2_indx_map(nir_graph)
# Extract edges list from graph
self._edges = self._get_edges_from_nir(nir_graph, self._name_2_indx_map) # @TODO edges need to be modified in place if DVS layer is needed.
self._edges = self._get_edges_from_nir(nir_graph, self._name_2_indx_map)
# Determine entry points to graph
self._entry_nodes = self._get_entry_nodes(self._edges) # @TODO maybe functionality has to change here a when DVS layer is needed.
self._entry_nodes = self._get_entry_nodes(self._edges)
# Store the associated `nn.Module` (layer) of each node.
self._indx_2_module_map = self._get_named_modules(spiking_model, need_dvs_node, dvs_input_shape)
self._indx_2_module_map = self._get_named_modules(spiking_model)

# True if `dvs_input == True` and `spiking_model` does not start with DVS layer.
if self._need_dvs_node(spiking_model, dvs_input):
# input shape for `DVSLayer` instance that will be the module of the node 'dvs'.
_, _, height, width = dummy_input.shape
self._add_dvs_node(dvs_input_shape=(height, width))

# Verify that graph is compatible
self.verify_graph_integrity()
Expand Down Expand Up @@ -230,6 +228,24 @@ def verify_graph_integrity(self):

####################################################### Pivate Methods #######################################################

def _add_dvs_node(self, dvs_input_shape: Tuple[int, int]) -> None:
""" In-place modification of `self._name_2_indx_map`, `self._indx_2_module_map`, and `self._edges` to accomodate the
creation of an extra node in the graph representing the DVS camera of the chip.
Parameters
----------
- dvs_input_shape (tuple): Shape of input in format `(height, width)`.
"""
# @TODO - not considering pooling after the DVSLayer yet.
# @TODO - does self._entry_nodes need to have the index of the DVS node?

# add name entry for node 'dvs'.
self._name_2_indx_map['dvs'] = len(self._name_2_indx_map)
# add module entry for node 'dvs'.
self._indx_2_module_map[self._name_2_indx_map['dvs']] = DVSLayer(input_shape=dvs_input_shape)
# set DVS node as input to each entry node of the graph.
self._edges.update({(self._name_2_indx_map['dvs'], entry_node) for entry_node in self._entry_nodes})

def _need_dvs_node(self, model: nn.Module, dvs_input: bool) -> bool:
""" Returns whether or not a node will need to be added to represent a `DVSLayer` instance. A new node will have
to be added if `model` does not start with a `DVSLayer` instance and `dvs_input == True`.
Expand All @@ -252,14 +268,12 @@ def _need_dvs_node(self, model: nn.Module, dvs_input: bool) -> bool:

return not isinstance(first_module, DVSLayer) and dvs_input

def _get_name_2_indx_map(self, nir_graph: nirtorch.graph.Graph, need_dvs_node: bool) -> Dict[str, int]:
"""Assign unique index to each node and return mapper from name to index. If `need_dvs_node == Ture` we want to
leave index `0` free to be assigned to the `DVSLayer` node that will have to be created.
def _get_name_2_indx_map(self, nir_graph: nirtorch.graph.Graph) -> Dict[str, int]:
"""Assign unique index to each node and return mapper from name to index.
Parameters
----------
- nir_graph (nirtorch.graph.Graph): a NIR graph representation of `spiking_model`.
- need_dvs_node (bool): True of `dvs_input == True` and `spiking_model` doesn't start with a `DVSLayer`.
Returns
----------
Expand All @@ -268,17 +282,10 @@ def _get_name_2_indx_map(self, nir_graph: nirtorch.graph.Graph, need_dvs_node: b
"""

# Start name indexing from 1 if a DVS node needs to be added
name_2_indx_map = {
node.name: (node_idx + 1 if need_dvs_node else node_idx)
for node_idx, node in enumerate(nir_graph.node_list)
return {
node.name: node_idx for node_idx, node in enumerate(nir_graph.node_list)
}

if need_dvs_node:
# Adds entry for the DVS node that needs to be created - default node name is 'dvs'
name_2_indx_map['dvs'] = 0

return name_2_indx_map

def _get_edges_from_nir(
self, nir_graph: nirtorch.graph.Graph, name_2_indx_map: Dict[str, int]
) -> Set[Edge]:
Expand Down Expand Up @@ -317,22 +324,18 @@ def _get_entry_nodes(self, edges: Set[Edge]) -> Set[Edge]:
all_sources, all_targets = zip(*edges)
return set(all_sources) - set(all_targets)

def _get_named_modules(self, model: nn.Module, need_dvs_node: bool, dvs_input_shape: Tuple[int, int]) -> Dict[int, nn.Module]:
def _get_named_modules(self, model: nn.Module) -> Dict[int, nn.Module]:
"""Find for each node in the graph what its associated layer in `model` is.
Parameters
----------
- model (nn.Module): the `spiking_model` used as argument to the class instance.
- need_dvs_node (bool): True of `dvs_input == True` and `spiking_model` doesn't start with a `DVSLayer`.
- dvs_input_shape (tuple): Shape of input in format `(height, width)`.
Returns
----------
- indx_2_module_map (dict): the mapping between a node (`key` as an `int`) and its module (`value` as a `nn.Module`).
"""

assert need_dvs_node and isinstance(dvs_input_shape, tuple), f"DVSLayer instantiation is needed but 'dvs_input_shape == {dvs_input_shape}'."

indx_2_module_map = dict()

for name, module in model.named_modules():
Expand All @@ -341,10 +344,6 @@ def _get_named_modules(self, model: nn.Module, need_dvs_node: bool, dvs_input_sh
if name in self._name_2_indx_map:
indx_2_module_map[self._name_2_indx_map[name]] = module

if need_dvs_node:
# Adds an entry for the `DVSLayer` node that is needed - default node name is 'dvs'
indx_2_module_map[self._name_2_indx_map['dvs']] = DVSLayer(input_shape=dvs_input_shape)

return indx_2_module_map

def _update_internal_representation(self, remapped_nodes: Dict[int, int]):
Expand Down

0 comments on commit d0ae8a5

Please sign in to comment.