Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor codebase #8

Merged
merged 39 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a658227
feat: :tada: Add StarGAN model for training
adjavon Feb 18, 2024
5916215
feat: :construction: Sketch StarGAN code
adjavon Mar 9, 2024
6651ff6
feat: :sparkles: Define how to generate counterfactuals
adjavon Mar 21, 2024
d549533
refactor: :art: Make dataloaders more general and add tutorial
adjavon Mar 28, 2024
3943f78
refactor: :art: Add parametrization with defaults
adjavon Apr 1, 2024
3b70926
refactor: :construction: Remove reliance on cuda
adjavon Apr 1, 2024
8e3f28e
feat: :sparkles: Update reference-based inference
adjavon Apr 1, 2024
f6af01b
refactor: :construction: Make attribution user-friendly
adjavon Apr 2, 2024
33eddbd
chore: :heavy_plus_sign: Update installation
adjavon Apr 3, 2024
41037ed
refactor: :art: Add and match attribution tutorial
adjavon Apr 3, 2024
2847bdb
refactor: :art: Add normalization to attributions
adjavon Apr 8, 2024
e1ed2c2
refactor: :art: Fix typos and restructure
adjavon Apr 8, 2024
e997147
refactor: :construction: Add the tutorial on training StarGAN
adjavon Apr 11, 2024
de3b5a5
fix: Account for incomplete reports
adjavon Apr 23, 2024
68114c7
refactor: :construction: Match train code to API
adjavon Apr 23, 2024
22c4f35
fix: :art: Debug using fictus experiment
adjavon Apr 28, 2024
75928ea
fix: :ambulance: Avoid collapse with new StarGAN
adjavon Apr 30, 2024
2993787
refactor: :poop: Diversify logging and update tutorial
adjavon May 1, 2024
c6ccbd3
fix: :bug: Add final activation during inference
adjavon May 23, 2024
e714749
fix: :bug: Add normalization assumption in classifier
adjavon May 23, 2024
f41e813
Fix typos and shape bugs
adjavon Jun 2, 2024
8542144
Get access to counterfactual prediction
adjavon Jun 2, 2024
2f0cfcd
Update training code
adjavon Jun 28, 2024
e547fcc
Remove random cropping by default
adjavon Aug 8, 2024
bdc056b
feat: :adhesive_bandage: Add the ability to do nothing in Classifier …
adjavon Sep 30, 2024
e7d3dd5
feat: :art: Log EMA images during training
adjavon Sep 30, 2024
67ece91
Merge pull request #7 from funkelab/tutorial_refactor
adjavon Sep 30, 2024
d559dd9
style: :pencil2: Fix typo in docstring
adjavon Sep 30, 2024
1f14bd9
refactor: :art: Expand do nothing in classifier
adjavon Oct 1, 2024
f3a2095
Add forced resume on logging
adjavon Oct 9, 2024
2dc04eb
fix: :bug: Fix how the optimal mask is chosen.
adjavon Oct 10, 2024
fe44eed
feat: :poop: Add an unblurred version of the processor
adjavon Oct 10, 2024
3b131cb
refactor: :adhesive_bandage: Increase the number of tries during eval…
adjavon Oct 14, 2024
aa65ec0
fix: :heavy_plus_sign: Add dependency which broke tests
adjavon Oct 14, 2024
38ad914
feat: :sparkles: Add test data loader
adjavon Oct 15, 2024
2e81122
feat: :zap: Allow batched attribution as well as unbatched
adjavon Oct 16, 2024
4762a94
fix: :zap: Use torch in batched attributions
adjavon Nov 4, 2024
eb9e348
feat: :alembic: Make RGB explicit
adjavon Nov 4, 2024
41f3890
refactor: :chart_with_upwards_trend: Do not force resume on WandB
adjavon Nov 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
hooks:
- id: black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
hooks:
- id: mypy
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.0.1
# hooks:
# - id: mypy
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,15 @@



Logo made with the help of DALL-E 2.
Logo made with the help of DALL-E 2.

Installing:
1. Clone this repository
2. Create a `conda` environment with `python, pytorch, torchvision`; I recommend `mamba`
3. Activate your new environment (`mamba activate ...`)
4. Change into the directory holding this repository.
5. `pip install .`

Installing as developper:
1. - 4. Same as above.
5. `pip install -e .\[dev\]`
83 changes: 83 additions & 0 deletions docs/tutorials/attribute.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Attribution and evaluation given counterfactuals

## Attribution
```python
# Load the classifier
from quac.generate import load_classifier
classifier = load_classifier(

)

# Defining attributions
from quac.attribution import (
DDeepLift,
DIntegratedGradients,
AttributionIO
)
from torchvision import transforms

attributor = AttributionIO(
attributions = {
"deeplift" : DDeepLift(),
"ig" : DIntegratedGradients()
},
output_directory = "my_attributions_directory"
)

transform = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.Normalize(...)
]
)

# This will run attributions and store all of the results in the output_directory
# Shows a progress bar
attributor.run(
source_directory="my_source_image_directory",
counterfactual_directory="my_counterfactual_image_directory",
transform=transform
)
```

