Skip to content

Commit

Permalink
Resolution de problemes de style : lignes trop longues, utiliser isin…
Browse files Browse the repository at this point in the history
…stance, ne pas utilier de lambda, ne pas importer des modules non utilises
  • Loading branch information
Florent-LC committed Dec 19, 2023
1 parent b3cbebc commit 52e8569
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions sklearn_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
from sklearn.utils.validation import check_X_y, check_is_fitted
from sklearn.utils.validation import check_array
from sklearn.utils.multiclass import check_classification_targets
from sklearn.metrics.pairwise import pairwise_distances

# from sklearn.metrics.pairwise import pairwise_distances
from sklearn.utils.multiclass import unique_labels

from dateutil.relativedelta import relativedelta
Expand Down Expand Up @@ -97,6 +98,9 @@ def fit(self, X, y):

return self

def distance_func(x1, x2):
return np.linalg.norm(x1 - x2)

def predict(self, X):
"""Predict function.
Expand All @@ -115,15 +119,16 @@ def predict(self, X):

y_pred = np.empty(X.shape[0], dtype=self.dtype_)

distance_func = lambda x1, x2: np.linalg.norm(x1 - x2)

for i, X_predict in enumerate(X):
# the euclidian distance for each sample of X_train
distances = np.array(
[distance_func(X_predict, X_train) for X_train in self.X_]
[
KNearestNeighbors.distance_func(X_predict, X_train)
for X_train in self.X_
]
)

# the indices for each sample sorted by its distance (the k-nearest)
# indices for each sample sorted by its distance (the k-nearest)
indexes_sort = np.argsort(distances)[: self.n_neighbors]

# the labels of the nearest neighbors
Expand Down Expand Up @@ -201,13 +206,19 @@ def get_n_splits(self, X, y=None, groups=None):
n_splits : int
The number of splits.
"""
if not (type(X) in [pd.core.frame.DataFrame, pd.core.series.Series]):
if not (
isinstance(X, pd.core.frame.DataFrame)
or isinstance(X, pd.core.series.Series)
):
raise TypeError(
f"The type of X ({type(X)}) is not consistent with a pandas dataframe or series"
f"The type of X ({type(X)}) is not consistent \
with a pandas dataframe or series"
)

if self.time_col == "index":
if type(X.index) != pd.core.indexes.datetimes.DatetimeIndex:
if not (
isinstance(X.index, pd.core.indexes.datetimes.DatetimeIndex)
):
raise ValueError("datetime")

else:
Expand Down Expand Up @@ -247,13 +258,19 @@ def split(self, X, y=None, groups=None):
idx_test : ndarray
The testing set indices for that split.
"""
if not (type(X) in [pd.core.frame.DataFrame, pd.core.series.Series]):
if not (
isinstance(X, pd.core.frame.DataFrame)
or isinstance(X, pd.core.series.Series)
):
raise TypeError(
f"The type of X ({type(X)}) is not consistent with a pandas dataframe or series"
f"The type of X ({type(X)}) is not consistent \
with a pandas dataframe or series"
)

if self.time_col == "index":
if type(X.index) != pd.core.indexes.datetimes.DatetimeIndex:
if not (
isinstance(X.index, pd.core.indexes.datetimes.DatetimeIndex)
):
raise ValueError("datetime")

else:
Expand Down

0 comments on commit 52e8569

Please sign in to comment.