Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small changes to DimensionSet #32

Merged
merged 5 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sodym/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class DataReader(ABC):
"""
def read_dimensions(self, dimension_definitions: List[DimensionDefinition]) -> DimensionSet:
dimensions = [self.read_dimension(definition) for definition in dimension_definitions]
return DimensionSet(dimensions=dimensions)
return DimensionSet(dim_list=dimensions)

@abstractmethod
def read_dimension(self, dimension_definition: DimensionDefinition) -> Dimension:
Expand Down
39 changes: 27 additions & 12 deletions sodym/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class DimensionSet(PydanticBaseModel):
"""

dimensions: list[Dimension]
dim_list: list[Dimension]

@model_validator(mode='after')
def no_repeated_dimensions(self):
Expand All @@ -69,29 +69,29 @@ def no_repeated_dimensions(self):
def drop(self, key: str, inplace: bool=False):
dim_to_drop = self._dict[key]
if not inplace:
dimensions = copy(self.dimensions)
dimensions = copy(self.dim_list)
dimensions.remove(dim_to_drop)
return DimensionSet(dimensions=dimensions)
self.dimensions.remove(dim_to_drop)
return DimensionSet(dim_list=dimensions)
self.dim_list.remove(dim_to_drop)

@property
def _dict(self) -> Dict[str, Dimension]:
"""Contains mappings.
letter --> dim object and name --> dim object
"""
return {dim.name: dim for dim in self.dimensions} | {dim.letter: dim for dim in self.dimensions}
return {dim.name: dim for dim in self.dim_list} | {dim.letter: dim for dim in self.dim_list}

def __getitem__(self, key) -> Dimension:
if isinstance(key, str):
return self._dict[key]
elif isinstance(key, int):
return self.dimensions[key]
return self.dim_list[key]
else:
raise TypeError("Key must be string or int")

def __iter__(self):
return iter(self.dimensions)
return iter(self.dim_list)

def size(self, key: str):
return self._dict[key].len
Expand All @@ -101,26 +101,41 @@ def shape(self, keys: tuple = None):
return tuple(self.size(key) for key in keys)

