-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
concepts: introduce the Craft method
Signed-off-by: Frederic Boisnard <[email protected]>
- Loading branch information
1 parent
ec30fe4
commit 9c7ea58
Showing
7 changed files
with
1,440 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import pytest | ||
import tensorflow as tf | ||
from tensorflow.keras.models import Sequential, Model | ||
from tensorflow.keras.layers import Dense, Conv1D, Conv2D, Activation, GlobalAveragePooling1D, Dropout, Flatten, MaxPooling2D, Input | ||
from tensorflow.keras.utils import to_categorical | ||
|
||
from xplique.concepts import CraftTf as Craft | ||
from ..utils import generate_data, generate_model, generate_txt_images_data | ||
|
||
def test_shape(): | ||
"""Ensure the output shape is correct""" | ||
|
||
input_shapes = [(32, 32, 3), (32, 32, 1), (64, 64, 3)] | ||
nb_labels = 3 | ||
nb_samples = 100 | ||
|
||
for input_shape in input_shapes: | ||
# Generate a fake dataset | ||
x, y = generate_data(input_shape, nb_labels, nb_samples) | ||
model = generate_model(input_shape, nb_labels) | ||
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1)) | ||
|
||
# cut the model in two parts (as explained in the paper) | ||
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model | ||
g = tf.keras.Model(model.input, model.layers[0].output) | ||
h = tf.keras.Model(model.layers[1].input, model.layers[-1].output) | ||
|
||
# The activations must be positives | ||
assert np.all(g(x) >= 0.0) | ||
|
||
# Initialize Craft | ||
number_of_concepts = 10 | ||
patch_size = 15 | ||
craft = Craft(input_to_latent = g, | ||
latent_to_logit = h, | ||
number_of_concepts = number_of_concepts, | ||
patch_size = patch_size, | ||
batch_size = 64) | ||
|
||
# Now we can fit the concept using our images | ||
# Focus on class id 0 | ||
class_id = 0 | ||
images_preprocessed = x[y.argmax(1)==class_id] # select only images of class 'class_id' | ||
crops, crops_u, w = craft.fit(images_preprocessed, class_id) | ||
|
||
print('Checking shape of crops, crops_u, w:', crops.shape, crops_u.shape, w.shape) | ||
assert(crops.shape[1] == crops.shape[2] == patch_size) # Check patch sizes | ||
assert(crops.shape[0] == crops_u.shape[0]) # Check numbers of patches | ||
assert(crops_u.shape[1] == w.shape[0]) | ||
|
||
# Importance estimation | ||
importances = craft.estimate_importance(nb_most_important_concepts=5) | ||
print('Checking shape of importances:', importances.shape) | ||
assert(len(importances) == number_of_concepts) | ||
|
||
images_u = craft.transform(images_preprocessed) | ||
print('Checking shape of images_u:', images_u.shape) | ||
assert(images_u.shape[1:] == (input_shape[0]-1, input_shape[1]-1, number_of_concepts)) | ||
|
||
|
||
def test_shape2(): | ||
"""Ensure the output shape is correct when activation.shape == 2""" | ||
|
||
input_shapes = [(32, 32, 3), (32, 32, 1), (64, 64, 3)] | ||
nb_labels = 3 | ||
nb_samples = 100 | ||
|
||
for input_shape in input_shapes: | ||
# Generate a fake dataset | ||
x, y = generate_data(input_shape, nb_labels, nb_samples) | ||
model = generate_model(input_shape, nb_labels) | ||
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1)) | ||
|
||
# cut the model in two parts (as explained in the paper) | ||
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model | ||
g = tf.keras.Model(model.input, model.layers[-1].output) | ||
h = tf.keras.Model(model.layers[-1].input, model.layers[-1].output) | ||
|
||
# The activations must be positives | ||
assert np.all(g(x) >= 0.0) | ||
|
||
# Initialize Craft | ||
number_of_concepts = 10 | ||
patch_size = 15 | ||
craft = Craft(input_to_latent = g, | ||
latent_to_logit = h, | ||
number_of_concepts = number_of_concepts, | ||
patch_size = patch_size, | ||
batch_size = 64) | ||
|
||
# Now we can fit the concept using our images | ||
# Focus on class id 0 | ||
class_id = 0 | ||
images_preprocessed = x[y.argmax(1)==class_id] # select only images of class 'class_id' | ||
crops, crops_u, w = craft.fit(images_preprocessed, class_id) | ||
|
||
print('Checking shape of crops, crops_u, w:', crops.shape, crops_u.shape, w.shape) | ||
assert(crops.shape[1] == crops.shape[2] == patch_size) # Check patch sizes | ||
assert(crops.shape[0] == crops_u.shape[0]) # Check numbers of patches | ||
assert(crops_u.shape[1] == w.shape[0] == number_of_concepts) | ||
|
||
# Importance estimation | ||
importances = craft.estimate_importance(nb_most_important_concepts=5) | ||
print('Checking shape of importances:', importances.shape) | ||
assert(len(importances) == number_of_concepts) | ||
|
||
images_u = craft.transform(images_preprocessed) | ||
print('Checking shape of images_u:', images_u.shape) | ||
assert(input_shape[1], number_of_concepts) | ||
|
||
def test_wrong_layers(): | ||
"""Ensure that Craft complains when the input models are incompatible""" | ||
|
||
input_shapes = [(32, 32, 3)] | ||
nb_labels = 3 | ||
|
||
for input_shape in input_shapes: | ||
# Generate a fake dataset | ||
model = generate_model(input_shape, nb_labels) | ||
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1)) | ||
|
||
g = tf.keras.Model(model.input, model.layers[0].output) | ||
h = lambda x: 2*x | ||
|
||
# Initialize Craft | ||
number_of_concepts = 10 | ||
patch_size = 15 | ||
with pytest.raises(TypeError): | ||
craft = Craft(input_to_latent = g, | ||
latent_to_logit = h, | ||
number_of_concepts = number_of_concepts, | ||
patch_size = patch_size, | ||
batch_size = 64) | ||
|
||
def test_classifier(): | ||
""" Check the Craft results on a small fake dataset """ | ||
|
||
input_shape = (64, 64, 3) | ||
nb_labels = 3 | ||
nb_samples = 200 | ||
|
||
# Create a dataset of 'ABC', 'BCD', 'CDE' images | ||
x, y, nb_samples, labels_str = generate_txt_images_data(input_shape, nb_labels, nb_samples) | ||
print(f'Dataset: {nb_samples} samples generated, {nb_labels} classes: {labels_str}') | ||
|
||
# train a small classifier on the dataset | ||
def generate_model(input_shape=(64, 64, 3), output_shape=10): | ||
model = Sequential() | ||
model.add(Input(shape=input_shape)) | ||
model.add(Conv2D(6, kernel_size=(2, 2))) | ||
model.add(Activation('relu')) | ||
model.add(Conv2D(6, kernel_size=(2, 2))) | ||
model.add(Activation('relu')) | ||
model.add(Conv2D(6, kernel_size=(2, 2))) | ||
model.add(Activation('relu')) | ||
model.add(Flatten()) | ||
model.add(Dense(output_shape)) | ||
model.add(Activation('softmax')) | ||
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) | ||
|
||
return model | ||
|
||
model = generate_model(input_shape, nb_labels) | ||
model.fit(x, y, batch_size=32, epochs=70) | ||
print('acc:', np.sum(np.argmax(model(x), axis=1) == np.argmax(y, axis=1)) / nb_samples) | ||
|
||
# cut the model in two parts (as explained in the paper) | ||
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model | ||
g = tf.keras.Model(model.input, model.layers[-4].output) | ||
h = tf.keras.Model(model.layers[-3].input, model.layers[-1].output) | ||
print(model.layers[-4]) | ||
print(model.layers[-3]) | ||
|
||
assert np.all(g(x) >= 0.0) | ||
|
||
# Init Craft on the full dataset | ||
craft = Craft(input_to_latent = g, | ||
latent_to_logit = h, | ||
number_of_concepts = 3, | ||
patch_size = 12, | ||
batch_size = 32) | ||
|
||
# Expected best crop for class 0 (ABC) is AB | ||
AB_str = """ | ||
0 0 0 0 0 0 0 0 0 0 0 0 | ||
0 0 0 0 0 0 0 0 0 0 0 0 | ||
0 0 0 0 1 1 1 1 1 1 1 1 | ||
0 0 0 0 0 0 1 0 0 0 1 1 | ||
1 0 0 0 0 0 1 0 0 0 0 1 | ||
1 0 0 0 0 0 1 0 0 0 1 1 | ||
1 0 0 0 0 0 1 1 1 1 1 1 | ||
1 1 0 0 0 0 1 0 0 0 0 1 | ||
0 1 0 0 0 0 1 0 0 0 0 0 | ||
0 1 1 0 0 0 1 0 0 0 0 1 | ||
1 1 1 1 1 1 1 1 1 1 1 1 | ||
0 0 0 0 0 0 0 0 0 0 0 0 | ||
""" | ||
AB = np.genfromtxt(AB_str.splitlines()) | ||
|
||
# Expected best crop for class 1 (BCD) is BC | ||
BC_str = """ | ||
0 0 0 0 0 0 0 0 0 0 0 0 | ||
0 0 0 0 0 0 0 0 0 0 0 0 | ||
0 0 0 0 0 0 0 0 0 0 0 0 | ||
1 1 1 1 1 1 0 0 0 0 1 1 | ||
1 0 0 0 1 1 0 0 0 1 1 0 | ||
1 0 0 0 0 1 0 0 0 1 0 0 | ||
1 0 0 0 1 1 0 0 0 1 0 0 | ||
1 1 1 1 1 1 0 0 0 1 0 0 | ||
1 0 0 0 0 1 1 0 0 1 0 0 | ||
1 0 0 0 0 0 1 0 0 1 0 0 | ||
1 0 0 0 0 1 1 0 0 1 1 0 | ||
1 1 1 1 1 1 0 0 0 0 1 1 | ||
""" | ||
BC = np.genfromtxt(BC_str.splitlines()) | ||
|
||
# Expected best crop for class 2 (CDE) is DE | ||
DE_str = """ | ||
0 0 0 0 0 0 0 0 0 0 0 0 | ||
1 0 0 0 1 1 1 1 1 1 1 1 | ||
1 1 0 0 0 0 1 0 0 0 0 1 | ||
0 1 0 0 0 0 1 0 0 0 0 1 | ||
0 1 1 0 0 0 1 0 0 1 0 0 | ||
0 1 1 0 0 0 1 1 1 1 0 0 | ||
0 1 1 0 0 0 1 0 0 1 0 0 | ||
0 1 0 0 0 0 1 0 0 0 0 1 | ||
1 1 0 0 0 0 1 0 0 0 0 1 | ||
1 0 0 0 1 1 1 1 1 1 1 1 | ||
0 0 0 0 0 0 0 0 0 0 0 0 | ||
0 0 0 0 0 0 0 0 0 0 0 0 | ||
""" | ||
DE = np.genfromtxt(DE_str.splitlines()) | ||
|
||
expected_best_crops = [AB, BC, DE] | ||
|
||
# Run 3 Craft studies on each class, and in each case check if the best crop is the expected one | ||
for class_id in range(3): | ||
# Focus on class class_id | ||
print(f'Selecting subset for class {class_id} : {labels_str[class_id]}') | ||
x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:] | ||
|
||
print('fit craft on the selected class') | ||
crops, crops_u, w = craft.fit(x_subset, class_id) | ||
|
||
print('compute importances') | ||
importances = craft.estimate_importance(nb_most_important_concepts=3) | ||
assert(importances[0] > 0.7) | ||
|
||
print('find the best crop and compare it to the expected best crop') | ||
most_important_concepts = np.argsort(importances)[::-1] | ||
|
||
# Find the best crop for the most important concept | ||
c_id = most_important_concepts[0] | ||
best_crops_ids = np.argsort(crops_u[:, c_id])[::-1] | ||
best_crop = np.array(crops)[best_crops_ids[0]] | ||
|
||
# Compare this best crop to the expectation | ||
bin_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0) | ||
assert(np.all(bin_best_crop == expected_best_crops[class_id])) |
Oops, something went wrong.