## Evaluation
Once you have attributions, you can run evaluations.
You may want to try different methods for thresholding and smoothing the attributions to get masks.


In this example, we evaluate the results from the DeepLift attribution method.

```python
# Defining processors and evaluators
from quac.evaluation import Processor, Evaluator
from sklearn.metrics import ConfusionMatrixDisplay

classifier = load_classifier(...)

evaluator = Evaluator(
classifier,
source_directory="my_source_image_directory",
counterfactual_directory="my_counterfactual_image_directory",
attribution_directory="my_attributions_directory/deeplift",
transform=transform
)


cf_confusion_matrix = evaluator.classification_report(
data="counterfactuals", # this is the default
return_classifications=False,
print_report=True,
)

# Plot the confusion matrix
disp = ConfusionMatrixDisplay(
confusion_matrix=cf_confusion_matrix,
)
disp.show()

# Run QuAC evaluation on your attribution and store a report
report = evaluator.quantify(processor=Processor())
# The report will be stored based on the processor's name, which is "default" by default
report.store("my_attributions_directory/deeplift/reports")
```
138 changes: 138 additions & 0 deletions docs/tutorials/generate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# How to generate images from a pre-trained network

## Defining the dataset

We will be generating images one source-target pair at a time.
As such, we need to point to the subdirectory that holds the source class that we are interested in.
For example, below, we are going to be using the validation data, and our source class will be class `0` which has no Diabetic Retinopathy.

```python
from pathlib import Path
from quac.generate import load_data

img_size = 224
data_directory = Path("root_directory/val/0_No_DR")
dataset = load_data(data_directory, img_size, grayscale=False)
```
## Loading the classifier

Next we need to load the pre-trained classifier, and wrap it in the correct pre-processing step.
The classifier is expected to be saved as a `torchscript` checkpoint. This allows us to use it without having to redefine the python class from which it was generated.

We also have a wrapper around the classifier that re-normalizes images to the range that it expects. The assumption is that these images come from the StarGAN trained with `quac`, so the images will have values in `[-1, 1]`.
Here, our pre-trained classifier expects images with the ImageNet normalization, for example.

Finally, we need to define the device, and whether to put the classifier in `eval` mode.

```python
from quac.generate import load_classifier

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

classifier = load_classifier(classifier_checkpoint, mean=mean, std=std, eval=True, device=device)
```

## Inference from random latents

The StarGAN model used to generate images can have two sources for the style.
The first and simplest one is to use a random latent vector to create style.

### Loading the StarGAN

```python
from quac.generate import load_stargan

latent_model_checkpoint_dir = Path("/path/to/directory/holding/the/stargan/checkpoints")

inference_model = load_stargan(
latent_model_checkpoint_dir,
img_size=224,
input_dim=1,
style_dim=64,
latent_dim=16,
num_domains=5,
checkpoint_iter=100000,
kind = "latent"
)
```

### Running the image generation

```python
from quac.generate import get_counterfactual
from torchvision.utils import save_image

output_directory = Path("/path/to/output/latent/0_No_DR/1_Mild/")

for x, name in tqdm(dataset):
xcf = get_counterfactual(
classifier,
inference_model,
x,
target=1,
kind="latent",
device=device,
max_tries=10,
batch_size=10
)
# For example, you can save the images here
save_image(xcf, output_directory / name)
```

## Inference using a reference dataset

The alternative image generation method of a StarGAN is to use an image of the target class to generate the style using the `StyleEncoder`.
Although the structure is similar as above, there are a few key differences.


### Generating the reference dataset

The first thing we need to do is to get the reference images.

```python
reference_data_directory = Path(f"{root_directory}/val/1_Mild")
reference_dataset = load_data(reference_data_directory, img_size, grayscale=False)
```

### Loading the StarGAN
This time, we will be creating a `ReferenceInferenceModel`.

```python
inference_model = load_stargan(
latent_model_checkpoint_dir,
img_size=224,
input_dim=1,
style_dim=64,
latent_dim=16,
num_domains=5,
checkpoint_iter=100000,
kind = "reference"
)
```

### Running the image generation

Finally, we combine the two by changing the `kind` in our counterfactual generation, and giving it the reference dataset to use.

```python
from torchvision.utils import save_image

output_directory = Path("/path/to/output/reference/0_No_DR/1_Mild/")

for x, name in tqdm(dataset):
xcf = get_counterfactual(
classifier,
inference_model,
x,
target=1,
kind="reference", # Change the kind of inference being done
dataset_ref=reference_dataset, # Add the reference dataset
device=device,
max_tries=10,
batch_size=10
)
# For example, you can save the images here
save_image(xcf, output_directory / name)
```
Loading
Loading