Skip to content

Commit

Permalink
Final test
Browse files Browse the repository at this point in the history
  • Loading branch information
Michel-debug committed Dec 21, 2023
1 parent 0505309 commit ad7cb62
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 33 deletions.
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
54 changes: 21 additions & 33 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, n_neighbors=1): # noqa: D107
def fit(self, X, y):
"""
Fitting function.
Parameters
----------
X : ndarray, shape (n_samples, n_features)
Expand Down Expand Up @@ -102,7 +102,7 @@ def fit(self, X, y):
def predict(self, X):
"""
Predict function.
Parameters
----------
X : ndarray, shape (n_test_samples, n_features)
Expand Down Expand Up @@ -133,7 +133,7 @@ def predict(self, X):
def score(self, X, y):
"""
Calculate the score of the prediction.
Parameters
----------
X : ndarray, shape (n_samples, n_features)
Expand All @@ -160,7 +160,7 @@ def score(self, X, y):
class MonthlySplit(BaseCrossValidator):
"""
CrossValidator based on monthly split.
Split data based on the given `time_col` (or default to index). Each split
corresponds to one month of data for the training and the next month of
data for the test.
Expand All @@ -180,7 +180,7 @@ def __init__(self, time_col='index'): # noqa: D107
def get_n_splits(self, X, y=None, groups=None):
"""
Return the number of splitting iterations in the cross-validator.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Expand All @@ -197,23 +197,19 @@ def get_n_splits(self, X, y=None, groups=None):
The number of splits.
"""
# Check whether time_col points to the index of the DataFrame
if self.time_col == 'index':
dates = X.index
else:
# check time_col if exist in column
if self.time_col not in X.columns:
raise KeyError(f"{self.time_col} is not in the columns of X")
dates = X[self.time_col]
if not pd.api.types.is_datetime64_any_dtype(dates):
raise ValueError("The time column should be datetime type")

months = np.unique(dates.to_period('M'))
return max(0, len(months) - 1)
X = X.reset_index()
time_col = X[self.time_col]
if not np.issubdtype(time_col.dtype, np.datetime64):
raise ValueError(
'DataFrame should have at least one datetime column'
)
n_splits = len(pd.to_datetime(time_col).dt.to_period('M').unique()) - 1
return n_splits

def split(self, X, y, groups=None):
"""
Generate indices to split data into training and test set.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Expand All @@ -233,25 +229,17 @@ def split(self, X, y, groups=None):
idx_test : ndarray
The testing set indices for that split.
"""
# obtain time data
if self.time_col == 'index':
dates = X.index
else:
dates = X[self.time_col]

# Convert to month
months = dates.to_period('M')
unique_months = np.unique(months)
n_splits = self.get_n_splits(X, y, groups)
X = X.reset_index().resample(
'M', on=self.time_col
).apply(
lambda x: x.index
)

for i in range(n_splits):
# Next month as a train set
train_mask = months == unique_months[i]
idx_train = np.where(train_mask)[0]

# Next month as a test set
test_mask = months == unique_months[i + 1]
idx_test = np.where(test_mask)[0]
idx_train = X.iloc[i].values
idx_test = X.iloc[i + 1].values
yield (
idx_train, idx_test
)

0 comments on commit ad7cb62

Please sign in to comment.