Skip to content

Commit

Permalink
Final Solution Sklearn Task 2
Browse files Browse the repository at this point in the history
  • Loading branch information
Pablo-Molla-Charlez committed Dec 21, 2023
1 parent 22eb618 commit bb7e96f
Showing 1 changed file with 2 additions and 12 deletions.
14 changes: 2 additions & 12 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@

class KNearestNeighbors(BaseEstimator, ClassifierMixin):
"""KNearestNeighbors classifier."""

def __init__(self, n_neighbors=1): # noqa: D107
self.n_neighbors = n_neighbors

Expand All @@ -85,7 +84,6 @@ def fit(self, X, y):
The current instance of the classifier
"""
# Checks

"""
The check_classification_targets function ensures that target y is of
a non-regression type. Only the following target types
Expand All @@ -94,7 +92,6 @@ def fit(self, X, y):
'multilabel-sequences'.
"""
check_classification_targets(y)

"""
The check_X_y function performs an input validation for standard
estimators (=models). It checks X and y for consistent length,
Expand All @@ -118,7 +115,8 @@ def fit(self, X, y):
return self

def predict(self, X):
"""Predict function.
"""
Predict function.
Parameters
----------
Expand All @@ -130,7 +128,6 @@ def predict(self, X):
y : ndarray, shape (n_test_samples,)
Predicted class labels for each test data sample (y_test).
"""

"""
The check_is_fitted function is a sklearn.utils.validation function
used to check whether an estimator (such as a classifier or regressor)
Expand All @@ -139,7 +136,6 @@ def predict(self, X):
an error.
"""
check_is_fitted(self)

"""
The check_array function is a sklearn.utils.validation function used
to validate whether an input array is suitable for use in scikit-learn
Expand All @@ -150,7 +146,6 @@ def predict(self, X):
these requirements, check_array will throw an error.
"""
check_array(X)

# Calculate pairwise distances
"""
pairwise_distances(X, self.X_):
Expand All @@ -162,7 +157,6 @@ def predict(self, X):
the i-th sample in X and the j-th sample in self.X_.
"""
dist_matrix = pairwise_distances(X, self.X_)

# Find Indices of Nearest Neighbors
"""
np.argsort(dist_matrix, axis=1):
Expand All @@ -176,7 +170,6 @@ def predict(self, X):
the nearest neighbors.
"""
dist_sort_pos = np.argsort(dist_matrix, axis=1)[:, :self.n_neighbors]

# Find Indices of Nearest Neighbors
"""
np.argsort(dist_matrix, axis=1):
Expand All @@ -189,7 +182,6 @@ def predict(self, X):
This slicing operation takes the first self.n_neighbors
indices for each row. These are the indices of the nearest neighbors.
"""

# Get labels of nearest neighbors
"""
self.y_ (y_train) is the array of labels corresponding to the training
Expand All @@ -198,7 +190,6 @@ def predict(self, X):
each sample in X.
"""
y_closest = self.y_[dist_sort_pos]

# Determine predicted values
"""
This line predicts the label for each sample in X based on the
Expand Down Expand Up @@ -256,7 +247,6 @@ class MonthlySplit(BaseCrossValidator):
for which this column is not a datetime, it will raise a ValueError.
To use the index as column just set `time_col` to `'index'`.
"""

def __init__(self, time_col='index'): # noqa: D107
self.time_col = time_col

Expand Down

0 comments on commit bb7e96f

Please sign in to comment.