Skip to content

Commit

Permalink
concepts: introduce the Craft method
Browse files Browse the repository at this point in the history
Signed-off-by: Frederic Boisnard <[email protected]>
  • Loading branch information
fredericboisnard committed Sep 8, 2023
1 parent ec30fe4 commit 9c7ea58
Show file tree
Hide file tree
Showing 7 changed files with 1,440 additions and 6 deletions.
261 changes: 261 additions & 0 deletions tests/concepts/test_craft_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import numpy as np
import tensorflow as tf
import pytest
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Conv1D, Conv2D, Activation, GlobalAveragePooling1D, Dropout, Flatten, MaxPooling2D, Input
from tensorflow.keras.utils import to_categorical

from xplique.concepts import CraftTf as Craft
from ..utils import generate_data, generate_model, generate_txt_images_data

def test_shape():
"""Ensure the output shape is correct"""

input_shapes = [(32, 32, 3), (32, 32, 1), (64, 64, 3)]
nb_labels = 3
nb_samples = 100

for input_shape in input_shapes:
# Generate a fake dataset
x, y = generate_data(input_shape, nb_labels, nb_samples)
model = generate_model(input_shape, nb_labels)
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1))

# cut the model in two parts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model
g = tf.keras.Model(model.input, model.layers[0].output)
h = tf.keras.Model(model.layers[1].input, model.layers[-1].output)

# The activations must be positives
assert np.all(g(x) >= 0.0)

# Initialize Craft
number_of_concepts = 10
patch_size = 15
craft = Craft(input_to_latent = g,
latent_to_logit = h,
number_of_concepts = number_of_concepts,
patch_size = patch_size,
batch_size = 64)

# Now we can fit the concept using our images
# Focus on class id 0
class_id = 0
images_preprocessed = x[y.argmax(1)==class_id] # select only images of class 'class_id'
crops, crops_u, w = craft.fit(images_preprocessed, class_id)

print('Checking shape of crops, crops_u, w:', crops.shape, crops_u.shape, w.shape)
assert(crops.shape[1] == crops.shape[2] == patch_size) # Check patch sizes
assert(crops.shape[0] == crops_u.shape[0]) # Check numbers of patches
assert(crops_u.shape[1] == w.shape[0])

# Importance estimation
importances = craft.estimate_importance(nb_most_important_concepts=5)
print('Checking shape of importances:', importances.shape)
assert(len(importances) == number_of_concepts)

images_u = craft.transform(images_preprocessed)
print('Checking shape of images_u:', images_u.shape)
assert(images_u.shape[1:] == (input_shape[0]-1, input_shape[1]-1, number_of_concepts))


def test_shape2():
"""Ensure the output shape is correct when activation.shape == 2"""

input_shapes = [(32, 32, 3), (32, 32, 1), (64, 64, 3)]
nb_labels = 3
nb_samples = 100

for input_shape in input_shapes:
# Generate a fake dataset
x, y = generate_data(input_shape, nb_labels, nb_samples)
model = generate_model(input_shape, nb_labels)
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1))

# cut the model in two parts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model
g = tf.keras.Model(model.input, model.layers[-1].output)
h = tf.keras.Model(model.layers[-1].input, model.layers[-1].output)

# The activations must be positives
assert np.all(g(x) >= 0.0)

# Initialize Craft
number_of_concepts = 10
patch_size = 15
craft = Craft(input_to_latent = g,
latent_to_logit = h,
number_of_concepts = number_of_concepts,
patch_size = patch_size,
batch_size = 64)

# Now we can fit the concept using our images
# Focus on class id 0
class_id = 0
images_preprocessed = x[y.argmax(1)==class_id] # select only images of class 'class_id'
crops, crops_u, w = craft.fit(images_preprocessed, class_id)

print('Checking shape of crops, crops_u, w:', crops.shape, crops_u.shape, w.shape)
assert(crops.shape[1] == crops.shape[2] == patch_size) # Check patch sizes
assert(crops.shape[0] == crops_u.shape[0]) # Check numbers of patches
assert(crops_u.shape[1] == w.shape[0] == number_of_concepts)

# Importance estimation
importances = craft.estimate_importance(nb_most_important_concepts=5)
print('Checking shape of importances:', importances.shape)
assert(len(importances) == number_of_concepts)

images_u = craft.transform(images_preprocessed)
print('Checking shape of images_u:', images_u.shape)
assert(input_shape[1], number_of_concepts)

def test_wrong_layers():
"""Ensure that Craft complains when the input models are incompatible"""

input_shapes = [(32, 32, 3)]
nb_labels = 3

for input_shape in input_shapes:
# Generate a fake dataset
model = generate_model(input_shape, nb_labels)
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1))

g = tf.keras.Model(model.input, model.layers[0].output)
h = lambda x: 2*x

