Skip to content

Commit

Permalink
UP MY SOLUTION
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyan168 committed Dec 21, 2023
1 parent 6ccb1be commit dd92da9
Showing 1 changed file with 43 additions and 12 deletions.
55 changes: 43 additions & 12 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@

from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from collections import Counter

from sklearn.model_selection import BaseCrossValidator

from sklearn.utils.validation import check_X_y, check_is_fitted
from sklearn.utils.validation import check_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.multiclass import (
check_classification_targets, unique_labels)
from sklearn.metrics.pairwise import pairwise_distances


Expand All @@ -82,6 +84,13 @@ def fit(self, X, y):
self : instance of KNearestNeighbors
The current instance of the classifier
"""

X, y = check_X_y(X, y)
check_classification_targets(y)
self.X_ = X
self.y_ = y
self.classes_ = unique_labels(y)
self.n_features_in_ = X.shape[1]
return self

def predict(self, X):
Expand All @@ -97,8 +106,21 @@ def predict(self, X):
y : ndarray, shape (n_test_samples,)
Predicted class labels for each test data sample.
"""
y_pred = np.zeros(X.shape[0])
return y_pred
check_is_fitted(self)
X = check_array(X)
y_pred = []
for i, x in enumerate(X):
distances = pairwise_distances(self.X_, x[np.newaxis, :]).ravel()
# get top k nearest neighbors
args_k_nearest = np.argpartition(distances, self.n_neighbors)[
: self.n_neighbors
]
# most frequent strategy
y_pred.append(
Counter(self.y_[args_k_nearest]).most_common(1)[0][0]
)

return np.array(y_pred)

def score(self, X, y):
"""Calculate the score of the prediction.
Expand All @@ -115,7 +137,10 @@ def score(self, X, y):
score : float
Accuracy of the model computed for the (X, y) pairs.
"""
return 0.
X, y = check_X_y(X, y)
check_classification_targets(y)
y_pred = self.predict(X)
return (y_pred == y).mean()


class MonthlySplit(BaseCrossValidator):
Expand Down Expand Up @@ -155,7 +180,14 @@ def get_n_splits(self, X, y=None, groups=None):
n_splits : int
The number of splits.
"""
return 0
if not isinstance(X.index, pd.RangeIndex):
X = X.reset_index()

# check datatype
if X[self.time_col].dtype != "datetime64[ns]":
raise ValueError("datetime")

return len(X.resample("M", on=self.time_col)) - 1

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
Expand All @@ -177,12 +209,11 @@ def split(self, X, y, groups=None):
idx_test : ndarray
The testing set indices for that split.
"""

n_samples = X.shape[0]
if isinstance(X, pd.Series):
X = X.to_frame()
X = X.reset_index()
n_splits = self.get_n_splits(X, y, groups)
ids_by_mth = [groups.index.values
for _, groups in X.resample("M", on=self.time_col)]
for i in range(n_splits):
idx_train = range(n_samples)
idx_test = range(n_samples)
yield (
idx_train, idx_test
)
yield ids_by_mth[i], ids_by_mth[i + 1]

0 comments on commit dd92da9

Please sign in to comment.