Skip to content

Commit

Permalink
UP my solution
Browse files Browse the repository at this point in the history
  • Loading branch information
RomainPoupon committed Dec 22, 2023
1 parent 6ccb1be commit e1d3eeb
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
to compute distances between 2 sets of samples.
"""
import numpy as np
import pandas as pd
# import pandas as pd

from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
Expand All @@ -59,6 +59,7 @@
from sklearn.utils.validation import check_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.metrics.pairwise import pairwise_distances
from collections import Counter


class KNearestNeighbors(BaseEstimator, ClassifierMixin):
Expand All @@ -82,6 +83,13 @@ def fit(self, X, y):
self : instance of KNearestNeighbors
The current instance of the classifier
"""
X, y = check_X_y(X, y)
check_classification_targets(y)

self.sample_ = X
self.features_ = y
self.classes_ = np.unique(y)
self.n_features_in_ = X.shape[1]
return self

def predict(self, X):
Expand All @@ -97,8 +105,19 @@ def predict(self, X):
y : ndarray, shape (n_test_samples,)
Predicted class labels for each test data sample.
"""
y_pred = np.zeros(X.shape[0])
return y_pred
check_is_fitted(self)
X = check_array(X)
ind = np.argsort(pairwise_distances(X,
self.sample_,
metric='euclidean'), axis=1)

ind = ind[:, 0:self.n_neighbors]

ydat = self.features_[ind]
y_pred = []
for i in ydat:
y_pred.append(Counter(i).most_common(1)[0][0])
return np.array(y_pred)

def score(self, X, y):
"""Calculate the score of the prediction.
Expand All @@ -115,7 +134,7 @@ def score(self, X, y):
score : float
Accuracy of the model computed for the (X, y) pairs.
"""
return 0.
return np.mean(self.predict(X) == y)


class MonthlySplit(BaseCrossValidator):
Expand Down Expand Up @@ -155,11 +174,15 @@ def get_n_splits(self, X, y=None, groups=None):
n_splits : int
The number of splits.
"""
return 0
X = X.reset_index()
if X[self.time_col].dtype != 'datetime64[ns]':
raise ValueError("The column is not a datetime")
X = X.set_index(self.time_col)
n_splits = X[self.time_col].dt.to_period("M").nunique()-1
return n_splits

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Expand All @@ -177,12 +200,9 @@ def split(self, X, y, groups=None):
idx_test : ndarray
The testing set indices for that split.
"""

n_samples = X.shape[0]
n_splits = self.get_n_splits(X, y, groups)
for i in range(n_splits):
idx_train = range(n_samples)
idx_test = range(n_samples)
yield (
idx_train, idx_test
)
yield (idx_train, idx_test)

0 comments on commit e1d3eeb

Please sign in to comment.