Skip to content

Commit

Permalink
UP my solution
Browse files Browse the repository at this point in the history
  • Loading branch information
MattVerlynde committed Dec 22, 2023
1 parent 6ccb1be commit 1bce1ef
Showing 1 changed file with 42 additions and 8 deletions.
50 changes: 42 additions & 8 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
to compute distances between 2 sets of samples.
"""
import numpy as np
import pandas as pd
# import pandas as pd
from pandas.api.types import is_datetime64_dtype as is_datetime64

from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
Expand Down Expand Up @@ -82,6 +83,13 @@ def fit(self, X, y):
self : instance of KNearestNeighbors
The current instance of the classifier
"""

X, y = check_X_y(X, y)
check_classification_targets(y)
self.X_, self.y_ = X, y
self.classes_ = np.unique(self.y_)
self.n_features_in_ = X.shape[1]

return self

def predict(self, X):
Expand All @@ -97,7 +105,17 @@ def predict(self, X):
y : ndarray, shape (n_test_samples,)
Predicted class labels for each test data sample.
"""
y_pred = np.zeros(X.shape[0])
check_is_fitted(self)
X = check_array(X)

y_pred = np.zeros(X.shape[0]).astype(self.y_.dtype)
dist_matrix = pairwise_distances(X, self.X_, metric='euclidean')

for i in range(dist_matrix.shape[0]):
yi_near = (self.y_[np.argsort(dist_matrix[i])[:self.n_neighbors]])
near_unique, counts_unique = np.unique(yi_near, return_counts=True)
y_pred[i] = near_unique[np.argmax(counts_unique)]

return y_pred

def score(self, X, y):
Expand All @@ -115,7 +133,8 @@ def score(self, X, y):
score : float
Accuracy of the model computed for the (X, y) pairs.
"""
return 0.
s = np.where(self.predict(X) == y, 1, 0).mean()
return s


class MonthlySplit(BaseCrossValidator):
Expand Down Expand Up @@ -155,7 +174,16 @@ def get_n_splits(self, X, y=None, groups=None):
n_splits : int
The number of splits.
"""
return 0
if self.time_col == 'index':
X = X.reset_index()
self.X_time = X[self.time_col]

if not is_datetime64(self.X_time):
raise ValueError(
"The column {} is not a datetime.".format(self.time_col)
)

return len(self.X_time.dt.to_period('M').unique())-1

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
Expand All @@ -177,12 +205,18 @@ def split(self, X, y, groups=None):
idx_test : ndarray
The testing set indices for that split.
"""

n_samples = X.shape[0]
n_splits = self.get_n_splits(X, y, groups)

X = X.reset_index()

index_sorted = X.resample('M',
on=self.time_col
).apply(lambda x: x.index)
index_sorted.index = index_sorted.index.to_period('M')

for i in range(n_splits):
idx_train = range(n_samples)
idx_test = range(n_samples)
idx_train = np.array(index_sorted[index_sorted.index[i]])
idx_test = np.array(index_sorted[index_sorted.index[i+1]])
yield (
idx_train, idx_test
)

0 comments on commit 1bce1ef

Please sign in to comment.