# Initialize Craft
number_of_concepts = 10
patch_size = 15
with pytest.raises(TypeError):
craft = Craft(input_to_latent = g,
latent_to_logit = h,
number_of_concepts = number_of_concepts,
patch_size = patch_size,
batch_size = 64)

def test_classifier():
""" Check the Craft results on a small fake dataset """

input_shape = (64, 64, 3)
nb_labels = 3
nb_samples = 200

# Create a dataset of 'ABC', 'BCD', 'CDE' images
x, y, nb_samples, labels_str = generate_txt_images_data(input_shape, nb_labels, nb_samples)
print(f'Dataset: {nb_samples} samples generated, {nb_labels} classes: {labels_str}')

# train a small classifier on the dataset
def generate_model(input_shape=(64, 64, 3), output_shape=10):
model = Sequential()
model.add(Input(shape=input_shape))
model.add(Conv2D(6, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(6, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Conv2D(6, kernel_size=(2, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(output_shape))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

return model

model = generate_model(input_shape, nb_labels)
model.fit(x, y, batch_size=32, epochs=70)
print('acc:', np.sum(np.argmax(model(x), axis=1) == np.argmax(y, axis=1)) / nb_samples)

# cut the model in two parts (as explained in the paper)
# first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model
g = tf.keras.Model(model.input, model.layers[-4].output)
h = tf.keras.Model(model.layers[-3].input, model.layers[-1].output)
print(model.layers[-4])
print(model.layers[-3])

assert np.all(g(x) >= 0.0)

# Init Craft on the full dataset
craft = Craft(input_to_latent = g,
latent_to_logit = h,
number_of_concepts = 3,
patch_size = 12,
batch_size = 32)

# Expected best crop for class 0 (ABC) is AB
AB_str = """
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 1 1 1 1 1 1 1 1
0 0 0 0 0 0 1 0 0 0 1 1
1 0 0 0 0 0 1 0 0 0 0 1
1 0 0 0 0 0 1 0 0 0 1 1
1 0 0 0 0 0 1 1 1 1 1 1
1 1 0 0 0 0 1 0 0 0 0 1
0 1 0 0 0 0 1 0 0 0 0 0
0 1 1 0 0 0 1 0 0 0 0 1
1 1 1 1 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
"""
AB = np.genfromtxt(AB_str.splitlines())

# Expected best crop for class 1 (BCD) is BC
BC_str = """
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
1 1 1 1 1 1 0 0 0 0 1 1
1 0 0 0 1 1 0 0 0 1 1 0
1 0 0 0 0 1 0 0 0 1 0 0
1 0 0 0 1 1 0 0 0 1 0 0
1 1 1 1 1 1 0 0 0 1 0 0
1 0 0 0 0 1 1 0 0 1 0 0
1 0 0 0 0 0 1 0 0 1 0 0
1 0 0 0 0 1 1 0 0 1 1 0
1 1 1 1 1 1 0 0 0 0 1 1
"""
BC = np.genfromtxt(BC_str.splitlines())

# Expected best crop for class 2 (CDE) is DE
DE_str = """
0 0 0 0 0 0 0 0 0 0 0 0
1 0 0 0 1 1 1 1 1 1 1 1
1 1 0 0 0 0 1 0 0 0 0 1
0 1 0 0 0 0 1 0 0 0 0 1
0 1 1 0 0 0 1 0 0 1 0 0
0 1 1 0 0 0 1 1 1 1 0 0
0 1 1 0 0 0 1 0 0 1 0 0
0 1 0 0 0 0 1 0 0 0 0 1
1 1 0 0 0 0 1 0 0 0 0 1
1 0 0 0 1 1 1 1 1 1 1 1
0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0
"""
DE = np.genfromtxt(DE_str.splitlines())

expected_best_crops = [AB, BC, DE]

# Run 3 Craft studies on each class, and in each case check if the best crop is the expected one
for class_id in range(3):
# Focus on class class_id
print(f'Selecting subset for class {class_id} : {labels_str[class_id]}')
x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:]

print('fit craft on the selected class')
crops, crops_u, w = craft.fit(x_subset, class_id)

print('compute importances')
importances = craft.estimate_importance(nb_most_important_concepts=3)
assert(importances[0] > 0.7)

print('find the best crop and compare it to the expected best crop')
most_important_concepts = np.argsort(importances)[::-1]

# Find the best crop for the most important concept
c_id = most_important_concepts[0]
best_crops_ids = np.argsort(crops_u[:, c_id])[::-1]
best_crop = np.array(crops)[best_crops_ids[0]]

# Compare this best crop to the expectation
bin_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0)
assert(np.all(bin_best_crop == expected_best_crops[class_id]))
Loading

0 comments on commit 9c7ea58

Please sign in to comment.