Skip to content

Commit

Permalink
Reintroduce missing methods of DynapcnnNetwork: reset_states, `zero…
Browse files Browse the repository at this point in the history
…_grad`
  • Loading branch information
bauerfe committed Nov 12, 2024
1 parent 2680ac1 commit 1cb7b5c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 5 deletions.
2 changes: 1 addition & 1 deletion sinabs/backend/dynapcnn/dynapcnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def forward(self, x) -> List[torch.Tensor]:

def zero_grad(self, set_to_none: bool = False) -> None:
"""Call `zero_grad` method of spiking layer"""
return self._spk.zero_grad(set_to_none)
return self.spk.zero_grad(set_to_none)

def get_neuron_shape(self) -> Tuple[int, int, int]:
"""Return the output shape of the neuron layer.
Expand Down
70 changes: 66 additions & 4 deletions sinabs/backend/dynapcnn/dynapcnn_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,20 @@ def __init__(
- batch_size (optional int): If `None`, will try to infer the batch size from the model.
If int value is provided, it has to match the actual batch size of the model.
- dvs_input (bool): optional (default as `None`). Wether or not dynapcnn receive
input from its DVS camera. If a `DVSLayer` is part of `snn` and `dvs_input` is
false, the DVS sensor will be configured but its output will not be sent as input
to the chip. If `dvs_input` is `True` and `snn` does not contain a `DVSLayer`,
it will be added.
input from its DVS camera.
If a `DVSLayer` is part of `snn`...
... and `dvs_input` is `False`, its `disable_pixel_array` attribute
will be set `True`. This means the DVS sensor will be configured
upon deployment but its output will not be sent as input
... and `dvs_input` is `None`, the `disable_pixel_array` attribute
of the layer will not be changed.
... and `dvs_input` is `True`, `disable_pixel_array` will be set
`False`, so that the DVS sensor data is sent to the network.
If no `DVSLayer` is part of `snn`...
... and `dvs_input` is `False` or `None`, no `DVSLayer` will be added
and the DVS sensor will not be configured upon deployment.
... and `dvs_input` is `True`, a `DVSLayer` instance will be added
to the network, with `disable_pixel_array` set to `False`.
- discretize (bool): If `True`, discretize the parameters and thresholds. This is needed for uploading
weights to dynapcnn. Set to `False` only for testing purposes.
- weight_rescaling_fn (callable): a method that handles how the re-scaling factor for one or more `SumPool2d` projecting to
Expand Down Expand Up @@ -514,6 +524,58 @@ def has_dvs_layer(self) -> bool:
"""
return self.dvs_layer is not None

def zero_grad(self, set_to_none: bool = False) -> None:
""" Call `zero_grad` method of each DynapCNN layer
Parameters
----------
- set_to_none (bool): This argument is passed directly to the
`zero_grad` method of each DynapCNN layer
"""
for lyr in self.dynapcnn_layers.values():
lyr.zero_grad(set_to_none)

def reset_states(self, randomize=False):
"""Reset the states of the network.
Parameters
----------
- randomize (bool): If `False` (default), will set all states to 0.
Otherwise will set to random values.
Notes
-----
- Setting `randomize` to `True` is only supported for models that have
not yet been deployed on a SynSense device.
"""
if hasattr(self, "device") and isinstance(self.device, str): # pragma: no cover
device_name, _ = parse_device_id(self.device)
# Reset states on SynSense device
if device_name in ChipFactory.supported_devices:
config_builder = ChipFactory(self.device).get_config_builder()
# Set all the vmem states in the samna config to zero
config_builder.reset_states(self.samna_config, randomize=randomize)
self.samna_device.get_model().apply_configuration(self.samna_config)
# wait for the config to be written
time.sleep(1)
# Note: The below shouldn't be necessary ideally
# Erase all vmem memory
if not randomize:
if hasattr(self, "samna_input_graph"):
self.samna_input_graph.stop()
for lyr_idx in self.chip_layers_ordering:
config_builder.set_all_v_mem_to_zeros(
self.samna_device, lyr_idx
)
time.sleep(0.1)
self.samna_input_graph.start()
return

# Reset states of `DynapcnnLayer` instances
for layer in self.sequence:
if isinstance(layer, DynapcnnLayer):
layer.spk_layer.reset_states(randomize=randomize)

####################################################### Private Methods #######################################################

def _make_config(
Expand Down

0 comments on commit 1cb7b5c

Please sign in to comment.