From 9c7ea58cee3e3cce4f03ede30e965b5befc39c90 Mon Sep 17 00:00:00 2001 From: Frederic Boisnard Date: Wed, 16 Aug 2023 19:28:08 +0200 Subject: [PATCH] concepts: introduce the Craft method Signed-off-by: Frederic Boisnard --- tests/concepts/test_craft_tf.py | 261 ++++++++++++++++ tests/concepts/test_craft_torch.py | 325 ++++++++++++++++++++ tests/utils.py | 52 ++++ xplique/concepts/__init__.py | 18 +- xplique/concepts/craft.py | 475 +++++++++++++++++++++++++++++ xplique/concepts/craft_tf.py | 124 ++++++++ xplique/concepts/craft_torch.py | 191 ++++++++++++ 7 files changed, 1440 insertions(+), 6 deletions(-) create mode 100644 tests/concepts/test_craft_tf.py create mode 100644 tests/concepts/test_craft_torch.py create mode 100644 xplique/concepts/craft.py create mode 100644 xplique/concepts/craft_tf.py create mode 100644 xplique/concepts/craft_torch.py diff --git a/tests/concepts/test_craft_tf.py b/tests/concepts/test_craft_tf.py new file mode 100644 index 00000000..4fe4eea2 --- /dev/null +++ b/tests/concepts/test_craft_tf.py @@ -0,0 +1,261 @@ +import numpy as np +import tensorflow as tf +import pytest +import tensorflow as tf +from tensorflow.keras.models import Sequential, Model +from tensorflow.keras.layers import Dense, Conv1D, Conv2D, Activation, GlobalAveragePooling1D, Dropout, Flatten, MaxPooling2D, Input +from tensorflow.keras.utils import to_categorical + +from xplique.concepts import CraftTf as Craft +from ..utils import generate_data, generate_model, generate_txt_images_data + +def test_shape(): + """Ensure the output shape is correct""" + + input_shapes = [(32, 32, 3), (32, 32, 1), (64, 64, 3)] + nb_labels = 3 + nb_samples = 100 + + for input_shape in input_shapes: + # Generate a fake dataset + x, y = generate_data(input_shape, nb_labels, nb_samples) + model = generate_model(input_shape, nb_labels) + model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1)) + + # cut the model in two parts (as explained in the paper) + # first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model + g = tf.keras.Model(model.input, model.layers[0].output) + h = tf.keras.Model(model.layers[1].input, model.layers[-1].output) + + # The activations must be positives + assert np.all(g(x) >= 0.0) + + # Initialize Craft + number_of_concepts = 10 + patch_size = 15 + craft = Craft(input_to_latent = g, + latent_to_logit = h, + number_of_concepts = number_of_concepts, + patch_size = patch_size, + batch_size = 64) + + # Now we can fit the concept using our images + # Focus on class id 0 + class_id = 0 + images_preprocessed = x[y.argmax(1)==class_id] # select only images of class 'class_id' + crops, crops_u, w = craft.fit(images_preprocessed, class_id) + + print('Checking shape of crops, crops_u, w:', crops.shape, crops_u.shape, w.shape) + assert(crops.shape[1] == crops.shape[2] == patch_size) # Check patch sizes + assert(crops.shape[0] == crops_u.shape[0]) # Check numbers of patches + assert(crops_u.shape[1] == w.shape[0]) + + # Importance estimation + importances = craft.estimate_importance(nb_most_important_concepts=5) + print('Checking shape of importances:', importances.shape) + assert(len(importances) == number_of_concepts) + + images_u = craft.transform(images_preprocessed) + print('Checking shape of images_u:', images_u.shape) + assert(images_u.shape[1:] == (input_shape[0]-1, input_shape[1]-1, number_of_concepts)) + + +def test_shape2(): + """Ensure the output shape is correct when activation.shape == 2""" + + input_shapes = [(32, 32, 3), (32, 32, 1), (64, 64, 3)] + nb_labels = 3 + nb_samples = 100 + + for input_shape in input_shapes: + # Generate a fake dataset + x, y = generate_data(input_shape, nb_labels, nb_samples) + model = generate_model(input_shape, nb_labels) + model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1)) + + # cut the model in two parts (as explained in the paper) + # first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model + g = tf.keras.Model(model.input, model.layers[-1].output) + h = tf.keras.Model(model.layers[-1].input, model.layers[-1].output) + + # The activations must be positives + assert np.all(g(x) >= 0.0) + + # Initialize Craft + number_of_concepts = 10 + patch_size = 15 + craft = Craft(input_to_latent = g, + latent_to_logit = h, + number_of_concepts = number_of_concepts, + patch_size = patch_size, + batch_size = 64) + + # Now we can fit the concept using our images + # Focus on class id 0 + class_id = 0 + images_preprocessed = x[y.argmax(1)==class_id] # select only images of class 'class_id' + crops, crops_u, w = craft.fit(images_preprocessed, class_id) + + print('Checking shape of crops, crops_u, w:', crops.shape, crops_u.shape, w.shape) + assert(crops.shape[1] == crops.shape[2] == patch_size) # Check patch sizes + assert(crops.shape[0] == crops_u.shape[0]) # Check numbers of patches + assert(crops_u.shape[1] == w.shape[0] == number_of_concepts) + + # Importance estimation + importances = craft.estimate_importance(nb_most_important_concepts=5) + print('Checking shape of importances:', importances.shape) + assert(len(importances) == number_of_concepts) + + images_u = craft.transform(images_preprocessed) + print('Checking shape of images_u:', images_u.shape) + assert(input_shape[1], number_of_concepts) + +def test_wrong_layers(): + """Ensure that Craft complains when the input models are incompatible""" + + input_shapes = [(32, 32, 3)] + nb_labels = 3 + + for input_shape in input_shapes: + # Generate a fake dataset + model = generate_model(input_shape, nb_labels) + model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(0.1)) + + g = tf.keras.Model(model.input, model.layers[0].output) + h = lambda x: 2*x + + # Initialize Craft + number_of_concepts = 10 + patch_size = 15 + with pytest.raises(TypeError): + craft = Craft(input_to_latent = g, + latent_to_logit = h, + number_of_concepts = number_of_concepts, + patch_size = patch_size, + batch_size = 64) + +def test_classifier(): + """ Check the Craft results on a small fake dataset """ + + input_shape = (64, 64, 3) + nb_labels = 3 + nb_samples = 200 + + # Create a dataset of 'ABC', 'BCD', 'CDE' images + x, y, nb_samples, labels_str = generate_txt_images_data(input_shape, nb_labels, nb_samples) + print(f'Dataset: {nb_samples} samples generated, {nb_labels} classes: {labels_str}') + + # train a small classifier on the dataset + def generate_model(input_shape=(64, 64, 3), output_shape=10): + model = Sequential() + model.add(Input(shape=input_shape)) + model.add(Conv2D(6, kernel_size=(2, 2))) + model.add(Activation('relu')) + model.add(Conv2D(6, kernel_size=(2, 2))) + model.add(Activation('relu')) + model.add(Conv2D(6, kernel_size=(2, 2))) + model.add(Activation('relu')) + model.add(Flatten()) + model.add(Dense(output_shape)) + model.add(Activation('softmax')) + model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) + + return model + + model = generate_model(input_shape, nb_labels) + model.fit(x, y, batch_size=32, epochs=70) + print('acc:', np.sum(np.argmax(model(x), axis=1) == np.argmax(y, axis=1)) / nb_samples) + + # cut the model in two parts (as explained in the paper) + # first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model + g = tf.keras.Model(model.input, model.layers[-4].output) + h = tf.keras.Model(model.layers[-3].input, model.layers[-1].output) + print(model.layers[-4]) + print(model.layers[-3]) + + assert np.all(g(x) >= 0.0) + + # Init Craft on the full dataset + craft = Craft(input_to_latent = g, + latent_to_logit = h, + number_of_concepts = 3, + patch_size = 12, + batch_size = 32) + + # Expected best crop for class 0 (ABC) is AB + AB_str = """ + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 1 1 1 1 1 1 1 1 + 0 0 0 0 0 0 1 0 0 0 1 1 + 1 0 0 0 0 0 1 0 0 0 0 1 + 1 0 0 0 0 0 1 0 0 0 1 1 + 1 0 0 0 0 0 1 1 1 1 1 1 + 1 1 0 0 0 0 1 0 0 0 0 1 + 0 1 0 0 0 0 1 0 0 0 0 0 + 0 1 1 0 0 0 1 0 0 0 0 1 + 1 1 1 1 1 1 1 1 1 1 1 1 + 0 0 0 0 0 0 0 0 0 0 0 0 + """ + AB = np.genfromtxt(AB_str.splitlines()) + + # Expected best crop for class 1 (BCD) is BC + BC_str = """ + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 0 0 0 0 0 + 1 1 1 1 1 1 0 0 0 0 1 1 + 1 0 0 0 1 1 0 0 0 1 1 0 + 1 0 0 0 0 1 0 0 0 1 0 0 + 1 0 0 0 1 1 0 0 0 1 0 0 + 1 1 1 1 1 1 0 0 0 1 0 0 + 1 0 0 0 0 1 1 0 0 1 0 0 + 1 0 0 0 0 0 1 0 0 1 0 0 + 1 0 0 0 0 1 1 0 0 1 1 0 + 1 1 1 1 1 1 0 0 0 0 1 1 + """ + BC = np.genfromtxt(BC_str.splitlines()) + + # Expected best crop for class 2 (CDE) is DE + DE_str = """ + 0 0 0 0 0 0 0 0 0 0 0 0 + 1 0 0 0 1 1 1 1 1 1 1 1 + 1 1 0 0 0 0 1 0 0 0 0 1 + 0 1 0 0 0 0 1 0 0 0 0 1 + 0 1 1 0 0 0 1 0 0 1 0 0 + 0 1 1 0 0 0 1 1 1 1 0 0 + 0 1 1 0 0 0 1 0 0 1 0 0 + 0 1 0 0 0 0 1 0 0 0 0 1 + 1 1 0 0 0 0 1 0 0 0 0 1 + 1 0 0 0 1 1 1 1 1 1 1 1 + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 0 0 0 0 0 + """ + DE = np.genfromtxt(DE_str.splitlines()) + + expected_best_crops = [AB, BC, DE] + + # Run 3 Craft studies on each class, and in each case check if the best crop is the expected one + for class_id in range(3): + # Focus on class class_id + print(f'Selecting subset for class {class_id} : {labels_str[class_id]}') + x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:] + + print('fit craft on the selected class') + crops, crops_u, w = craft.fit(x_subset, class_id) + + print('compute importances') + importances = craft.estimate_importance(nb_most_important_concepts=3) + assert(importances[0] > 0.7) + + print('find the best crop and compare it to the expected best crop') + most_important_concepts = np.argsort(importances)[::-1] + + # Find the best crop for the most important concept + c_id = most_important_concepts[0] + best_crops_ids = np.argsort(crops_u[:, c_id])[::-1] + best_crop = np.array(crops)[best_crops_ids[0]] + + # Compare this best crop to the expectation + bin_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0) + assert(np.all(bin_best_crop == expected_best_crops[class_id])) diff --git a/tests/concepts/test_craft_torch.py b/tests/concepts/test_craft_torch.py new file mode 100644 index 00000000..cfc5d7eb --- /dev/null +++ b/tests/concepts/test_craft_torch.py @@ -0,0 +1,325 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytest +from math import ceil + +from xplique.concepts import CraftTorch as Craft +from ..utils import generate_txt_images_data + +def generate_torch_data(x_shape=(3, 32, 32), num_labels=10, samples=100): + x = torch.tensor(np.random.rand(samples, *x_shape).astype(np.float32)) + y = F.one_hot(torch.tensor(np.random.randint(0, num_labels, samples)), num_labels) + + return x, y + +def generate_torch_model(input_shape=(3, 32, 32, 3), output_shape=10): + c_in = input_shape[0] + h_in = input_shape[1] + w_in = input_shape[2] + + model = nn.Sequential() + + model.append(nn.Conv2d(c_in, 4, (2, 2))) + h_out = h_in - 1 + w_out = w_in -1 + c_out = 4 + + model.append(nn.ReLU()) + model.append(nn.MaxPool2d((2, 2))) + h_out = int((h_out - 2)/2 + 1) + w_out = int((w_out - 2)/2 + 1) + + model.append(nn.Dropout(0.25)) + model.append(nn.Flatten()) + flatten_size = c_out * h_out * w_out + + model.append(nn.Linear(int(flatten_size), output_shape)) + + return model + +def test_shape(): + """Ensure the output shape is correct""" + + input_shapes = [(3, 32, 32), (1, 32, 32), (3, 64, 64)] + nb_labels = 3 + nb_samples = 100 + + for input_shape in input_shapes: + # Generate a fake dataset + x, y = generate_torch_data(input_shape, nb_labels, nb_samples) + model = generate_torch_model(input_shape, nb_labels) + + # cut the model in two parts (as explained in the paper) + # first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model + g = nn.Sequential(*list(model.children())[:2]) + h = nn.Sequential(*list(model.children())[2:]) + + # The activations must be positives + assert torch.all(g(x) >= 0.0) + + # Initialize Craft + number_of_concepts = 10 + patch_size = 15 + + craft = Craft(input_to_latent = g, + latent_to_logit = h, + number_of_concepts = number_of_concepts, + patch_size = patch_size, + batch_size = 64, + device = 'cpu') + + # Now we can fit the concept using our images + # Focus on class id 0 + class_id = 0 + images_preprocessed = x[y.argmax(1)==class_id] # select only images of class 'class_id' + print('images_preprocessed.shape:', images_preprocessed.shape) + crops, crops_u, w = craft.fit(images_preprocessed, class_id) + + print('Checking shape of crops, crops_u, w:', crops.shape, crops_u.shape, w.shape) + assert(crops.shape[2] == crops.shape[3] == patch_size) # Check patch sizes + assert(crops.shape[0] == crops_u.shape[0]) # Check numbers of patches + assert(crops_u.shape[1] == w.shape[0]) + + # Importance estimation + importances = craft.estimate_importance(nb_most_important_concepts=5) + print('Checking shape of importances:', importances.shape) + assert(len(importances) == number_of_concepts) + + images_u = craft.transform(images_preprocessed) + print('Checking shape of images_u:', images_u.shape) + assert(images_u.shape[1:] == (input_shape[1]-1, input_shape[2]-1, number_of_concepts)) + assert(images_u.shape[-1] == number_of_concepts) + +def test_shape2(): + """Ensure the output shape is correct when activation.shape == 2""" + + input_shapes = [(3, 32, 32), (1, 32, 32), (3, 64, 64)] + nb_labels = 3 + nb_samples = 100 + + for input_shape in input_shapes: + # Generate a fake dataset + x, y = generate_torch_data(input_shape, nb_labels, nb_samples) + model = generate_torch_model(input_shape, nb_labels) + + # cut the model in two parts (as explained in the paper) + # first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model + g = nn.Sequential(*list(model.children())[:-1]) + h = nn.Sequential(*list(model.children())[-1:]) + + # The activations must be positives + assert torch.all(g(x) >= 0.0) + + # Initialize Craft + number_of_concepts = 10 + patch_size = 15 + + craft = Craft(input_to_latent = g, + latent_to_logit = h, + number_of_concepts = number_of_concepts, + patch_size = patch_size, + batch_size = 64, + device = 'cpu') + + # Now we can fit the concept using our images + # Focus on class id 0 + class_id = 0 + images_preprocessed = x[y.argmax(1)==class_id] # select only images of class 'class_id' + print('images_preprocessed.shape:', images_preprocessed.shape) + crops, crops_u, w = craft.fit(images_preprocessed, class_id) + + print('Checking shape of crops, crops_u, w:', crops.shape, crops_u.shape, w.shape) + assert(crops.shape[2] == crops.shape[3] == patch_size) # Check patch sizes + assert(crops.shape[0] == crops_u.shape[0]) # Check numbers of patches + assert(crops_u.shape[1] == w.shape[0]) + + # Importance estimation + importances = craft.estimate_importance(nb_most_important_concepts=5) + print('Checking shape of importances:', importances.shape) + assert(len(importances) == number_of_concepts) + + images_u = craft.transform(images_preprocessed) + print('Checking shape of images_u:', images_u.shape) + assert(images_u.shape[-1] == number_of_concepts) + +def test_wrong_layers(): + """Ensure that Craft complains when the input models are incompatible""" + + input_shapes = [(3, 32, 32)] + nb_labels = 3 + + for input_shape in input_shapes: + + model = generate_torch_model(input_shape, nb_labels) + + # cut the model in two parts (as explained in the paper) + # first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model + g = nn.Sequential(*list(model.children())[:3]) + h = lambda x: 2*x + + # Initialize Craft + number_of_concepts = 10 + patch_size = 15 + with pytest.raises(TypeError): + craft = Craft(input_to_latent = g, + latent_to_logit = h, + number_of_concepts = number_of_concepts, + patch_size = patch_size, + batch_size = 64) + +def test_classifier(): + """ Check the Craft results on a small fake dataset """ + + input_shape = (64, 64, 3) + nb_labels = 3 + nb_samples = 200 + + # Create a dataset of 'ABC', 'BCD', 'CDE' images + x, y, nb_samples, labels_str = generate_txt_images_data(input_shape, nb_labels, nb_samples) + x = np.moveaxis(x, -1, 1) # reorder the axis to match torch format + x, y = torch.Tensor(x), torch.Tensor(y) + print(f'Dataset: {nb_samples} samples generated, {nb_labels} classes: {labels_str}') + + # train a small classifier on the dataset + def generate_torch_model(input_shape=(3, 64, 64), output_shape=10): + flatten_size = 6*(input_shape[1]-3)*(input_shape[2]-3) + model = nn.Sequential( + nn.Conv2d(3, 6, kernel_size=(2, 2)), + nn.ReLU(), + nn.Conv2d(6, 6, kernel_size=(2, 2)), + nn.ReLU(), + nn.Conv2d(6, 6, kernel_size=(2, 2)), + nn.ReLU(), + nn.Flatten(1, -1), + nn.Linear(flatten_size, output_shape)) + return model + model = generate_torch_model((input_shape[-1], *input_shape[0:2]), nb_labels) + + def train_model(model, x, y): + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=0.005) + loss_func = torch.nn.CrossEntropyLoss() + + batch_size = 64 + nb_batchs = ceil(len(x) / batch_size) + print('nb_batchs:', nb_batchs) + start_ids = [i*batch_size for i in range(nb_batchs)] + + for epoch in range(50): + print(f'=== epoch {epoch} ===') + for i in start_ids: + x_batch = x[i:i+batch_size] + y_batch = y[i:i+batch_size] + y_pred = model(torch.Tensor(x_batch)) + loss = loss_func(y_pred, torch.Tensor(y_batch)) + print(i, loss.item()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + return model + model = train_model(model, x, y) + + # check accuracy + model.eval() + acc = torch.sum(torch.argmax(model(x), axis=1) == torch.argmax(y, axis=1))/len(y) + print('accuracy:', acc) + assert(acc > 0.9) + + # cut pytorch model + g = nn.Sequential(*(list(model.children())[:6])) # input to penultimate layer + h = nn.Sequential(*(list(model.children())[6:])) # penultimate layer to logits + + assert torch.all(g(x) >= 0.0) + + # Init Craft on the full dataset + craft = Craft(input_to_latent = g, + latent_to_logit = h, + number_of_concepts = 3, + patch_size = 12, + batch_size = 32, + device='cpu') + + # Expected best crop for class 0 (ABC) is AB + AB_str = """ + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 1 1 1 1 1 1 1 1 + 0 0 0 0 0 0 1 0 0 0 1 1 + 1 0 0 0 0 0 1 0 0 0 0 1 + 1 0 0 0 0 0 1 0 0 0 1 1 + 1 0 0 0 0 0 1 1 1 1 1 1 + 1 1 0 0 0 0 1 0 0 0 0 1 + 0 1 0 0 0 0 1 0 0 0 0 0 + 0 1 1 0 0 0 1 0 0 0 0 1 + 1 1 1 1 1 1 1 1 1 1 1 1 + 0 0 0 0 0 0 0 0 0 0 0 0 + """ + AB = np.genfromtxt(AB_str.splitlines()) + + # Expected best crop for class 1 (BCD) is BC + BC_str = """ + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 0 0 0 0 0 + 1 1 1 1 1 1 0 0 0 0 1 1 + 1 0 0 0 1 1 0 0 0 1 1 0 + 1 0 0 0 0 1 0 0 0 1 0 0 + 1 0 0 0 1 1 0 0 0 1 0 0 + 1 1 1 1 1 1 0 0 0 1 0 0 + 1 0 0 0 0 1 1 0 0 1 0 0 + 1 0 0 0 0 0 1 0 0 1 0 0 + 1 0 0 0 0 1 1 0 0 1 1 0 + 1 1 1 1 1 1 0 0 0 0 1 1 + """ + BC = np.genfromtxt(BC_str.splitlines()) + + # Expected best crop for class 2 (CDE) is DE + DE_str = """ + 0 0 0 0 0 0 0 0 0 0 0 0 + 1 0 0 0 1 1 1 1 1 1 1 1 + 1 1 0 0 0 0 1 0 0 0 0 1 + 0 1 0 0 0 0 1 0 0 0 0 1 + 0 1 1 0 0 0 1 0 0 1 0 0 + 0 1 1 0 0 0 1 1 1 1 0 0 + 0 1 1 0 0 0 1 0 0 1 0 0 + 0 1 0 0 0 0 1 0 0 0 0 1 + 1 1 0 0 0 0 1 0 0 0 0 1 + 1 0 0 0 1 1 1 1 1 1 1 1 + 0 0 0 0 0 0 0 0 0 0 0 0 + 0 0 0 0 0 0 0 0 0 0 0 0 + """ + DE = np.genfromtxt(DE_str.splitlines()) + + expected_best_crops = [AB, BC, DE] + + # Run 3 Craft studies on each class, and in each case check if the best crop is the expected one + for class_id in range(3): + # Focus on class class_id + print(f'Selecting subset for class {class_id} : {labels_str[class_id]}') + x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:] + + print('fit craft on the selected class') + crops, crops_u, w = craft.fit(x_subset, class_id) + + print('compute importances') + importances = craft.estimate_importance(nb_most_important_concepts=3) + print('importances:', importances) + assert(importances[0] > 0.7) + + print('find the best crop and compare it to the expected best crop') + most_important_concepts = np.argsort(importances)[::-1] + + # Find the best crop for the most important concept + c_id = most_important_concepts[0] + best_crops_ids = np.argsort(crops_u[:, c_id])[::-1] + best_crop = np.array(crops)[best_crops_ids[0]] + best_crop = np.moveaxis(best_crop, 0, -1) + + # Compare this best crop to the expectation + bin_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0) + print('best crop:', bin_best_crop) + check = np.all(bin_best_crop == expected_best_crops[class_id]) + print('check class ', class_id, ' : comparison result:', check) + assert(check == True) diff --git a/tests/utils.py b/tests/utils.py index f9457372..b6356b4c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,7 @@ from tensorflow.keras.models import Sequential, Model from tensorflow.keras.layers import Dense, Conv1D, Conv2D, Activation, GlobalAveragePooling1D, Dropout, Flatten, MaxPooling2D, Input from tensorflow.keras.utils import to_categorical +from PIL import Image, ImageDraw, ImageFont def generate_data(x_shape=(32, 32, 3), num_labels=10, samples=100): x = np.random.rand(samples, *x_shape).astype(np.float32) @@ -114,3 +115,54 @@ def __call__(self, inputs): return tf_model +def generate_txt_images_data(x_shape=(32, 32, 3), num_labels=10, samples=100): + """ + Generate an image dataset composed of white texts over black background. + The texts are words of 3 successive letters, the number of classes is set by the + parameter num_labels. The location of the text in the image is cycling over the + image dimensions. + Ex: with num_labels=3, the 3 classes will be 'ABC', 'BCD' and 'CDE'. + + """ + all_labels_str = "".join([chr(lab_idx) for lab_idx in range(65, 65+num_labels+2)]) # ABCDEF + labels_str = [all_labels_str[i:i+3] for i in range(len(all_labels_str) - 2)] # ['ABC', 'BCD', 'CDE', 'DEF'] + + def create_image_from_txt(image_shape, txt, offset_x, offset_y): + # Get a Pillow font (OS independant) + fnt = ImageFont.truetype("Pillow/Tests/fonts/FreeMono.ttf", 16) + + # Make a black image and draw the input text in white at the location offset_x, offset_y + rgb = (len(image_shape) == 3 and image_shape[2] > 1) + if rgb: + image = Image.new("RGB", (image_shape[0], image_shape[1]), (0, 0, 0)) + else: + # grayscale + image = Image.new("L", (image_shape[0], image_shape[1]), 0) + d = ImageDraw.Draw(image) + d.text((offset_x, offset_y), txt, font=fnt, fill='white') + return image + + x = np.empty((samples, *x_shape)).astype(np.float32) + y = np.empty(samples) + + # Iterate over the samples and generate images of labels shifted by increasing offsets + offset_x_max = x_shape[0] - 25 + offset_y_max = x_shape[1] - 10 + + current_label_id = 0 + offset_x = offset_y = 0 + for i in range(samples): + image = create_image_from_txt(x_shape, txt=labels_str[current_label_id], offset_x=offset_x, offset_y=offset_y) + image = np.reshape(image, x_shape) + x[i] = np.array(image).astype(np.float32)/255.0 + y[i] = current_label_id + + # cycle labels + current_label_id = (current_label_id + 1) % num_labels + offset_x = (offset_x + 1) % offset_x_max + offset_y = ((i+2) % offset_y_max) + if offset_y > offset_y_max: + break + x = x[0:i] + y = y[0:i] + return x, to_categorical(y, num_labels), i, labels_str \ No newline at end of file diff --git a/xplique/concepts/__init__.py b/xplique/concepts/__init__.py index af5c4467..3f7b4d07 100644 --- a/xplique/concepts/__init__.py +++ b/xplique/concepts/__init__.py @@ -1,6 +1,12 @@ -""" -Concept based methods -""" - -from .cav import Cav -from .tcav import Tcav +""" +Concept based methods +""" + +from .cav import Cav +from .tcav import Tcav +from .craft import BaseCraft +from .craft_tf import CraftTf +try: + from .craft_torch import CraftTorch +except ImportError: + pass diff --git a/xplique/concepts/craft.py b/xplique/concepts/craft.py new file mode 100644 index 00000000..3b447f5d --- /dev/null +++ b/xplique/concepts/craft.py @@ -0,0 +1,475 @@ +""" +CRAFT Module for Tensorflow/Pytorch +""" + +from abc import abstractmethod +import numpy as np +from sklearn.decomposition import NMF +import cv2 +import colorsys +from math import ceil +import matplotlib.pyplot as plt +from matplotlib.colors import ListedColormap, LinearSegmentedColormap +import matplotlib.gridspec as gridspec +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + +from xplique.attributions.global_sensitivity_analysis import (HaltonSequenceRS, JansenEstimator) + +from ..types import Callable, Tuple, Optional +from .base import BaseConceptExtractor + +class BaseCraft(BaseConceptExtractor): + """ + Base class implementing the CRAFT Concept Extraction Mechanism. + + Ref. Fel et al., CRAFT Concept Recursive Activation FacTorization (2023). + https://arxiv.org/abs/2211.10154 + + It shall be subclassed in order to adapt to a specific framework + (Tensorflow or Pytorch) + + Parameters + ---------- + input_to_latent + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + latent_to_logit + The second part of the model taking activation and returning + logits, h(.) in the original paper. + number_of_concepts + The number of concepts to extract. Default is 20. + batch_size + The batch size to use during training and prediction. Default is 64. + patch_size + The size of the patches to extract from the input data. Default is 64. + """ + + def __init__(self, input_to_latent : Callable, + latent_to_logit : Callable, + number_of_concepts: int = 20, + batch_size: int = 64, + patch_size: int = 64): + super().__init__(number_of_concepts, batch_size) + self.input_to_latent = input_to_latent + self.latent_to_logit = latent_to_logit + self.patch_size = patch_size + self.cmaps = None + + # sanity checks + assert(callable(input_to_latent)), "input_to_latent must be a callable function" + assert(callable(latent_to_logit)), "latent_to_logit must be a callable function" + + @abstractmethod + def _latent_predict(self, inputs: np.ndarray): + raise NotImplementedError + + @abstractmethod + def _logit_predict(self, inputs: np.ndarray): + raise NotImplementedError + + @abstractmethod + def _extract_patches(self, inputs: np.ndarray): + raise NotImplementedError + + @abstractmethod + def _to_np_array(self, input, dtype) -> np.ndarray: + raise NotImplementedError + + def fit(self, inputs : np.ndarray, class_id: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Fit the Craft model to the input data. + + Parameters + ---------- + inputs + Input data of shape (n_samples, height, width, channels). + (x1, x2, ..., xn) in the paper. + class_id + The class id of the inputs. + + Returns + ------- + (X, U, W) + A tuple containing the crops (X in the paper), + the concepts values (U) and the concepts basis (W). + """ + # extract patches from the input data + self.inputs = inputs + self.class_id = class_id + self.crops, activations = self._extract_patches(self.inputs) + + # apply NMF to the activations to obtain matrices U and W + reducer = NMF(n_components=self.number_of_concepts, alpha_W=1e-2) + self.crops_u = reducer.fit_transform(activations) + self.W = reducer.components_.astype(np.float32) + + # store the factorizer as attribute of the Craft instance + self.reducer = reducer + + return self.crops, self.crops_u, self.W + + def transform(self, inputs : np.ndarray, activations : np.ndarray = None): + """Transforms the inputs data into its concept representation. + + Parameters + ---------- + inputs : numpy array or Tensor + The input data to be transformed. + activations + Pre-computed activations of the input data. If not provided, the activations + will be computed using the input_to_latent model on the inputs. + + Returns + ------- + U + The concept value (U) of the inputs. + """ + self.check_if_fitted() + + if activations == None: + activations = self._latent_predict(inputs) + + is_4d = len(activations.shape) == 4 + + if is_4d: + # (N, W, H, C) -> (N * W * H, C) + original_shape = activations.shape[:-1] + activations = np.reshape(activations, (-1, activations.shape[-1])) + + W_dtype = self.reducer.components_.dtype + U = self.reducer.transform(self._to_np_array(activations, dtype=W_dtype)) + + if is_4d: + # (N * W * H, R) -> (N, W, H, R) with R = nb_concepts + U = np.reshape(U, (*original_shape, U.shape[-1])) + return U + + def estimate_importance(self, nb_design: int = 32, nb_most_important_concepts: int = 6): + """ + Estimates the importance of each concept for a given class. + + Parameters + ---------- + nb_design + The number of design to use for the importance estimation. Default is 32. + nb_most_important_concepts + The number of concepts to focus on. Default is 6. + + Returns + ------- + importances + The Sobol total index (importance score) for each concept. + + """ + self.check_if_fitted() + + U = self.transform(self.inputs) + self.images_u = U + + masks = HaltonSequenceRS()(self.number_of_concepts, nb_design = nb_design) + estimator = JansenEstimator() + + importances = [] + + if len(U.shape) == 2: + # apply the original method of the paper + + for u in U: + u_perturbated = u[None, :] * masks + a_perturbated = u_perturbated @ self.W + + y_pred = self._logit_predict(a_perturbated) + y_pred = y_pred[:, self.class_id] + + stis = estimator(masks, y_pred, nb_design) + + importances.append(stis) + + elif len(U.shape) == 4: + # apply a re-parameterization trick and use mask on all localization for a given + # concept id to estimate sobol indices + for u in U: + u_perturbated = u[None, :] * masks[:, None, None, :] + a_perturbated = np.reshape(u_perturbated, (-1, u.shape[-1])) @ self.W + a_perturbated = np.reshape(a_perturbated, (len(masks), U.shape[1], U.shape[2], -1)) + + # a_perturbated: (N, H, W, C) + y_pred = self._logit_predict(a_perturbated) + y_pred = y_pred[:, self.class_id] + + stis = estimator(masks, y_pred, nb_design) + + importances.append(stis) + + self.importances = np.mean(importances, 0) + + self.most_important_concepts = np.argsort(self.importances)[::-1][:nb_most_important_concepts] + + if nb_most_important_concepts > len(self.importances): + raise RuntimeError('nb_most_important_concepts should be <= len of importances') + + + return self.importances + + def plot_concepts_importances(self, only_most_important: bool=False, quiet: bool=False): + """ + Plot a bar chart displaying the importance value of each concept. + + Parameters + ---------- + only_most_important + Flag to reduce the display to only the most important concepts (the number of + impotant concepts is provided using estimate_importance() method). + quiet + If True, then print the importance value of each concept, otherwise no textual + output will be printed. + """ + if not only_most_important: + plt.bar(range(len(self.importances)), self.importances) + plt.xticks(range(len(self.importances))) + plt.title("Concept Importance") + else: + if self.cmaps == None: + self.set_concept_attribution_cmap() + importances = self.importances[self.most_important_concepts] + colors = [c(1.0) for c in self.cmaps] + plt.bar(range(len(importances)), importances, color=colors) + plt.xticks(range(len(importances))) + plt.title("Concept Importance") + + if not quiet: + for c_id in self.most_important_concepts: + print("Concept", c_id, " has an importance value of ", self.importances[c_id]) + + @staticmethod + def _show(img, **kwargs): + img = np.array(img) + if img.shape[0] == 3: + img = img.transpose(1, 2, 0) + + img -= img.min() + if img.max() > 0: + img /= img.max() + plt.imshow(img, **kwargs); plt.axis('off') + + def plot_concepts_crops(self, nb_crops: int = 10) -> None: + """ + Display the crops for each concept. + + Parameters + ---------- + nb_crops + The number of crops to display per concept. Defaults to 10. + """ + for c_id in self.most_important_concepts: + best_crops_ids = np.argsort(self.crops_u[:, c_id])[::-1][:nb_crops] + best_crops = np.array(self.crops)[best_crops_ids] + + print("Concept", c_id, " has an importance value of ", self.importances[c_id]) + plt.figure(figsize=(7, (2.5/2)*ceil(nb_crops/5))) + for i in range(nb_crops): + plt.subplot(ceil(nb_crops/5), 5, i+1) + BaseCraft._show(best_crops[i]) + plt.show() + + @staticmethod + def _get_alpha_cmap(cmap): + if isinstance(cmap, str): + cmap = plt.get_cmap(cmap) + else: + c = np.array(cmap) + if np.any(c > 1.0): + c = c / 255.0 + + cmax = colorsys.rgb_to_hls(*c) + cmax = np.array(cmax) + cmax[-1] = 1.0 + + cmax = np.clip(np.array(colorsys.hls_to_rgb(*cmax)), 0, 1) + cmap = LinearSegmentedColormap.from_list("", [c,cmax]) + + alpha_cmap = cmap(np.arange(256)) + alpha_cmap[:,-1] = np.linspace(0, 0.85, 256) + alpha_cmap = ListedColormap(alpha_cmap) + + return alpha_cmap + + def set_concept_attribution_cmap(self, cmaps: list=None): + """ + Set the colormap used for the concepts displayed in the attribution maps. + + Parameters + ---------- + cmaps + A list of (r, g, b) colors. + Example: plt.get_cmap('tab10').colors + """ + if cmaps == None: + self.cmaps = [ + BaseCraft._get_alpha_cmap((54, 197, 240)), + BaseCraft._get_alpha_cmap((210, 40, 95)), + BaseCraft._get_alpha_cmap((236, 178, 46)), + BaseCraft._get_alpha_cmap((15, 157, 88)), + BaseCraft._get_alpha_cmap((84, 25, 85)), + BaseCraft._get_alpha_cmap((55, 35, 235)) + ] + else: + self.cmaps = [BaseCraft._get_alpha_cmap(cmap) for cmap in cmaps] + + if len(self.cmaps) < len(self.most_important_concepts): + raise RuntimeError(f'Not enough colors in cmaps ({len(self.cmaps)}) ' \ + f'compared to the number of important concepts ({len(self.most_important_concepts)})') + + def plot_concept_attribution_legend(self, p: int=5): + """ + Plot a legend for the concepts attribution maps. + + Parameters + ---------- + p + Width of the border around each concept image, in pixels. Defaults to 5. + """ + if self.cmaps == None: + self.set_concept_attribution_cmap() + + for i, c_id in enumerate(self.most_important_concepts): + cmap = self.cmaps[i] + plt.subplot(1, len(self.most_important_concepts), i+1) + + best_crops_id = np.argsort(self.crops_u[:, c_id])[::-1][0] + best_crop = self.crops[best_crops_id] + + if best_crop.shape[0] > best_crop.shape[-1]: + mask = np.zeros(best_crop.shape[:-1]) # tf + else: + mask = np.zeros(best_crop.shape[1:]) # torch + mask[:p, :] = 1.0 + mask[:, :p] = 1.0 + mask[-p:, :] = 1.0 + mask[:, -p:] = 1.0 + + BaseCraft._show(best_crop) + BaseCraft._show(mask, cmap=cmap) + plt.title(f"{c_id}", color=cmap(1.0)) + + plt.show() + + def plot_concept_attribution_maps(self, + id: int, + percentile: int = 95): + """ + Display the concepts attribution maps for the image `id` given in argument. + + Parameters + ---------- + id + The id of the image to display. + percentile + Percentile used to filter the concept heatmap + (only show concept if excess N-th percentile). Defaults to 95. + """ + if self.cmaps == None: + self.set_concept_attribution_cmap() + + # img = images_preprocessed[id] + # u = images_u[id] + img = self.inputs[id] + u = self.images_u[id] + + BaseCraft._show(img) + + for i, c_id in enumerate(self.most_important_concepts): + + cmap = self.cmaps[i] + heatmap = u[:, :, c_id] + + # only show concept if excess N-th percentile + sigma = np.percentile(np.array(heatmap).flatten(), percentile) + heatmap = heatmap * np.array(heatmap > sigma, np.float32) + + heatmap = cv2.resize(heatmap[:, :, None], dsize=(224, 224), interpolation=cv2.INTER_CUBIC) + BaseCraft._show(heatmap, cmap=cmap, alpha=0.7) + + def plot_image_concepts(self, + id: int, + percentile: int = 95, + filepath: Optional[str] = None): + """ + All in one method displaying several plots for the image `id` given in argument: + - the concepts attribution map for this image + - the best crops for each concept (displayed around the heatmap) + - the importance of each concept + + Parameters + ---------- + id + The id of the image to display. + percentile + Percentile used to filter the concept heatmap + (only show concept if excess N-th percentile). Defaults to 95. + filepath + Path the file will be saved at. If None, the function will call plt.show(). + """ + if self.cmaps == None: + self.set_concept_attribution_cmap() + + fig = plt.figure(figsize=(18, 7)) + + # create the main gridspec which is split in the left and right parts storing + # the crops, and the central part to display the heatmap + nb_rows = ceil(len(self.most_important_concepts) / 2.0) + nb_cols = 4 + gs_main = fig.add_gridspec(nb_rows, nb_cols, hspace=0.4, width_ratios=[0.2, 0.4, 0.2, 0.4]) + + # Central image + # + ax = fig.add_subplot(gs_main[:, 1]) + self.plot_concept_attribution_maps(id, percentile) + + # Concepts: creation of the axes on left and right of the image for the concepts + # + gs_concepts_axes = [gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_main[i, 0]) for i in range(nb_rows)] + gs_right = [gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_main[i, 2]) for i in range(nb_rows)] + gs_concepts_axes.extend(gs_right) + + # display the best crops for each concept, in the order of the most important concept + nb_crops = 6 + for i, c_id in enumerate(self.most_important_concepts): + cmap = self.cmaps[i] + + # use a ghost invisible subplot only to have a border around the crops + ghost_axe = fig.add_subplot(gs_concepts_axes[i][:,:]) + ghost_axe.set_title(f"{c_id}", color=cmap(1.0)) + ghost_axe.axis('off') + + inset_axes = ghost_axe.inset_axes([-0.04, -0.04, 1.08, 1.08]) # actually it is an outer border ... + inset_axes.set_xticks([]) + inset_axes.set_yticks([]) + for spine in inset_axes.spines.values(): # border color + spine.set_edgecolor(color=cmap(1.0)) + spine.set_linewidth(3) + + # draw each crop for this concept + gs_current = gridspec.GridSpecFromSubplotSpec(2, 3, subplot_spec=gs_concepts_axes[i][:,:]) + + best_crops_ids = np.argsort(self.crops_u[:, c_id])[::-1][:nb_crops] + best_crops = np.array(self.crops)[best_crops_ids] + for i in range(nb_crops): + ax = plt.Subplot(fig, gs_current[i // 3, i % 3]) + # plt.subplots_adjust(bottom=0.2, right=0.8, top=0.8, left=0.2) + fig.add_subplot(ax) + BaseCraft._show(best_crops[i]) + + # Right plot: importances + # + importance_axe = gridspec.GridSpecFromSubplotSpec(3, 2, width_ratios=[0.1, 0.9], + height_ratios=[0.15, 0.6, 0.15], + subplot_spec=gs_main[:, 3]) + ax = fig.add_subplot(importance_axe[1, 1]) + self.plot_concepts_importances(only_most_important=True, quiet=True) + + + if filepath is not None: + plt.savefig(filepath) + else: + plt.show() \ No newline at end of file diff --git a/xplique/concepts/craft_tf.py b/xplique/concepts/craft_tf.py new file mode 100644 index 00000000..878c22aa --- /dev/null +++ b/xplique/concepts/craft_tf.py @@ -0,0 +1,124 @@ + +""" +CRAFT Module for Tensorflow +""" + +from typing import Callable, Optional, Tuple +import keras +import tensorflow as tf +import numpy as np + +from .craft import BaseCraft + +class CraftTf(BaseCraft): + + """ + Class implementing the CRAFT Concept Extraction Mechanism on Tensorflow. + + Parameters + ---------- + input_to_latent + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must be a Tensorflow model (keras.engine.base_layer.Layer) accepting + data of shape (n_samples, height, width, channels). + latent_to_logit + The second part of the model taking activation and returning + logits, h(.) in the original paper. + Must be a Tensorflow model (keras.engine.base_layer.Layer). + number_of_concepts + The number of concepts to extract. Default is 20. + batch_size + The batch size to use during training and prediction. Default is 64. + patch_size + The size of the patches to extract from the input data. Default is 64. + """ + def __init__(self, input_to_latent : Callable, + latent_to_logit : Callable, + number_of_concepts: int = 20, + batch_size: int = 64, + patch_size: int = 64): + super().__init__(input_to_latent, latent_to_logit, number_of_concepts, batch_size) + self.patch_size = patch_size + + # Check model type + is_tf_model = issubclass(type(input_to_latent), keras.engine.base_layer.Layer) & \ + issubclass(type(latent_to_logit), keras.engine.base_layer.Layer) + if not is_tf_model: + raise TypeError('input_to_latent and latent_to_logit are not Tensorflow models') + + def _latent_predict(self, inputs: tf.Tensor): + """ + Compute the embedding space using the 1st model `input_to_latent`. + + Parameters + ---------- + inputs + Input data of shape (n_samples, height, width, channels). + + Returns + ------- + activations + The latent activations of shape (n_samples, height, width, channels) + """ + return self.input_to_latent.predict(inputs, batch_size=self.batch_size, verbose=False) + + def _logit_predict(self, activations: np.ndarray): + """ + Compute logits from activations using the 2nd model `latent_to_logit`. + + Parameters + ---------- + activations + Activations produced by the 1st model `input_to_latent`, + of shape (n_samples, height, width, channels). + + Returns + ------- + logits + The logits of shape (n_samples, n_classes) + """ + return self.latent_to_logit.predict(activations, batch_size=self.batch_size, verbose=False) + + def _extract_patches(self, inputs: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Extract patches from the input images, and compute their embeddings. + + Parameters + ---------- + inputs + Input images (n_samples, height, width, channels). + + Returns + ------- + (patches, activations) : + A tuple containing the patches (n_patches, height, width, channels), + and their activations (n_patches, channels) + """ + + strides = int(self.patch_size * 0.80) + patches = tf.image.extract_patches(images=inputs, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, strides, strides, 1], + rates=[1, 1, 1, 1], + padding='VALID') + patches = tf.reshape(patches, (-1, self.patch_size, self.patch_size, inputs.shape[-1])) + + # encode the patches and obtain the activations + input_width, input_height = inputs.shape[1], inputs.shape[2] + activations = self._latent_predict(tf.image.resize(patches, (input_width, input_height), method="bicubic")) + assert np.min(activations) >= 0.0, "Activations must be positive." + + # if the activations have shape (n_samples, height, width, n_channels), + # apply average pooling + if len(activations.shape) == 4: + # activations: (N, H, W, C) + activations = tf.reduce_mean(activations, axis=(1, 2)) + + return patches, activations + + def _to_np_array(self, input: tf.Tensor, dtype: type): + """ + Converts a Tensorflow tensor into a numpy array. + """ + return np.array(input, dtype) \ No newline at end of file diff --git a/xplique/concepts/craft_torch.py b/xplique/concepts/craft_torch.py new file mode 100644 index 00000000..16d187fb --- /dev/null +++ b/xplique/concepts/craft_torch.py @@ -0,0 +1,191 @@ + +""" +CRAFT Module for Pytorch +""" + +from typing import Callable, Optional, Tuple +import torch +import numpy as np +from math import ceil + +from .craft import BaseCraft + +def _batch_inference(model: torch.nn.Module, + dataset: torch.Tensor, + batch_size: int = 128, + resize: Optional[int] = None, + device: Optional[str]='cuda') -> torch.Tensor: + """ + Compute the model predictions of the input images. + + Parameters + ---------- + model + The model to use for inference. + dataset + The input images to be processed, of shape (n_samples, channels, height, width). + batch_size + The batch size to use during training and prediction. Defaults to 128. + resize + Optional argument to resize the outputs of the model. Defaults to None. + device + The compute device. Defaults to 'cuda'. + + Returns + ------- + activations + The latent activations of shape (n_samples, channels, height, width). + """ + # dataset: (N, C, H, W) + nb_batchs = ceil(len(dataset) / batch_size) + start_ids = [i*batch_size for i in range(nb_batchs)] + + results = [] + + with torch.no_grad(): + for i in start_ids: + x = dataset[i:i+batch_size] + x = x.to(device) + + if resize: + x = torch.nn.functional.interpolate(x, size=resize, mode='bilinear', align_corners=False) + + results.append(model(x).cpu()) + + results = torch.cat(results) + return results + + +class CraftTorch(BaseCraft): + """ + Class Implementing the CRAFT Concept Extraction Mechanism on Pytorch. + + Parameters + ---------- + input_to_latent + The first part of the model taking an input and returning + positive activations, g(.) in the original paper. + Must be a Pytorch model (torch.nn.modules.module.Module) accepting + data of shape (n_samples, channels, height, width). + latent_to_logit + The second part of the model taking activation and returning + logits, h(.) in the original paper. + Must be a Pytorch model (torch.nn.modules.module.Module). + number_of_concepts + The number of concepts to extract. Default is 20. + batch_size + The batch size to use during training and prediction. Default is 64. + patch_size + The size of the patches to extract from the input data. Default is 64. + device + The device to use. Default is 'cuda'. + """ + + def __init__(self, input_to_latent: Callable, + latent_to_logit: Callable, + number_of_concepts: int = 20, + batch_size: int = 64, + patch_size: int = 64, + device : str = 'cuda'): + super().__init__(input_to_latent, latent_to_logit, number_of_concepts, batch_size) + self.patch_size = patch_size + self.device = device + + # Check model type + is_torch_model = issubclass(type(input_to_latent), torch.nn.modules.module.Module) & \ + issubclass(type(latent_to_logit), torch.nn.modules.module.Module) + if not is_torch_model: + raise TypeError('input_to_latent and latent_to_logit are not Pytorch modules') + + def _latent_predict(self, inputs: torch.Tensor, resize=None) -> torch.Tensor: + """ + Compute the embedding space using the 1st model `input_to_latent`. + + Parameters + ---------- + inputs + Input data of shape (n_samples, channels, height, width). + + Returns + ------- + activations + The latent activations of shape (n_samples, height, width, channels) + """ + # inputs: (N, C, H, W) + activations = _batch_inference(self.input_to_latent, inputs, self.batch_size, resize, device=self.device) + if len(activations.shape) == 4: + # activations: (N, C, H, W) -> (N, H, W, C) + activations = activations.permute(0, 2, 3, 1) + return activations + + def _logit_predict(self, activations: np.ndarray, resize=None) -> torch.Tensor: + """ + Compute logits from activations using the 2nd model `latent_to_logit`. + + Parameters + ---------- + activations + Activations produced by the 1st model `input_to_latent`, + of shape (n_samples, height, width, channels). + + Returns + ------- + logits + The logits of shape (n_samples, n_classes) + """ + activations_perturbated = torch.from_numpy(activations) + + if len(activations_perturbated.shape) == 4: + # activations_perturbated: (N, H, W, C) -> (N, C, H, W) + activations_perturbated = activations_perturbated.permute(0, 3, 1, 2) + + y_pred = _batch_inference(self.latent_to_logit, activations_perturbated, self.batch_size, resize, device=self.device) + return self._to_np_array(y_pred) + + def _extract_patches(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, np.ndarray]: + """ + Extract patches from the input images, and compute their embeddings. + + Parameters + ---------- + inputs + Input images (n_samples, channels, height, width) + + Returns + ------- + (patches, activations) : + A tuple containing the patches (n_patches, channels, height, width), + and their activations (n_patches, channels) + """ + + image_size = inputs.shape[2] + num_channels = inputs.shape[1] + + # extract patches from the input data, keep patches on cpu + strides = int(self.patch_size * 0.80) + + patches = torch.nn.functional.unfold(inputs, kernel_size=self.patch_size, stride=strides) + patches = patches.transpose(1, 2).contiguous().view(-1, num_channels, self.patch_size, self.patch_size) + + # encode the patches and obtain the activations + activations = self._latent_predict(patches, resize=image_size) + + assert torch.min(activations) >= 0.0, "Activations must be positive." + + # if the activations have shape (n_samples, height, width, n_channels), + # apply average pooling + if len(activations.shape) == 4: + # activations: (N, H, W, R) + activations = torch.mean(activations, dim=(1, 2)) + + return patches, self._to_np_array(activations) + + def _to_np_array(self, input: torch.Tensor, dtype: type=None): + """ + Converts a Pytorch tensor into a numpy array. + """ + res = input.detach().cpu().numpy() + if dtype != None: + return res.astype(dtype) + else: + return res \ No newline at end of file