-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
concepts: add the base class of all concept methods
Signed-off-by: Frederic Boisnard <[email protected]>
- Loading branch information
1 parent
7923975
commit ec30fe4
Showing
1 changed file
with
104 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
""" | ||
Module related to abstract concept explainer | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from sklearn.exceptions import NotFittedError | ||
|
||
from ..types import Callable, Optional | ||
|
||
class BaseConceptExtractor(ABC): | ||
|
||
""" | ||
Base class for concept extraction models. | ||
Parameters | ||
---------- | ||
number_of_concepts | ||
The number of concepts to extract. Default is 20. | ||
batch_size | ||
The batch size to use during training and prediction. Default is 64. | ||
""" | ||
@abstractmethod | ||
def __init__(self, number_of_concepts: int = 20, | ||
batch_size: int = 64): | ||
self.number_of_concepts = number_of_concepts | ||
self.batch_size = batch_size | ||
|
||
# sanity checks | ||
assert(number_of_concepts > 0), "number_of_concepts must be greater than 0" | ||
assert(batch_size > 0), "batch_size must be greater than 0" | ||
|
||
@abstractmethod | ||
def fit(self, inputs): | ||
""" | ||
Fit the CAVs to the input data. | ||
Parameters | ||
---------- | ||
inputs | ||
The input data to fit the model on. | ||
Returns | ||
------- | ||
tuple | ||
A tuple containing the input data and the matrices (U, W) that factorize the data. | ||
""" | ||
raise NotImplementedError | ||
|
||
def check_if_fitted(self): | ||
"""Checks if the factorization model has been fitted to input data. | ||
Raises | ||
------ | ||
NotFittedError | ||
If the factorization model has not been fitted to input data. | ||
""" | ||
|
||
if not hasattr(self, 'reducer'): | ||
raise NotFittedError("The factorization model has not been fitted to input data yet.") | ||
|
||
@abstractmethod | ||
def transform(self, inputs): | ||
""" | ||
Transform the input data into a concepts embedding. | ||
Parameters | ||
---------- | ||
inputs | ||
The input data to transform. | ||
Returns | ||
------- | ||
array-like | ||
The transformed embedding of the input data. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class ClassifierConceptExtractor(BaseConceptExtractor): | ||
""" | ||
Base class for concept extraction based on a single classifier model. | ||
Parameters | ||
---------- | ||
model_encoder | ||
The model to explain. | ||
number_of_concepts | ||
The number of concepts to extract. Default is 20. | ||
batch_size | ||
The batch size to use during training and prediction. Default is 64. | ||
""" | ||
def __init__(self, model_encoder : Callable, | ||
number_of_concepts: int = 20, | ||
batch_size: int = 64): | ||
super().__init__(number_of_concepts, batch_size) | ||
self.model_encoder = model_encoder | ||
|
||
# sanity checks | ||
assert(callable(model_encoder)), "model_encoder must be a callable function" |