Skip to content

Commit

Permalink
UP my solution
Browse files Browse the repository at this point in the history
  • Loading branch information
Aymane-Rahmoune committed Dec 20, 2023
1 parent 6ccb1be commit ac16a64
Showing 1 changed file with 63 additions and 9 deletions.
72 changes: 63 additions & 9 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@
"""
import numpy as np
import pandas as pd
# import datetime as date

from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.metrics import accuracy_score

from sklearn.model_selection import BaseCrossValidator

Expand Down Expand Up @@ -82,6 +84,13 @@ def fit(self, X, y):
self : instance of KNearestNeighbors
The current instance of the classifier
"""
X, y = check_X_y(X, y)
check_classification_targets(y)
self.classes_ = np.unique(y)
self.X_ = X
self.y_ = y
self.n_features_in_ = X.shape[1]

return self

def predict(self, X):
Expand All @@ -97,7 +106,25 @@ def predict(self, X):
y : ndarray, shape (n_test_samples,)
Predicted class labels for each test data sample.
"""
y_pred = np.zeros(X.shape[0])
check_is_fitted(self)
X = check_array(X)

# calcul distances
distances = pairwise_distances(X, self.X_)

# select k points les plus proches
closest = np.argsort(distances, axis=1)[:, : self.n_neighbors]
closest = self.y_[closest]

# choix du y le plus frequent
y_pred = np.apply_along_axis(
lambda x: np.unique(x, return_counts=True)[0][
np.argmax(np.unique(x, return_counts=True)[1])
],
axis=1,
arr=closest,
)

return y_pred

def score(self, X, y):
Expand All @@ -115,7 +142,9 @@ def score(self, X, y):
score : float
Accuracy of the model computed for the (X, y) pairs.
"""
return 0.
predictions = self.predict(X)

return accuracy_score(y, predictions)


class MonthlySplit(BaseCrossValidator):
Expand Down Expand Up @@ -155,7 +184,15 @@ def get_n_splits(self, X, y=None, groups=None):
n_splits : int
The number of splits.
"""
return 0
X = X.reset_index()

if X[self.time_col].dtype != "datetime64[ns]":
raise ValueError("datetime")
column_date = X[self.time_col]
max = column_date.max()
min = column_date.min()

return (max.year - min.year) * 12 + max.month - min.month

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
Expand All @@ -178,11 +215,28 @@ def split(self, X, y, groups=None):
The testing set indices for that split.
"""

n_samples = X.shape[0]
n_splits = self.get_n_splits(X, y, groups)
for i in range(n_splits):
idx_train = range(n_samples)
idx_test = range(n_samples)
yield (
idx_train, idx_test

X = X.reset_index()
X.index.names = ["Index_nb"]
X = X.reset_index()
X["Month"] = pd.DatetimeIndex(X[self.time_col]).month
X["Year"] = pd.DatetimeIndex(X[self.time_col]).year
X_ = X.copy()
X_ = X_[["Month", "Year"]].drop_duplicates().sort_values(
["Year", "Month"]
)
for i in range(n_splits):
idx_train = X.merge(
X_.iloc[[i]],
how="inner",
left_on=["Month", "Year"],
right_on=["Month", "Year"],
)["Index_nb"].to_numpy()
idx_test = X.merge(
X_.iloc[[i + 1]],
how="inner",
left_on=["Month", "Year"],
right_on=["Month", "Year"],
)["Index_nb"].to_numpy()
yield (idx_train, idx_test)

0 comments on commit ac16a64

Please sign in to comment.