def get_subset(self, dims: tuple = None) -> 'DimensionSet':
"""Selects :py:class:`Dimension` objects from the object attribute dimensions,
"""Selects :py:class:`Dimension` objects from the object attribute dim_list,
according to the dims passed, which can be either letters or names.
Returns a copy if dims are not given.
"""
subset = copy(self)
if dims is not None:
subset.dimensions = [self._dict[dim_key] for dim_key in dims]
subset.dim_list = [self._dict[dim_key] for dim_key in dims]
return subset

def expand_by(self, added_dims: list[Dimension]) -> 'DimensionSet':
"""Expands the DimensionSet by adding new dimensions to it.
"""
if not all([dim.letter not in self.letters for dim in added_dims]):
raise ValueError('DimensionSet already contains one or more of the dimensions to be added.')
return DimensionSet(dim_list=self.dim_list + added_dims)

def intersect_with(self, other: 'DimensionSet') -> 'DimensionSet':
intersection_letters = [dim.letter for dim in self.dim_list if dim.letter in other.letters]
return self.get_subset(intersection_letters)

def union_with(self, other: 'DimensionSet') -> 'DimensionSet':
added_dims = [dim for dim in other.dim_list if dim.letter not in self.letters]
return self.expand_by(added_dims)

@property
def names(self):
return tuple([dim.name for dim in self.dimensions])
return tuple([dim.name for dim in self.dim_list])

@property
def letters(self):
return tuple([dim.letter for dim in self.dimensions])
return tuple([dim.letter for dim in self.dim_list])

@property
def string(self):
return "".join(self.letters)

def index(self, key):
return [d.letter for d in self.dimensions].index(key)
return [d.letter for d in self.dim_list].index(key)
2 changes: 1 addition & 1 deletion sodym/export/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def fill_fig_ax(self):
def get_x_array_like_value_array(self):
if self.x_array is None:
x_dim_obj = self.array.dims[self.intra_line_dim]
x_dimset = DimensionSet(dimensions=[x_dim_obj])
x_dimset = DimensionSet(dim_list=[x_dim_obj])
self.x_array = NamedDimArray(dims=x_dimset, values=np.array(x_dim_obj.items), name=self.intra_line_dim)
self.x_array = self.x_array.cast_to(self.array.dims)

Expand Down
27 changes: 6 additions & 21 deletions sodym/named_dim_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,58 +130,43 @@ def _prepare_other(self, other):
other = NamedDimArray(dims=self.dims, values=other * np.ones(self.shape))
return other

def intersect_dims_with(self, other):
matching_dims = []
for dim in self.dims.dimensions:
if dim.letter in other.dims.letters:
matching_dims.append(dim)
return DimensionSet(dimensions=matching_dims)

def union_dims_with(self, other):
all_dims = copy(self.dims.dimensions)
letters_self = self.dims.letters
for dim in other.dims.dimensions:
if dim.letter not in letters_self:
all_dims.append(dim)
return DimensionSet(dimensions=all_dims)

def __add__(self, other):
other = self._prepare_other(other)
dims_out = self.intersect_dims_with(other)
dims_out = self.dims.intersect_with(other.dims)
return NamedDimArray(
dims=dims_out, values=self.sum_values_to(dims_out.letters) + other.sum_values_to(dims_out.letters)
)

def __sub__(self, other):
other = self._prepare_other(other)
dims_out = self.intersect_dims_with(other)
dims_out = self.dims.intersect_with(other.dims)
return NamedDimArray(
dims=dims_out, values=self.sum_values_to(dims_out.letters) - other.sum_values_to(dims_out.letters)
)

def __mul__(self, other):
other = self._prepare_other(other)
dims_out = self.union_dims_with(other)
dims_out = self.dims.union_with(other.dims)
values_out = np.einsum(f"{self.dims.string},{other.dims.string}->{dims_out.string}", self.values, other.values)
return NamedDimArray(dims=dims_out, values=values_out)

def __truediv__(self, other):
other = self._prepare_other(other)
dims_out = self.union_dims_with(other)
dims_out = self.dims.union_with(other.dims)
values_out = np.einsum(
f"{self.dims.string},{other.dims.string}->{dims_out.string}", self.values, 1.0 / other.values
)
return NamedDimArray(dims=dims_out, values=values_out)

def minimum(self, other):
other = self._prepare_other(other)
dims_out = self.intersect_dims_with(other)
dims_out = self.dims.intersect_with(other.dims)
values_out = np.minimum(self.sum_values_to(dims_out.letters), other.sum_values_to(dims_out.letters))
return NamedDimArray(dims=dims_out, values=values_out)

def maximum(self, other):
other = self._prepare_other(other)
dims_out = self.intersect_dims_with(other)
dims_out = self.dims.intersect_with(other.dims)
values_out = np.maximum(self.sum_values_to(dims_out.letters), other.sum_values_to(dims_out.letters))
return NamedDimArray(dims=dims_out, values=values_out)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ def test_validate_dimension_set():
{'name': 'time', 'letter': 't', 'items': [1990, 2000, 2010]},
{'name': 'place', 'letter': 'p', 'items': ['World', ]}
]
DimensionSet(dimensions=dimensions)
DimensionSet(dim_list=dimensions)

# example with repeated dimension letters in DimensionSet
dimensions.append(
{'name': 'another_time', 'letter': 't', 'items': [2020, 2030]}
)
with pytest.raises(ValidationError) as error_msg:
DimensionSet(dimensions=dimensions)
DimensionSet(dim_list=dimensions)
assert 'letter' in str(error_msg.value)


Expand All @@ -29,11 +29,11 @@ def test_get_subset():
material_dimension = {'name': 'material', 'letter': 'm', 'items': ['material_0', 'material_1']}

parent_dimensions = subset_dimensions + [material_dimension]
dimension_set = DimensionSet(dimensions=parent_dimensions)
dimension_set = DimensionSet(dim_list=parent_dimensions)

# example of subsetting the dimension set using dimension letters
subset_from_letters = dimension_set.get_subset(dims=('t', 'p'))
assert subset_from_letters == DimensionSet(dimensions=subset_dimensions)
assert subset_from_letters == DimensionSet(dim_list=subset_dimensions)

# example of subsetting the dimension set using dimension names
subset_from_names = dimension_set.get_subset(dims=('time', 'place'))
Expand Down
10 changes: 5 additions & 5 deletions tests/test_named_dim_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
{'name': 'place', 'letter': 'p', 'items': ['Earth', 'Sun', 'Moon', 'Venus']},
{'name': 'time', 'letter': 't', 'items': [1990, 2000, 2010]},
]
dims = DimensionSet(dimensions=dimensions)
dims = DimensionSet(dim_list=dimensions)
values = np.random.rand(4, 3)
numbers = NamedDimArray(name='two', dims=dims, values=values)

animals = {'name': 'animal', 'letter': 'a', 'items': ['cat', 'mouse']}
dims_incl_animals = DimensionSet(dimensions=dimensions+[animals])
dims_incl_animals = DimensionSet(dim_list=dimensions+[animals])
animal_values = np.random.rand(4, 3, 2)
space_animals = NamedDimArray(name='space_animals', dims=dims_incl_animals, values=animal_values)

Expand All @@ -25,7 +25,7 @@ def test_named_dim_array_validations():
{'name': 'place', 'letter': 'p', 'items': ['World', ]},
{'name': 'time', 'letter': 't', 'items': [1990, 2000, 2010]},
]
dims = DimensionSet(dimensions=dimensions)
dims = DimensionSet(dim_list=dimensions)

# example with values with the correct shape
NamedDimArray(name='numbers', dims=dims, values=np.array([[1, 2, 3], ]))
Expand All @@ -52,15 +52,15 @@ def test_cast_to():
assert_almost_equal(np.sum(casted_named_dim_array.values), 2 * np.sum(values))

# example with differently ordered dimensions
target_dims = DimensionSet(dimensions=[animals]+dimensions[::-1])
target_dims = DimensionSet(dim_list=[animals]+dimensions[::-1])
casted_named_dim_array = numbers.cast_to(target_dims=target_dims)
assert casted_named_dim_array.values.shape == (2, 3, 4)


def test_sum_nda_to():
# sum over one dimension
summed_named_dim_array = space_animals.sum_nda_to(result_dims=('p', 't'))
assert summed_named_dim_array.dims == DimensionSet(dimensions=dimensions)
assert summed_named_dim_array.dims == DimensionSet(dim_list=dimensions)
assert_array_almost_equal(summed_named_dim_array.values, np.sum(animal_values, axis=2))

# sum over two dimensions
Expand Down