Skip to content

Commit

Permalink
Merge branch 'dev' into dev_duplex
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Nov 7, 2024
2 parents 2a8c699 + 41f3890 commit 3017d96
Show file tree
Hide file tree
Showing 21 changed files with 2,419 additions and 599 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
hooks:
- id: black

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
hooks:
- id: mypy
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.0.1
# hooks:
# - id: mypy
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,15 @@



Logo made with the help of DALL-E 2.
Logo made with the help of DALL-E 2.

Installing:
1. Clone this repository
2. Create a `conda` environment with `python, pytorch, torchvision`; I recommend `mamba`
3. Activate your new environment (`mamba activate ...`)
4. Change into the directory holding this repository.
5. `pip install .`

Installing as developper:
1. - 4. Same as above.
5. `pip install -e .\[dev\]`
83 changes: 83 additions & 0 deletions docs/tutorials/attribute.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Attribution and evaluation given counterfactuals

## Attribution
```python
# Load the classifier
from quac.generate import load_classifier
classifier = load_classifier(

)

# Defining attributions
from quac.attribution import (
DDeepLift,
DIntegratedGradients,
AttributionIO
)
from torchvision import transforms

attributor = AttributionIO(
attributions = {
"deeplift" : DDeepLift(),
"ig" : DIntegratedGradients()
},
output_directory = "my_attributions_directory"
)

transform = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.Normalize(...)
]
)

# This will run attributions and store all of the results in the output_directory
# Shows a progress bar
attributor.run(
source_directory="my_source_image_directory",
counterfactual_directory="my_counterfactual_image_directory",
transform=transform
)
```

## Evaluation
Once you have attributions, you can run evaluations.
You may want to try different methods for thresholding and smoothing the attributions to get masks.


In this example, we evaluate the results from the DeepLift attribution method.

```python
# Defining processors and evaluators
from quac.evaluation import Processor, Evaluator
from sklearn.metrics import ConfusionMatrixDisplay

classifier = load_classifier(...)

evaluator = Evaluator(
classifier,
source_directory="my_source_image_directory",
counterfactual_directory="my_counterfactual_image_directory",
attribution_directory="my_attributions_directory/deeplift",
transform=transform
)


cf_confusion_matrix = evaluator.classification_report(
data="counterfactuals", # this is the default
return_classifications=False,
print_report=True,
)

# Plot the confusion matrix
disp = ConfusionMatrixDisplay(
confusion_matrix=cf_confusion_matrix,
)
disp.show()

# Run QuAC evaluation on your attribution and store a report
report = evaluator.quantify(processor=Processor())
# The report will be stored based on the processor's name, which is "default" by default
report.store("my_attributions_directory/deeplift/reports")
```
151 changes: 151 additions & 0 deletions docs/tutorials/train.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Training the StarGAN

In this tutorial, we go over the basics of how to train a (slightly modified) StarGAN for use in QuAC.

## Defining the dataset

The data is expected to be in the form of image files with a directory structure denoting the classification.
For example:
```
data_folder/
crow/
crow1.png
crow2.png
raven/
raven1.png
raven2.png
```

A training dataset is defined in `quac.training.data` which will need to be given two directories: a `source` and a `reference`. These directories can be the same.

The validation dataset will need the same information.

For example:
```python
from quac.training.data import TrainingDataset

dataset = TrainingDataset(
source="path/to/training/data",
reference="path/to/training/data",
img_size=128,
batch_size=4,
num_workers=4
)

# Setup data for validation
val_dataset = ValidationData(
source="path/to/training/data",
reference="path/to/training/data",
img_size=128,
batch_size=16,
num_workers=16
)

```
## Defining the models

The models can be built using a function in `quac.training.stargan`.

```python
from quac.training.stargan import build_model

nets, nets_ema = build_model(
img_size=256, # Images are made square
style_dim=64, # The size of the style vector
input_dim=1, # Number of channels in the input
latent_dim=16, # The size of the random latent
num_domains=4, # Number of classes
single_output_style_encoder=False
)
## Defining the models
nets, nets_ema = build_model(**experiment.model.model_dump())

```

If using multiple or specific GPUs, it may be necessary to add the `gpu_ids` argument.

The `nets_ema` are a copy of the `nets` that will not be trained but rather will be an exponential moving average of the weight of the `nets`.
The sub-networks of both can be accessed in a dictionary-like manner.

## Creating a logger
```python
# Example using WandB
logger = Logger.create(
log_type="wandb",
project="project-name",
name="experiment name",
tags=["experiment", "project", "test", "quac", "stargan"],
hparams={ # this holds all of the hyperparameters you want to store for your run
"hyperparameter_key": "Hyperparameter values"
}
)

# TODO example using tensorboard
```

## Defining the Solver

It is now time to initiate the `Solver` object, which will do the bulk of the work in training.

```python
solver = Solver(
nets,
nets_ema,
# Checkpointing
checkpoint_dir="path/to/store/checkpoints",
# Parameters for the Adam optimizers
lr=1e-4,
beta1=0.5,
beta2=0.99,
weight_decay=0.1,
)

# TODO
solver = Solver(nets, nets_ema, **experiment.solver.model_dump(), run=logger)
```

## Training
We use the solver to train on the data as follows:

```python
from quac.training.options import ValConfig
val_config=ValConfig(
classifier_checkpoint="/path/to/classifier/", mean=0.5, std=0.5
)

solver.train(dataset, val_config)
```

All results will be stored in the `checkpoint_directory` defined above.
Validation will be done during training at regular intervals (by default, every 10000 iterations).

## BONUS: Training with a Config file

```python
run_config=RunConfig(
# All of these are default
resume_iter=0,
total_iter=100000,
log_every=1000,
save_every=10000,
eval_every=10000,
)
val_config=ValConfig(
classifier_checkpoint="/path/to/classifier/",
# The below is default
val_batch_size=32
num_outs_per_domain=10,
mean=0.5,
std=0.5,
grayscale=True,
)
loss_config=LossConfig(
# The following should probably not be changed
# unless you really know what you're doing :)
# All of these are default
lambda_ds=1.,
lambda_reg=1.,
lambda_sty=1.,
lambda_cyc=1.,
)
```
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ authors = [
dynamic = ["version"]
dependencies = [
"captum",
"numpy",
"torch",
"numpy",
"munch",
"torch",
"torchvision",
"funlib.learn.torch@git+https://github.com/funkelab/funlib.learn.torch",
"opencv-python",
"scipy"
"pydantic",
"scipy",
"scikit-learn"
]

[project.optional-dependencies]
Expand Down
Loading

0 comments on commit 3017d96

Please sign in to comment.