Skip to content

Commit

Permalink
Merge pull request #201 from NeuralEnsemble/feat-use-caching
Browse files Browse the repository at this point in the history
Feat: use caching for methods that return the same value each time
  • Loading branch information
sanjayankur31 authored Sep 11, 2024
2 parents e00f20e + b7f76df commit 4baafdd
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 2 deletions.
1 change: 1 addition & 0 deletions neuroml/nml/gds_imports-template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import math
import typing
from functools import cached_property, lru_cache
from math import pi, sqrt
from operator import attrgetter

Expand Down
1 change: 1 addition & 0 deletions neuroml/nml/generatedssupersuper.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def parentinfo(self, return_format="string"):
"GdsCollector_",
"GeneratedsSuperSuper",
"attrgetter",
"cached_property",
]

# do not show parameters here, they are indicated by members below
Expand Down
15 changes: 15 additions & 0 deletions neuroml/nml/helper_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,7 @@ def __str__(self):
source='''\
# Get segment object by its id
@lru_cache(maxsize=1000)
def get_segment(self, segment_id: int) -> Segment:
"""Get segment object by its id
Expand Down Expand Up @@ -947,6 +948,7 @@ def get_segments_by_substring(self, substring: str) -> typing.Dict[str, Segment]
# Get the proximal point of a segment, even the proximal field is None and
# so the proximal point is on the parent (at a point set by fraction_along)
@lru_cache(maxsize=1000)
def get_actual_proximal(self, segment_id: str):
"""Get the proximal point of a segment.
Expand Down Expand Up @@ -976,6 +978,7 @@ def get_actual_proximal(self, segment_id: str):
return p
@lru_cache(maxsize=1000)
def get_segment_length(self, segment_id: str) -> float:
"""Get the length of the segment.
Expand All @@ -993,6 +996,7 @@ def get_segment_length(self, segment_id: str) -> float:
return length
@lru_cache(maxsize=1000)
def get_segment_surface_area(self, segment_id: str) -> float:
"""Get the surface area of the segment.
Expand All @@ -1010,6 +1014,7 @@ def get_segment_surface_area(self, segment_id: str) -> float:
return temp_seg.surface_area
@lru_cache(maxsize=1000)
def get_segment_volume(self, segment_id: str) -> float:
"""Get volume of segment
Expand All @@ -1029,6 +1034,14 @@ def get_segment_volume(self, segment_id: str) -> float:
def get_segment_ids_vs_segments(self) -> typing.Dict[str, Segment]:
"""Get a dictionary of segment IDs and the segments in the cell.
:return: dictionary with segment ID as key, and segment as value
"""
return self.segment_ids_vs_segments
@cached_property
def segment_ids_vs_segments(self) -> typing.Dict[str, Segment]:
"""Get a dictionary of segment IDs and the segments in the cell.
:return: dictionary with segment ID as key, and segment as value
"""
Expand All @@ -1038,6 +1051,7 @@ def get_segment_ids_vs_segments(self) -> typing.Dict[str, Segment]:
return segments
@lru_cache(maxsize=1000)
def get_all_segments_in_group(self,
segment_group: typing.Union[str, SegmentGroup],
assume_all_means_all: bool = True) -> typing.List[int]:
Expand Down Expand Up @@ -1235,6 +1249,7 @@ def get_ordered_segments_in_groups(self,
return ord_segs
@lru_cache(maxsize=1000)
def get_segment_group(self, sg_id: str) -> SegmentGroup:
"""Return the SegmentGroup object for the specified segment group id.
Expand Down
20 changes: 18 additions & 2 deletions neuroml/nml/nml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# -*- coding: utf-8 -*-

#
# Generated Tue Aug 20 10:23:13 2024 by generateDS.py version 2.44.1.
# Python 3.11.9 (main, Apr 17 2024, 00:00:00) [GCC 14.0.1 20240411 (Red Hat 14.0.1-0)]
# Generated Wed Sep 11 16:30:43 2024 by generateDS.py version 2.44.1.
# Python 3.11.9 (main, Aug 23 2024, 00:00:00) [GCC 14.2.1 20240801 (Red Hat 14.2.1-1)]
#
# Command line options:
# ('-o', 'nml.py')
Expand Down Expand Up @@ -36,6 +36,7 @@
import os
import re as re_
import typing
from functools import cached_property, lru_cache
from math import pi, sqrt
from operator import attrgetter

Expand Down Expand Up @@ -48429,6 +48430,7 @@ def _buildChildren(
super(Cell, self)._buildChildren(child_, node, nodeName_, True)

# Get segment object by its id
@lru_cache(maxsize=1000)
def get_segment(self, segment_id: int) -> Segment:
"""Get segment object by its id

Expand Down Expand Up @@ -48471,6 +48473,7 @@ def get_segments_by_substring(self, substring: str) -> typing.Dict[str, Segment]

# Get the proximal point of a segment, even the proximal field is None and
# so the proximal point is on the parent (at a point set by fraction_along)
@lru_cache(maxsize=1000)
def get_actual_proximal(self, segment_id: str):
"""Get the proximal point of a segment.

Expand Down Expand Up @@ -48504,6 +48507,7 @@ def get_actual_proximal(self, segment_id: str):

return p

@lru_cache(maxsize=1000)
def get_segment_length(self, segment_id: str) -> float:
"""Get the length of the segment.

Expand All @@ -48521,6 +48525,7 @@ def get_segment_length(self, segment_id: str) -> float:

return length

@lru_cache(maxsize=1000)
def get_segment_surface_area(self, segment_id: str) -> float:
"""Get the surface area of the segment.

Expand All @@ -48538,6 +48543,7 @@ def get_segment_surface_area(self, segment_id: str) -> float:

return temp_seg.surface_area

@lru_cache(maxsize=1000)
def get_segment_volume(self, segment_id: str) -> float:
"""Get volume of segment

Expand All @@ -48557,6 +48563,14 @@ def get_segment_volume(self, segment_id: str) -> float:
def get_segment_ids_vs_segments(self) -> typing.Dict[str, Segment]:
"""Get a dictionary of segment IDs and the segments in the cell.

:return: dictionary with segment ID as key, and segment as value
"""
return self.segment_ids_vs_segments

@cached_property
def segment_ids_vs_segments(self) -> typing.Dict[str, Segment]:
"""Get a dictionary of segment IDs and the segments in the cell.

:return: dictionary with segment ID as key, and segment as value
"""

Expand All @@ -48566,6 +48580,7 @@ def get_segment_ids_vs_segments(self) -> typing.Dict[str, Segment]:

return segments

@lru_cache(maxsize=1000)
def get_all_segments_in_group(
self,
segment_group: typing.Union[str, SegmentGroup],
Expand Down Expand Up @@ -48773,6 +48788,7 @@ def get_ordered_segments_in_groups(

return ord_segs

@lru_cache(maxsize=1000)
def get_segment_group(self, sg_id: str) -> SegmentGroup:
"""Return the SegmentGroup object for the specified segment group id.

Expand Down
1 change: 1 addition & 0 deletions neuroml/nml/regenerate-nml.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ reformat () {
if command -v ruff > /dev/null 2>&1
then
echo "Formatting new nml.py with ruff"
ruff check --select I --fix nml.py
ruff format nml.py
else
echo "ruff is not installed"
Expand Down
5 changes: 5 additions & 0 deletions neuroml/test/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,11 @@ def test_get_morphology_root(self):
root_seg = acell.get_segment(0)
new_id = 99999
root_seg.id = new_id

# clear the cache, otherwise it'll return the old value again in
# other function calls also
acell.get_segment.cache_clear()

# also update all descendents to ensure cell remains valid
for seg in acell.morphology.segments:
par = seg.parent
Expand Down

0 comments on commit 4baafdd

Please sign in to comment.