diff --git a/sklearn_questions.py b/sklearn_questions.py index dd03014..696ea2c 100644 --- a/sklearn_questions.py +++ b/sklearn_questions.py @@ -69,8 +69,10 @@ def __init__(self, n_neighbors=1): # noqa: D107 self.n_neighbors = n_neighbors def fit(self, X, y): - """Fitting function. - Parameters + """ + Fitting function. + + Parameters ---------- X : ndarray, shape (n_samples, n_features) Data to train the model. @@ -98,7 +100,9 @@ def fit(self, X, y): return self def predict(self, X): - """Predict function. + """ + Predict function. + Parameters ---------- X : ndarray, shape (n_test_samples, n_features) @@ -127,7 +131,9 @@ def predict(self, X): return y_pred def score(self, X, y): - """Calculate the score of the prediction. + """ + Calculate the score of the prediction. + Parameters ---------- X : ndarray, shape (n_samples, n_features) @@ -152,7 +158,9 @@ def score(self, X, y): class MonthlySplit(BaseCrossValidator): - """CrossValidator based on monthly split. + """ + CrossValidator based on monthly split. + Split data based on the given `time_col` (or default to index). Each split corresponds to one month of data for the training and the next month of data for the test. @@ -170,7 +178,9 @@ def __init__(self, time_col='index'): # noqa: D107 self.time_col = time_col def get_n_splits(self, X, y=None, groups=None): - """Return the number of splitting iterations in the cross-validator. + """ + Return the number of splitting iterations in the cross-validator. + Parameters ---------- X : array-like of shape (n_samples, n_features) @@ -201,7 +211,9 @@ def get_n_splits(self, X, y=None, groups=None): return max(0, len(months) - 1) def split(self, X, y, groups=None): - """Generate indices to split data into training and test set. + """ + Generate indices to split data into training and test set. + Parameters ---------- X : array-like of shape (n_samples, n_features)