Skip to content

Commit

Permalink
final solution
Browse files Browse the repository at this point in the history
  • Loading branch information
Michel-debug committed Dec 21, 2023
1 parent 6ccb1be commit cd81a6a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 9 deletions.
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
81 changes: 72 additions & 9 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from sklearn.utils.validation import check_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.preprocessing import LabelEncoder


class KNearestNeighbors(BaseEstimator, ClassifierMixin):
Expand All @@ -76,7 +77,20 @@ def fit(self, X, y):
Data to train the model.
y : ndarray, shape (n_samples,)
Labels associated with the training data.
"""
# Validate the input
X, y = check_X_y(X, y)
# Check if the target is suitable for classification
check_classification_targets(y)
# init LabelEncoder
self.label_encoder_ = LabelEncoder()
self.X_train_ = X
# Convert y using LabelEncoder for different types of labels
self.y_train_ = self.label_encoder_.fit_transform(y)
# set up classes_ attribut after convert
self.classes_ = self.label_encoder_.classes_
self.n_features_in_ = X.shape[1]
"""
Returns
----------
self : instance of KNearestNeighbors
Expand All @@ -91,13 +105,27 @@ def predict(self, X):
----------
X : ndarray, shape (n_test_samples, n_features)
Data to predict on.
"""
# Check if fit had been called
check_is_fitted(self, ['X_train_', 'y_train_'])
# Input validation
X = check_array(X)
# Compute distances from each point in X to each point in self.X_train_
distances = pairwise_distances(X, self.X_train_, metric='euclidean')
y_train_int = self.y_train_.astype(int)
# Find the indices of the k closest training samples
nearest_indices = np.argsort(distances, axis=1)[:, :self.n_neighbors]
# Predict the most common label of the k nearest training samples
y_pred_indices = np.array([np.argmax(np.bincount(y_train_int[indices]))
for indices in nearest_indices])
# Predicts and converts integer labels back to the original label
y_pred = self.label_encoder_.inverse_transform(y_pred_indices)
"""
Returns
----------
y : ndarray, shape (n_test_samples,)
Predicted class labels for each test data sample.
"""
y_pred = np.zeros(X.shape[0])
return y_pred

def score(self, X, y):
Expand All @@ -109,13 +137,21 @@ def score(self, X, y):
Data to score on.
y : ndarray, shape (n_samples,)
target values.
"""
# Check if fit had been called
check_is_fitted(self, ['X_train_', 'y_train_'])

# Input validation
X, y = check_X_y(X, y)

y_pred = self.predict(X)
"""
Returns
----------
score : float
Accuracy of the model computed for the (X, y) pairs.
"""
return 0.
return np.mean(y_pred == y)


class MonthlySplit(BaseCrossValidator):
Expand Down Expand Up @@ -155,7 +191,19 @@ def get_n_splits(self, X, y=None, groups=None):
n_splits : int
The number of splits.
"""
return 0
# Check whether time_col points to the index of the DataFrame
if self.time_col == 'index':
dates = X.index
else:
# check time_col if exist in column
if self.time_col not in X.columns:
raise KeyError(f"{self.time_col} is not in the columns of X")
dates = X[self.time_col]
if not pd.api.types.is_datetime64_any_dtype(dates):
raise ValueError("The time column should be datetime type")

months = np.unique(dates.to_period('M'))
return max(0, len(months) - 1)

def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
Expand All @@ -169,20 +217,35 @@ def split(self, X, y, groups=None):
Always ignored, exists for compatibility.
groups : array-like of shape (n_samples,)
Always ignored, exists for compatibility.
"""

"""
Yields
------
idx_train : ndarray
The training set indices for that split.
idx_test : ndarray
The testing set indices for that split.
"""

n_samples = X.shape[0]
# obtain time data
if self.time_col == 'index':
dates = X.index
else:
dates = X[self.time_col]

# Convert to month
months = dates.to_period('M')
unique_months = np.unique(months)
n_splits = self.get_n_splits(X, y, groups)

for i in range(n_splits):
idx_train = range(n_samples)
idx_test = range(n_samples)
# Next month as a train set
train_mask = months == unique_months[i]
idx_train = np.where(train_mask)[0]

# Next month as a test set
test_mask = months == unique_months[i + 1]
idx_test = np.where(test_mask)[0]
yield (
idx_train, idx_test
)

0 comments on commit cd81a6a

Please sign in to comment.