Skip to content

Commit

Permalink
concepts: add the base class of all concept methods
Browse files Browse the repository at this point in the history
Signed-off-by: Frederic Boisnard <[email protected]>
  • Loading branch information
fredericboisnard committed Sep 8, 2023
1 parent 7923975 commit ec30fe4
Showing 1 changed file with 104 additions and 0 deletions.
104 changes: 104 additions & 0 deletions xplique/concepts/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Module related to abstract concept explainer
"""

from abc import ABC, abstractmethod

from sklearn.exceptions import NotFittedError

from ..types import Callable, Optional

class BaseConceptExtractor(ABC):

"""
Base class for concept extraction models.
Parameters
----------
number_of_concepts
The number of concepts to extract. Default is 20.
batch_size
The batch size to use during training and prediction. Default is 64.
"""
@abstractmethod
def __init__(self, number_of_concepts: int = 20,
batch_size: int = 64):
self.number_of_concepts = number_of_concepts
self.batch_size = batch_size

# sanity checks
assert(number_of_concepts > 0), "number_of_concepts must be greater than 0"
assert(batch_size > 0), "batch_size must be greater than 0"

@abstractmethod
def fit(self, inputs):
"""
Fit the CAVs to the input data.
Parameters
----------
inputs
The input data to fit the model on.
Returns
-------
tuple
A tuple containing the input data and the matrices (U, W) that factorize the data.
"""
raise NotImplementedError

def check_if_fitted(self):
"""Checks if the factorization model has been fitted to input data.
Raises
------
NotFittedError
If the factorization model has not been fitted to input data.
"""

if not hasattr(self, 'reducer'):
raise NotFittedError("The factorization model has not been fitted to input data yet.")

@abstractmethod
def transform(self, inputs):
"""
Transform the input data into a concepts embedding.
Parameters
----------
inputs
The input data to transform.
Returns
-------
array-like
The transformed embedding of the input data.
"""
raise NotImplementedError


class ClassifierConceptExtractor(BaseConceptExtractor):
"""
Base class for concept extraction based on a single classifier model.
Parameters
----------
model_encoder
The model to explain.
number_of_concepts
The number of concepts to extract. Default is 20.
batch_size
The batch size to use during training and prediction. Default is 64.
"""
def __init__(self, model_encoder : Callable,
number_of_concepts: int = 20,
batch_size: int = 64):
super().__init__(number_of_concepts, batch_size)
self.model_encoder = model_encoder

# sanity checks
assert(callable(model_encoder)), "model_encoder must be a callable function"

0 comments on commit ec30fe4

Please sign in to comment.