Skip to content

Commit

Permalink
Merge pull request #318 from TeamHG-Memex/catboost-docs-cleanup
Browse files Browse the repository at this point in the history
Catboost docs cleanup
  • Loading branch information
kmike authored Jun 27, 2019
2 parents e19139a + aa32308 commit 716a4d7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ It provides support for the following machine learning frameworks and packages:
* LightGBM_ - show feature importances and explain predictions of
LGBMClassifier and LGBMRegressor.

* CatBoost_ - show feature importances and explain predictions of CatBoostClassifier and CatBoostRegressor.
* CatBoost_ - show feature importances of CatBoostClassifier,
CatBoostRegressor and catboost.CatBoost.

* lightning_ - explain weights and predictions of lightning classifiers and
regressors.
Expand Down
29 changes: 18 additions & 11 deletions eli5/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from eli5.explain import explain_weights
from eli5._feature_importances import get_feature_importance_explanation

DESCRIPTION_CATBOOST = """CatBoost feature importances; values are numbers 0 <= x <= 1;
all values sum to 1."""
DESCRIPTION_CATBOOST = """CatBoost feature importances;
values are numbers 0 <= x <= 1; all values sum to 1."""

@explain_weights.register(catboost.CatBoost)
@explain_weights.register(catboost.CatBoostClassifier)
Expand All @@ -20,8 +20,8 @@ def explain_weights_catboost(catb,
pool=None
):
"""
Return an explanation of an CatBoost estimator (CatBoostClassifier, CatBoost, CatBoostRegressor)
as feature importances.
Return an explanation of an CatBoost estimator (CatBoostClassifier,
CatBoost, CatBoostRegressor) as feature importances.
See :func:`eli5.explain_weights` for description of
``top``, ``feature_names``,
Expand All @@ -34,13 +34,16 @@ def explain_weights_catboost(catb,
:param 'importance_type' : str, optional
A way to get feature importance. Possible values are:
- 'PredictionValuesChange' - The individual importance values for each of the input features.
(default)
- 'LossFunctionChange' - The individual importance values for each of the input features for ranking metrics (requires training data to be passed or a similar dataset with Pool)
- 'PredictionValuesChange' (default) - The individual importance
values for each of the input features.
- 'LossFunctionChange' - The individual importance values for
each of the input features for ranking metrics
(requires training data to be passed or a similar dataset with Pool)
:param 'pool' : catboost.Pool, optional
To be passed if explain_weights_catboost has importance_type set to LossFunctionChange.
The catboost feature_importances uses the Pool datatype to calculate the parameter for the specific importance_type.
To be passed if explain_weights_catboost has importance_type set
to LossFunctionChange. The catboost feature_importances uses the Pool
datatype to calculate the parameter for the specific importance_type.
"""
is_regression = _is_regression(catb)
catb_feature_names = catb.feature_names_
Expand Down Expand Up @@ -69,10 +72,14 @@ def _catb_feature_importance(catb, importance_type, pool=None):
fs = catb.get_feature_importance(data=pool, type=importance_type)
else:
raise ValueError(
'importance_type: "LossFunctionChange" requires catboost.Pool datatype to be passed with parameter pool to calculate metric. Either no datatype or invalid datatype was passed'
'importance_type: "LossFunctionChange" requires catboost.Pool '
'datatype to be passed with parameter pool to calculate '
'metric. Either no datatype or invalid datatype was passed'
)
else:
raise ValueError(
'Only two importance_type "PredictionValuesChange" and "LossFunctionChange" are supported. Invalid Parameter {} for importance_type'.format(importance_type))
'Only two importance_type "PredictionValuesChange" '
'and "LossFunctionChange" are supported. Invalid Parameter '
'{} for importance_type'.format(importance_type))
all_features = np.array(fs, dtype=np.float32)
return all_features/all_features.sum()

0 comments on commit 716a4d7

Please sign in to comment.