From 1bce1ef2e75f60860f0421b7a64ac524b805819d Mon Sep 17 00:00:00 2001 From: MattVerlynde Date: Fri, 22 Dec 2023 10:04:17 +0100 Subject: [PATCH] UP my solution --- sklearn_questions.py | 50 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/sklearn_questions.py b/sklearn_questions.py index fa02e0d..2aab4d2 100644 --- a/sklearn_questions.py +++ b/sklearn_questions.py @@ -48,7 +48,8 @@ to compute distances between 2 sets of samples. """ import numpy as np -import pandas as pd +# import pandas as pd +from pandas.api.types import is_datetime64_dtype as is_datetime64 from sklearn.base import BaseEstimator from sklearn.base import ClassifierMixin @@ -82,6 +83,13 @@ def fit(self, X, y): self : instance of KNearestNeighbors The current instance of the classifier """ + + X, y = check_X_y(X, y) + check_classification_targets(y) + self.X_, self.y_ = X, y + self.classes_ = np.unique(self.y_) + self.n_features_in_ = X.shape[1] + return self def predict(self, X): @@ -97,7 +105,17 @@ def predict(self, X): y : ndarray, shape (n_test_samples,) Predicted class labels for each test data sample. """ - y_pred = np.zeros(X.shape[0]) + check_is_fitted(self) + X = check_array(X) + + y_pred = np.zeros(X.shape[0]).astype(self.y_.dtype) + dist_matrix = pairwise_distances(X, self.X_, metric='euclidean') + + for i in range(dist_matrix.shape[0]): + yi_near = (self.y_[np.argsort(dist_matrix[i])[:self.n_neighbors]]) + near_unique, counts_unique = np.unique(yi_near, return_counts=True) + y_pred[i] = near_unique[np.argmax(counts_unique)] + return y_pred def score(self, X, y): @@ -115,7 +133,8 @@ def score(self, X, y): score : float Accuracy of the model computed for the (X, y) pairs. """ - return 0. + s = np.where(self.predict(X) == y, 1, 0).mean() + return s class MonthlySplit(BaseCrossValidator): @@ -155,7 +174,16 @@ def get_n_splits(self, X, y=None, groups=None): n_splits : int The number of splits. """ - return 0 + if self.time_col == 'index': + X = X.reset_index() + self.X_time = X[self.time_col] + + if not is_datetime64(self.X_time): + raise ValueError( + "The column {} is not a datetime.".format(self.time_col) + ) + + return len(self.X_time.dt.to_period('M').unique())-1 def split(self, X, y, groups=None): """Generate indices to split data into training and test set. @@ -177,12 +205,18 @@ def split(self, X, y, groups=None): idx_test : ndarray The testing set indices for that split. """ - - n_samples = X.shape[0] n_splits = self.get_n_splits(X, y, groups) + + X = X.reset_index() + + index_sorted = X.resample('M', + on=self.time_col + ).apply(lambda x: x.index) + index_sorted.index = index_sorted.index.to_period('M') + for i in range(n_splits): - idx_train = range(n_samples) - idx_test = range(n_samples) + idx_train = np.array(index_sorted[index_sorted.index[i]]) + idx_test = np.array(index_sorted[index_sorted.index[i+1]]) yield ( idx_train, idx_test )