Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#238 multiple datasets #311

Draft
wants to merge 8 commits into
base: develop
Choose a base branch
from
2 changes: 1 addition & 1 deletion pybop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
#
# Dataset class
#
from ._dataset import Dataset
from ._dataset import Dataset, DatasetCollection

#
# Model classes
Expand Down
122 changes: 121 additions & 1 deletion pybop/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def __init__(self, data_dictionary):

if isinstance(data_dictionary, pybamm.solvers.solution.Solution):
data_dictionary = data_dictionary.get_data_dict()
if not isinstance(data_dictionary, dict):
if not isinstance(data_dictionary, (dict, Dataset)):
raise ValueError("The input to pybop.Dataset must be a dictionary.")
self.data = data_dictionary
self.names = self.data.keys()
self.n_datasets = 1

def __repr__(self):
"""
Expand Down Expand Up @@ -131,3 +132,122 @@ def check(self, signal=["Voltage [V]"]):
raise ValueError(f"Time data and {s} data must be the same length.")

return True


class DatasetCollection:
"""
Represents a collection of Datasets. Provides a simple way to
handle multiple sets of experimental data.

Parameters
----------
datasets: list
The datasets to store within the dataset collection.
Individual datasets can either be a dict, or a Dataset
instance.
"""

def __init__(self, datasets):
"""
Initialize a DatasetCollection instance with either a list of
Dataset objects, or a list of dictionaries to be turned into
Datasets
"""
self.n_datasets = len(datasets)
try:
self.datasets = [Dataset(data) for data in datasets]
except AttributeError:
self.datasets = datasets
self.names = set().union(*(data.names for data in self.datasets))
self.data = {k: [data[k] for data in datasets] for k in self.names}

def check(self, signal=["Voltage [V]"]):
"""
Check the consistency of each PyBOP Dataset against the
expected format.

Returns
-------
bool
If True, the dataset has the expected attributes.

Raises
------
ValueError
If the time series and the data series are not consistent.
"""
for dataset in self.datasets:
dataset.check(signal)
return True

def __len__(self):
return self.n_datasets

def __repr__(self):
"""
Return a string representation of the Dataset instance.

Returns
-------
str
A string that includes the type and contents of the dataset.
"""
return f"Dataset: {type(self.data)} \n Contains: {self.names}"

def __setitem__(self, key, value):
"""
Set the data corresponding to a particular key.

Parameters
----------
key : str
The name of the key to be set.

value : list or np.ndarray
The data series to be stored in the dataset.
"""
self.data[key] = value

def __getitem__(self, key):
"""
Return the data corresponding to a particular key.

Parameters
----------
key : str
The name of a data series within the dataset.

Returns
-------
list or np.ndarray
The data series corresonding to the key.

Raises
------
ValueError
The key must exist in the dataset.
"""
if key not in self.data.keys():
raise ValueError(f"The key {key} does not exist in this dataset.")

return self.data[key]

def __iter__(self):
"""
DatasetCollection can be iterated over, to get each individual
dataset comprising the collection.
"""
self.__iter_index = 0
return self

def __next__(self):
"""
DatasetCollection can be iterated over, to get each individual
dataset comprising the collection.
"""
if self.__iter_index < len(self.datasets):
value = self.datasets[self.__iter_index]
self.__iter_index += 1
return value
else:
raise StopIteration
22 changes: 15 additions & 7 deletions pybop/_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Optimisation:
If True, the optimization progress is printed (default: False).
physical_viability : bool, optional
If True, the feasibility of the optimised parameters is checked (default: True).
allow_infeasible_solutions : bool, optional
allow_infeasible_solutions : bool or list[bool], optional
If True, infeasible parameter values will be allowed in the optimisation (default: True).

Attributes
Expand Down Expand Up @@ -65,9 +65,14 @@ def __init__(

# Set whether to allow infeasible locations
if self.cost.problem is not None and hasattr(self.cost.problem, "_model"):
self.cost.problem._model.allow_infeasible_solutions = (
self.allow_infeasible_solutions
)
try:
for model, allow_infeasible in zip(
self.cost.problem._model, self.allow_infeasible_solutions
):
model.allow_infeasible_solutions = allow_infeasible
except TypeError:
for model in self.cost.problem._model:
model.allow_infeasible_solutions = self.allow_infeasible_solutions
else:
# Turn off this feature as there is no model
self.physical_viability = False
Expand Down Expand Up @@ -517,13 +522,16 @@ def check_optimal_parameters(self, x):
Check if the optimised parameters are physically viable.
"""

if self.cost.problem._model.check_params(
inputs=x, allow_infeasible_solutions=False
if all(
[
model.check_params(inputs=x, allow_infeasible_solutions=False)
for model in self.cost.problem._model
]
):
return
else:
warnings.warn(
"Optimised parameters are not physically viable! \nConsider retrying the optimisation"
"One or more optimised parameters are not physically viable! \nConsider retrying the optimisation"
+ " with a non-gradient-based optimiser and the option allow_infeasible_solutions=False",
UserWarning,
stacklevel=2,
Expand Down
6 changes: 5 additions & 1 deletion pybop/costs/base_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class BaseCost:
problem : object
A problem instance containing the data and functions necessary for
evaluating the cost function.
weights : np.ndarray
Weights for operating one cost function over multiple
datasets. combined_cost = np.inner(weights, individual_costs)
_target : array-like
An array containing the target data to fit.
x0 : array-like
Expand All @@ -33,8 +36,9 @@ class BaseCost:
The number of outputs in the model.
"""

def __init__(self, problem=None, sigma=None):
def __init__(self, problem=None, sigma=None, weights=None):
self.problem = problem
self.weights = weights
self.x0 = None
self.bounds = None
self.sigma0 = sigma
Expand Down
Loading
Loading