diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a2bf1d6..dc763e9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: black - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.0.1 - hooks: - - id: mypy + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.0.1 + # hooks: + # - id: mypy diff --git a/README.md b/README.md index 1a46bf0..9e27d68 100644 --- a/README.md +++ b/README.md @@ -4,4 +4,15 @@ -Logo made with the help of DALL-E 2. \ No newline at end of file +Logo made with the help of DALL-E 2. + +Installing: +1. Clone this repository +2. Create a `conda` environment with `python, pytorch, torchvision`; I recommend `mamba` +3. Activate your new environment (`mamba activate ...`) +4. Change into the directory holding this repository. +5. `pip install .` + +Installing as developper: +1. - 4. Same as above. +5. `pip install -e .\[dev\]` diff --git a/docs/tutorials/attribute.md b/docs/tutorials/attribute.md new file mode 100644 index 0000000..abb97c3 --- /dev/null +++ b/docs/tutorials/attribute.md @@ -0,0 +1,83 @@ +# Attribution and evaluation given counterfactuals + +## Attribution +```python +# Load the classifier +from quac.generate import load_classifier +classifier = load_classifier( + +) + +# Defining attributions +from quac.attribution import ( + DDeepLift, + DIntegratedGradients, + AttributionIO +) +from torchvision import transforms + +attributor = AttributionIO( + attributions = { + "deeplift" : DDeepLift(), + "ig" : DIntegratedGradients() + }, + output_directory = "my_attributions_directory" +) + +transform = transforms.Compose( + [ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.Normalize(...) + ] +) + +# This will run attributions and store all of the results in the output_directory +# Shows a progress bar +attributor.run( + source_directory="my_source_image_directory", + counterfactual_directory="my_counterfactual_image_directory", + transform=transform +) +``` + +## Evaluation +Once you have attributions, you can run evaluations. +You may want to try different methods for thresholding and smoothing the attributions to get masks. + + +In this example, we evaluate the results from the DeepLift attribution method. + +```python +# Defining processors and evaluators +from quac.evaluation import Processor, Evaluator +from sklearn.metrics import ConfusionMatrixDisplay + +classifier = load_classifier(...) + +evaluator = Evaluator( + classifier, + source_directory="my_source_image_directory", + counterfactual_directory="my_counterfactual_image_directory", + attribution_directory="my_attributions_directory/deeplift", + transform=transform +) + + +cf_confusion_matrix = evaluator.classification_report( + data="counterfactuals", # this is the default + return_classifications=False, + print_report=True, + ) + +# Plot the confusion matrix +disp = ConfusionMatrixDisplay( + confusion_matrix=cf_confusion_matrix, +) +disp.show() + +# Run QuAC evaluation on your attribution and store a report +report = evaluator.quantify(processor=Processor()) +# The report will be stored based on the processor's name, which is "default" by default +report.store("my_attributions_directory/deeplift/reports") +``` diff --git a/docs/tutorials/generate.md b/docs/tutorials/generate.md new file mode 100644 index 0000000..19aba8d --- /dev/null +++ b/docs/tutorials/generate.md @@ -0,0 +1,138 @@ +# How to generate images from a pre-trained network + +## Defining the dataset + +We will be generating images one source-target pair at a time. +As such, we need to point to the subdirectory that holds the source class that we are interested in. +For example, below, we are going to be using the validation data, and our source class will be class `0` which has no Diabetic Retinopathy. + +```python +from pathlib import Path +from quac.generate import load_data + +img_size = 224 +data_directory = Path("root_directory/val/0_No_DR") +dataset = load_data(data_directory, img_size, grayscale=False) +``` +## Loading the classifier + +Next we need to load the pre-trained classifier, and wrap it in the correct pre-processing step. +The classifier is expected to be saved as a `torchscript` checkpoint. This allows us to use it without having to redefine the python class from which it was generated. + +We also have a wrapper around the classifier that re-normalizes images to the range that it expects. The assumption is that these images come from the StarGAN trained with `quac`, so the images will have values in `[-1, 1]`. +Here, our pre-trained classifier expects images with the ImageNet normalization, for example. + +Finally, we need to define the device, and whether to put the classifier in `eval` mode. + +```python +from quac.generate import load_classifier + +mean = (0.485, 0.456, 0.406) +std = (0.229, 0.224, 0.225) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +classifier = load_classifier(classifier_checkpoint, mean=mean, std=std, eval=True, device=device) +``` + +## Inference from random latents + +The StarGAN model used to generate images can have two sources for the style. +The first and simplest one is to use a random latent vector to create style. + +### Loading the StarGAN + +```python +from quac.generate import load_stargan + +latent_model_checkpoint_dir = Path("/path/to/directory/holding/the/stargan/checkpoints") + +inference_model = load_stargan( + latent_model_checkpoint_dir, + img_size=224, + input_dim=1, + style_dim=64, + latent_dim=16, + num_domains=5, + checkpoint_iter=100000, + kind = "latent" +) +``` + +### Running the image generation + +```python +from quac.generate import get_counterfactual +from torchvision.utils import save_image + +output_directory = Path("/path/to/output/latent/0_No_DR/1_Mild/") + +for x, name in tqdm(dataset): + xcf = get_counterfactual( + classifier, + inference_model, + x, + target=1, + kind="latent", + device=device, + max_tries=10, + batch_size=10 + ) + # For example, you can save the images here + save_image(xcf, output_directory / name) +``` + +## Inference using a reference dataset + +The alternative image generation method of a StarGAN is to use an image of the target class to generate the style using the `StyleEncoder`. +Although the structure is similar as above, there are a few key differences. + + +### Generating the reference dataset + +The first thing we need to do is to get the reference images. + +```python +reference_data_directory = Path(f"{root_directory}/val/1_Mild") +reference_dataset = load_data(reference_data_directory, img_size, grayscale=False) +``` + +### Loading the StarGAN +This time, we will be creating a `ReferenceInferenceModel`. + +```python +inference_model = load_stargan( + latent_model_checkpoint_dir, + img_size=224, + input_dim=1, + style_dim=64, + latent_dim=16, + num_domains=5, + checkpoint_iter=100000, + kind = "reference" +) +``` + +### Running the image generation + +Finally, we combine the two by changing the `kind` in our counterfactual generation, and giving it the reference dataset to use. + +```python +from torchvision.utils import save_image + +output_directory = Path("/path/to/output/reference/0_No_DR/1_Mild/") + +for x, name in tqdm(dataset): + xcf = get_counterfactual( + classifier, + inference_model, + x, + target=1, + kind="reference", # Change the kind of inference being done + dataset_ref=reference_dataset, # Add the reference dataset + device=device, + max_tries=10, + batch_size=10 + ) + # For example, you can save the images here + save_image(xcf, output_directory / name) +``` diff --git a/docs/tutorials/train.md b/docs/tutorials/train.md new file mode 100644 index 0000000..9e18aee --- /dev/null +++ b/docs/tutorials/train.md @@ -0,0 +1,151 @@ +# Training the StarGAN + +In this tutorial, we go over the basics of how to train a (slightly modified) StarGAN for use in QuAC. + +## Defining the dataset + +The data is expected to be in the form of image files with a directory structure denoting the classification. +For example: +``` +data_folder/ + crow/ + crow1.png + crow2.png + raven/ + raven1.png + raven2.png +``` + +A training dataset is defined in `quac.training.data` which will need to be given two directories: a `source` and a `reference`. These directories can be the same. + +The validation dataset will need the same information. + +For example: +```python +from quac.training.data import TrainingDataset + +dataset = TrainingDataset( + source="path/to/training/data", + reference="path/to/training/data", + img_size=128, + batch_size=4, + num_workers=4 +) + +# Setup data for validation +val_dataset = ValidationData( + source="path/to/training/data", + reference="path/to/training/data", + img_size=128, + batch_size=16, + num_workers=16 +) + +``` +## Defining the models + +The models can be built using a function in `quac.training.stargan`. + +```python +from quac.training.stargan import build_model + +nets, nets_ema = build_model( + img_size=256, # Images are made square + style_dim=64, # The size of the style vector + input_dim=1, # Number of channels in the input + latent_dim=16, # The size of the random latent + num_domains=4, # Number of classes + single_output_style_encoder=False +) +## Defining the models +nets, nets_ema = build_model(**experiment.model.model_dump()) + +``` + +If using multiple or specific GPUs, it may be necessary to add the `gpu_ids` argument. + +The `nets_ema` are a copy of the `nets` that will not be trained but rather will be an exponential moving average of the weight of the `nets`. +The sub-networks of both can be accessed in a dictionary-like manner. + +## Creating a logger +```python +# Example using WandB +logger = Logger.create( + log_type="wandb", + project="project-name", + name="experiment name", + tags=["experiment", "project", "test", "quac", "stargan"], + hparams={ # this holds all of the hyperparameters you want to store for your run + "hyperparameter_key": "Hyperparameter values" + } +) + +# TODO example using tensorboard +``` + +## Defining the Solver + +It is now time to initiate the `Solver` object, which will do the bulk of the work in training. + +```python +solver = Solver( + nets, + nets_ema, + # Checkpointing + checkpoint_dir="path/to/store/checkpoints", + # Parameters for the Adam optimizers + lr=1e-4, + beta1=0.5, + beta2=0.99, + weight_decay=0.1, +) + +# TODO +solver = Solver(nets, nets_ema, **experiment.solver.model_dump(), run=logger) +``` + +## Training +We use the solver to train on the data as follows: + +```python +from quac.training.options import ValConfig +val_config=ValConfig( + classifier_checkpoint="/path/to/classifier/", mean=0.5, std=0.5 +) + +solver.train(dataset, val_config) +``` + +All results will be stored in the `checkpoint_directory` defined above. +Validation will be done during training at regular intervals (by default, every 10000 iterations). + +## BONUS: Training with a Config file + +```python +run_config=RunConfig( + # All of these are default + resume_iter=0, + total_iter=100000, + log_every=1000, + save_every=10000, + eval_every=10000, +) +val_config=ValConfig( + classifier_checkpoint="/path/to/classifier/", + # The below is default + val_batch_size=32 + num_outs_per_domain=10, + mean=0.5, + std=0.5, + grayscale=True, +) +loss_config=LossConfig( + # The following should probably not be changed + # unless you really know what you're doing :) + # All of these are default + lambda_ds=1., + lambda_reg=1., + lambda_sty=1., + lambda_cyc=1., +) +``` diff --git a/pyproject.toml b/pyproject.toml index 59a1c12..a843ada 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,12 +18,15 @@ authors = [ dynamic = ["version"] dependencies = [ "captum", - "numpy", - "torch", + "numpy", + "munch", + "torch", "torchvision", "funlib.learn.torch@git+https://github.com/funkelab/funlib.learn.torch", "opencv-python", - "scipy" + "pydantic", + "scipy", + "scikit-learn" ] [project.optional-dependencies] diff --git a/src/quac/attribution.py b/src/quac/attribution.py index 233c483..7743671 100644 --- a/src/quac/attribution.py +++ b/src/quac/attribution.py @@ -1,13 +1,18 @@ """Holds all of the discriminative attribution methods that are accepted by QuAC.""" -from captum import attr -import torch + +from captum import attr import numpy as np import scipy +from pathlib import Path +from quac.data import PairedImageDataset +from tqdm import tqdm +import torch +from typing import Callable def residual(real_img, fake_img): - """Residual attribution method. - + """Residual attribution method. + This method just takes the standardized difference between the real and fake images. """ res = np.abs(real_img - fake_img) @@ -18,7 +23,7 @@ def residual(real_img, fake_img): def random(real_img, fake_img): """Random attribution method. - This method randomly assigns attribution to each pixel in the image, then applies a Gaussian filter for smoothing. + This method randomly assigns attribution to each pixel in the image, then applies a Gaussian filter for smoothing. """ rand = np.abs(np.random.randn(*np.shape(real_img))) rand = np.abs(scipy.ndimage.gaussian_filter(rand, 4)) @@ -31,60 +36,109 @@ class BaseAttribution: """ Basic format of an attribution class. """ - def __init__(self, classifier): + + def __init__(self, classifier, normalize=True): self.classifier = classifier - - def _attribute(self, real_img, counterfactual_img, real_class, target_class, **kwargs): - raise NotImplementedError("The base attribution class does not have an attribute method.") + self.normalize = normalize + + def _normalize(self, attribution): + """Scale the attribution to be between 0 and 1. + + Note that this also takes the absolute value of the attribution. + Generally in this framework, we only care about the absolute value of the attribution, + because if "negative changes" need to be made, this should be inherent in + the counterfactual image. + """ + attribution = torch.abs(attribution) + # We scale the attribution to be between 0 and 1, batch-wise + min_vals = attribution.flatten(1).min(1)[0][:, None, None, None] + max_vals = attribution.flatten(1).max(1)[0][:, None, None, None] + return (attribution - min_vals) / (max_vals - min_vals) + + def _attribute( + self, real_img, counterfactual_img, real_class, target_class, **kwargs + ): + raise NotImplementedError( + "The base attribution class does not have an attribute method." + ) - def attribute(self, real_img, counterfactual_img, real_class, target_class, **kwargs): + def attribute( + self, + real_img, + counterfactual_img, + real_class, + target_class, + device="cuda", + **kwargs, + ): self.classifier.zero_grad() - attribution = self._attribute(real_img, counterfactual_img, real_class, target_class, **kwargs) - return attribution.detach().cpu().numpy() + # Check if there is a batch dimension, if not, add it + batch_added = False + if len(real_img.shape) == 3: + real_img = real_img[None, ...] + counterfactual_img = counterfactual_img[None, ...] + batch_added = True + + attribution = self._attribute( + real_img.to(device), + counterfactual_img.to(device), + real_class, + target_class, + **kwargs, + ) + attribution = attribution.detach() + if self.normalize: + attribution = self._normalize(attribution) + if batch_added: + attribution = attribution[0] + return attribution.cpu().numpy() class DIntegratedGradients(BaseAttribution): """ Discriminative version of the Integrated Gradients attribution method. """ - def __init__(self, classifier): - super().__init__(classifier) + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) self.ig = attr.IntegratedGradients(classifier) def _attribute(self, real_img, counterfactual_img, real_class, target_class): # FIXME in the original DAPI code, the real and counterfactual were switched. attribution = self.ig.attribute( - real_img[None, ...].cuda(), - baselines=counterfactual_img[None, ...].cuda(), - target=real_class + real_img, + baselines=counterfactual_img, + target=real_class, ) - return attribution[0] + return attribution class DDeepLift(BaseAttribution): """ Discriminative version of the DeepLift attribution method. """ - def __init__(self, classifier): - super().__init__(classifier) + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) self.dl = attr.DeepLift(classifier) def _attribute(self, real_img, counterfactual_img, real_class, target_class): # FIXME in the original DAPI code, the real and counterfactual were switched. attribution = self.dl.attribute( - real_img[None, ...].cuda(), - baselines=counterfactual_img[None, ...].cuda(), - target=real_class + real_img, + baselines=counterfactual_img, + target=real_class, ) - return attribution[0] + return attribution class DInGrad(BaseAttribution): """ Discriminative version of the InputxGradient attribution method. """ - def __init__(self, classifier): - super().__init__(classifier) + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) self.saliency = attr.Saliency(self.classifier) def _attribute(self, real_img, counterfactual_img, real_class, target_class): @@ -92,7 +146,90 @@ def _attribute(self, real_img, counterfactual_img, real_class, target_class): # grads_fake = self.saliency.attribute(counterfactual_img, # target=target_class) # ingrad_diff_0 = grads_fake * (real_img - counterfactual_img) - grads_real = self.saliency.attribute(real_img[None, ...].cuda(), - target=real_class).detach().cpu() - ingrad_diff_1 = grads_real * (counterfactual_img[None, ...] - real_img[None, ...]) - return ingrad_diff_1[0] \ No newline at end of file + grads_real = self.saliency.attribute(real_img, target=real_class).detach().cpu() + ingrad_diff_1 = grads_real * (counterfactual_img - real_img) + return ingrad_diff_1 + + +class VanillaIntegratedGradients(BaseAttribution): + """Wrapper class for Integrated Gradients from Captum. + + Allows us to use it as a baseline. + """ + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) + self.ig = attr.IntegratedGradients(classifier) + + def _attribute(self, real_img, counterfactual_img, real_class, target_class): + batched_attribution = ( + self.ig.attribute(real_img, target=real_class).detach().cpu() + ) + return batched_attribution + + +class VanillaDeepLift(BaseAttribution): + """Wrapper class for DeepLift from Captum. + + Allows us to use it as a baseline. + """ + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) + self.dl = attr.DeepLift(classifier) + + def _attribute(self, real_img, counterfactual_img, real_class, target_class): + batched_attribution = ( + self.dl.attribute(real_img, target=real_class).detach().cpu() + ) + return batched_attribution + + +class AttributionIO: + """ + Running the attribution methods on the images. + Storing the results in the output directory. + + """ + + def __init__(self, attributions: dict[str, BaseAttribution], output_directory: str): + self.attributions = attributions + self.output_directory = Path(output_directory) + + def get_directory(self, attr_name: str, source_class: str, target_class: str): + directory = self.output_directory / f"{attr_name}/{source_class}/{target_class}" + directory.mkdir(parents=True, exist_ok=True) + return directory + + def run( + self, + source_directory: str, + counterfactual_directory: str, + transform: Callable, + device: str = "cuda", + ): + if device == "cuda": + if not torch.cuda.is_available(): + raise ValueError("CUDA is not available on this machine.") + print("Loading paired data") + dataset = PairedImageDataset( + source_directory, counterfactual_directory, transform=transform + ) + print("Running attributions") + for sample in tqdm(dataset, total=len(dataset)): + for attr_name, attribution in self.attributions.items(): + attr = attribution.attribute( + sample.image, + sample.counterfactual, + sample.source_class_index, + sample.target_class_index, + device=device, + ) + # Store the attribution + np.save( + self.get_directory( + attr_name, sample.source_class, sample.target_class + ) + / f"{sample.path.stem}.npy", + attr, + ) diff --git a/src/quac/data.py b/src/quac/data.py index a789a73..56efa06 100644 --- a/src/quac/data.py +++ b/src/quac/data.py @@ -1,4 +1,8 @@ +from dataclasses import dataclass +import numpy as np import os +from pathlib import Path +import torch from torch.utils.data import Dataset from torchvision.datasets.folder import ( default_loader, @@ -23,22 +27,12 @@ def find_classes(directory): return classes, class_to_idx -def make_dataset( +def check_requirements( directory: str, - paired_directory: str, class_to_idx: Optional[Dict[str, int]] = None, extensions: Optional[Union[str, Tuple[str, ...]]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, -) -> List[Tuple[str, str, int, int]]: - """Generates a list of samples of a form (path_to_sample, class). - - See :class:`DatasetFolder` for details. - - Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function - by default. - """ - directory = os.path.expanduser(directory) - +): if class_to_idx is None: _, class_to_idx = find_classes(directory) elif not class_to_idx: @@ -46,6 +40,9 @@ def make_dataset( "'class_to_index' must have at least one entry to collect any samples." ) + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: @@ -60,6 +57,112 @@ def is_valid_file(x: str) -> bool: is_valid_file = cast(Callable[[str], bool], is_valid_file) + return is_valid_file + + +def make_counterfactual_dataset( + counterfactual_directory: str, + class_to_idx: Optional[Dict[str, int]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, int]]: + """Generates a list of samples of a form (path_to_sample, class) + for data organized in a counterfactual style directory. + + The dataset is organized in the following way: + ``` + root_directory/ + ├── class_x + | └── class_y + │ ├── xxx.ext + │ ├── xxy.ext + │ └── xxz.ext + └── class_y + └── class_x + ├── 123.ext + ├── nsdf3.ext + └── ... + └── asd932_.ext + ``` + + We want to use the most nested subdirectories as the class labels. + """ + directory = os.path.expanduser(counterfactual_directory) + if class_to_idx is None: + _, class_to_idx = find_classes(directory) + + is_valid_file = check_requirements( + counterfactual_directory, + class_to_idx, + extensions, + is_valid_file, + ) + + instances = [] + available_classes = set() + for source_class in sorted(class_to_idx.keys()): + source_dir = os.path.join(directory, source_class) + if not os.path.isdir(source_dir): + continue + target_directories = {} + for target_class in sorted(class_to_idx.keys()): + if target_class == source_class: + continue + target_dir = os.path.join(directory, source_class, target_class) + if os.path.isdir(target_dir): + target_directories[target_class] = target_dir + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + assert source_class != target_class + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + item = ( + path, + class_to_idx[source_class], + class_to_idx[target_class], + ) + instances.append(item) + + if target_class not in available_classes: + available_classes.add(source_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes and not allow_empty: + msg = ( + f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + ) + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + +def make_paired_dataset( + directory: str, + paired_directory: str, + class_to_idx: Optional[Dict[str, int]], + extensions: Optional[Union[str, Tuple[str, ...]]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, str, int, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + See :class:`DatasetFolder` for details. + + Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function + by default. + """ + directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = find_classes(directory) + + is_valid_file = check_requirements( + directory, class_to_idx, extensions, is_valid_file + ) + instances = [] available_classes = set() for source_class in sorted(class_to_idx.keys()): @@ -74,26 +177,117 @@ def is_valid_file(x: str) -> bool: target_dir = os.path.join(paired_directory, source_class, target_class) if os.path.isdir(target_dir): target_directories[target_class] = target_dir - for root, _, fnames in sorted(os.walk(source_dir, followlinks=True)): - for fname in sorted(fnames): - path = os.path.join(root, fname) - if is_valid_file(path): - for target_class, target_dir in target_directories.items(): - target_path = os.path.join(target_dir, fname) - if os.path.isfile(target_path) and is_valid_file(target_path): - item = ( - path, - target_path, - class_index, - class_to_idx[target_class], + for root, _, fnames in sorted(os.walk(source_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + for target_class, target_dir in target_directories.items(): + target_path = os.path.join(target_dir, fname) + if os.path.isfile(target_path) and is_valid_file( + target_path + ): + item = ( + path, + target_path, + class_index, + class_to_idx[target_class], + ) + instances.append(item) + + if source_class not in available_classes: + available_classes.add(source_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes and not allow_empty: + msg = ( + f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + ) + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + +def make_paired_attribution_dataset( + directory: str, + paired_directory: str, + attribution_directory: str, + class_to_idx: Optional[Dict[str, int]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, str, int, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + See :class:`DatasetFolder` for details. + + Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function + by default. + """ + directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = find_classes(directory) + + is_valid_file = check_requirements( + directory, class_to_idx, extensions, is_valid_file + ) + + instances = [] + available_classes = set() + for source_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[source_class] + source_dir = os.path.join(directory, source_class) + if not os.path.isdir(source_dir): + continue + target_directories = {} + attribution_directories = {} + for target_class in sorted(class_to_idx.keys()): + if target_class == source_class: + continue + target_dir = os.path.join(paired_directory, source_class, target_class) + if os.path.isdir(target_dir): + target_directories[target_class] = target_dir + # Add the attribution directory as well. It is organized in the same way + # as the paire directory is + attribution_dir = os.path.join( + attribution_directory, source_class, target_class + ) + if os.path.isdir(attribution_dir): + attribution_directories[target_class] = attribution_dir + + for root, _, fnames in sorted(os.walk(source_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + for target_class, target_dir in target_directories.items(): + target_path = os.path.join(target_dir, fname) + # attribution path must replace the extension to npy + attr_filename = fname.split(".")[-2] + ".npy" + attr_path = os.path.join( + attribution_directories[target_class], attr_filename ) - instances.append(item) - if source_class not in available_classes: - available_classes.add(source_class) + if ( + os.path.isfile(target_path) + and is_valid_file(target_path) + and os.path.isfile(attr_path) + ): + item = ( + path, + target_path, + attr_path, + class_index, + class_to_idx[target_class], + ) + instances.append(item) + + if source_class not in available_classes: + available_classes.add(source_class) empty_classes = set(class_to_idx.keys()) - available_classes - if empty_classes: + if empty_classes and not allow_empty: msg = ( f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " ) @@ -104,8 +298,56 @@ def is_valid_file(x: str) -> bool: return instances -class PairedImageFolders(Dataset): - def __init__(self, source_directory, paired_directory, transform=None): +@dataclass +class Sample: + image: torch.Tensor + source_class_index: int + path: Path = None + source_class: str = None + + +# TODO remove? +@dataclass +class CounterfactualSample: + counterfactual: torch.Tensor + target_class_index: int + source_class_index: int + path: Path = None + counterfactual_path: Path = None + source_class: str = None + target_class: str = None + + +@dataclass +class PairedSample: + image: torch.Tensor + counterfactual: torch.Tensor + source_class_index: int + target_class_index: int + path: Path = None + counterfactual_path: Path = None + source_class: str = None + target_class: str = None + + +@dataclass +class SampleWithAttribution: + attribution: np.ndarray + image: torch.Tensor + counterfactual: torch.Tensor + source_class_index: int + target_class_index: int + path: Path = None + counterfactual_path: Path = None + source_class: str = None + target_class: str = None + attribution_path: Path = None + + +class PairedImageDataset(Dataset): + def __init__( + self, source_directory, paired_directory, transform=None, allow_empty=True + ): """A dataset that loads images from paired directories, where one has images generated based on the other. @@ -138,16 +380,22 @@ def __init__(self, source_directory, paired_directory, transform=None): └── ... └── asd932_.ext ``` - Note that this will not work if the file names do not match! + note:: this will not work if the file names do not match! + + note:: the transform is applied sequentially to the image, counterfactual, and attribution. + This means that if there is any randomness in the transform, the three images will faie to match. + Additionally, the attribution will be a torch tensor when the transform is applied, so no PIL-only transforms + can be used. """ classes, class_to_idx = find_classes(source_directory) self.classes = classes self.class_to_idx = class_to_idx - self.samples = make_dataset( + self.samples = make_paired_dataset( source_directory, paired_directory, class_to_idx, is_valid_file=is_image_file, + allow_empty=allow_empty, ) self.transform = transform @@ -159,14 +407,111 @@ def __getitem__(self, index): if self.transform is not None: sample = self.transform(sample) target_sample = self.transform(target_sample) - output = { - "sample_path": path, - "target_path": target_path, - "sample": sample, - "target_sample": target_sample, - "class_index": class_index, - "target_class_index": target_class_index, - } + output = PairedSample( + path=Path(path), + counterfactual_path=Path(target_path), + image=sample, + counterfactual=target_sample, + source_class_index=class_index, + target_class_index=target_class_index, + source_class=self.classes[class_index], + target_class=self.classes[target_class_index], + ) + return output + + def __len__(self): + return len(self.samples) + + +class CounterfactualDataset(Dataset): + def __init__(self, counterfactual_directory, transform=None, allow_empty=True): + classes, class_to_idx = find_classes(counterfactual_directory) + self.classes = classes + self.class_to_idx = class_to_idx + self.samples = make_counterfactual_dataset( + counterfactual_directory, + class_to_idx, + is_valid_file=is_image_file, + allow_empty=allow_empty, + ) + self.transform = transform + + def __getitem__(self, index): + path, source_class_index, target_class_index = self.samples[index] + sample = default_loader(path) + if self.transform is not None: + sample = self.transform(sample) + output = CounterfactualSample( + counterfactual_path=Path(path), + counterfactual=sample, + source_class_index=source_class_index, + source_class=self.classes[source_class_index], + target_class_index=target_class_index, + target_class=self.classes[target_class_index], + ) + return output + + def __len__(self): + return len(self.samples) + + +class PairedWithAttribution(Dataset): + """This dataset returns both the original and counterfactual images, + as well as an attribution heatmap. + + note:: the transform is applied sequentially to the image, counterfactual. + This means that if there is any randomness in the transform, the images will fail to match. + Additionally, no transform is applied to the attribution. + """ + + def __init__( + self, + source_directory, + paired_directory, + attribution_directory, + transform=None, + allow_empty=True, + ): + classes, class_to_idx = find_classes(source_directory) + self.classes = classes + self.class_to_idx = class_to_idx + self.samples = make_paired_attribution_dataset( + source_directory, + paired_directory, + attribution_directory, + class_to_idx, + is_valid_file=is_image_file, + allow_empty=allow_empty, + ) + self.transform = transform + + def __getitem__(self, index) -> SampleWithAttribution: + ( + path, + target_path, + attribution_path, + class_index, + target_class_index, + ) = self.samples[index] + sample = default_loader(path) + target_sample = default_loader(target_path) + attribution = np.load(attribution_path) + if self.transform is not None: + sample = self.transform(sample) + target_sample = self.transform(target_sample) + + output = SampleWithAttribution( + path=Path(path), + counterfactual_path=Path(target_path), + attribution_path=Path(attribution_path), + image=sample, + counterfactual=target_sample, + attribution=attribution, + source_class_index=class_index, + target_class_index=target_class_index, + source_class=self.classes[class_index], + target_class=self.classes[target_class_index], + ) return output def __len__(self): diff --git a/src/quac/evaluation.py b/src/quac/evaluation.py index c6058bd..4608176 100644 --- a/src/quac/evaluation.py +++ b/src/quac/evaluation.py @@ -1,7 +1,14 @@ +import cv2 +from dataclasses import dataclass import numpy as np +from pathlib import Path +from quac.data import PairedImageDataset, CounterfactualDataset, PairedWithAttribution +from quac.report import Report +from sklearn.metrics import classification_report, confusion_matrix +from torchvision.datasets import ImageFolder from torch.nn import functional as F -import cv2 import torch +from tqdm import tqdm def image_to_tensor(image, device=None): @@ -12,36 +19,236 @@ def image_to_tensor(image, device=None): image_tensor = image_tensor.unsqueeze(0).unsqueeze(0) elif len(np.shape(image)) == 3: image_tensor = image_tensor.unsqueeze(0) + elif len(np.shape(image)) == 4: + return image_tensor.float() else: - raise ValueError("Input shape not understood") + raise ValueError(f"Input shape not understood, {image.shape}") return image_tensor.float() -class Evaluator: - """This class evaluates the quality of an attribution using the QuAC method. +class Processor: + """Class that turns attributions into masks.""" + + def __init__( + self, gaussian_kernel_size=11, struc=10, channel_wise=True, name="default" + ): + self.gaussian_kernel_size = gaussian_kernel_size + self.kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struc, struc)) + self.channel_wise = channel_wise + self.name = name + + def create_mask(self, attribution, threshold, return_size=True): + channels, _, _ = attribution.shape + mask_size = 0 + mask = [] + # construct mask channel by channel + for c in range(channels): + # threshold + if self.channel_wise: + channel_mask = attribution[c, :, :] > threshold + else: + channel_mask = np.any(attribution > threshold, axis=0) + # TODO explain the reasoning behind the morphological closing + # Morphological closing + channel_mask = cv2.morphologyEx( + channel_mask.astype(np.uint8), cv2.MORPH_CLOSE, self.kernel + ) + # TODO This might be misleading, given the blur afterwards + mask_size += np.sum(channel_mask) + # TODO Add connected components + # Blur + mask.append( + cv2.GaussianBlur( + channel_mask.astype(np.float32), + (self.gaussian_kernel_size, self.gaussian_kernel_size), + 0, + ) + ) + # TODO should we do this instead? + # mask_size += np.sum(mask) + if not return_size: + return np.array(mask) + return np.array(mask), mask_size + - It it is based on the assumption that there exists a counterfactual for each image. +class UnblurredProcessor(Processor): """ + Processor without any blurring + """ + + def __init__(self, struc=10, channel_wise=True, name="no_blur"): + self.kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struc, struc)) + self.channel_wise = channel_wise + self.name = name + + def create_mask(self, attribution, threshold, return_size=True): + channels, _, _ = attribution.shape + mask_size = 0 + mask = [] + # construct mask channel by channel + for c in range(channels): + # threshold + if self.channel_wise: + channel_mask = attribution[c, :, :] > threshold + else: + channel_mask = np.any(attribution > threshold, axis=0) + # Morphological closing + channel_mask = cv2.morphologyEx( + channel_mask.astype(np.uint8), cv2.MORPH_CLOSE, self.kernel + ) + mask.append(channel_mask) + mask_size += np.sum(channel_mask) + if not return_size: + return np.array(mask) + return np.array(mask), mask_size + + +class BaseEvaluator: + """Base class for evaluating attributions.""" def __init__( self, classifier, - sigma=11, - struc=10, - channel_wise=False, + source_dataset=None, + paired_dataset=None, + attribution_dataset=None, num_thresholds=200, device=None, ): + """Initializes the evaluator. + + It requires three different datasets: the source dataset, the counterfactual dataset and the attribution dataset. + All of them must return objects in the forms of the dataclasses in `quac.data`. + + + Parameters + ---------- + classifier: + The classifier to be used for the evaluation. + source_dataset: + The source dataset must returns a `quac.data.Sample` object in its `__getitem__` method. + paired_dataset: + The paired dataset must returns a `quac.data.PairedSample` object in its `__getitem__` method. + attribution_dataset: + The attribution dataset must returns a `quac.data.SampleWithAttribution` object in its `__getitem__` method. + """ self.device = device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.classifier = classifier.to(self.device) - self.sigma = sigma - self.struc = struc - self.channel_wise = channel_wise self.num_thresholds = num_thresholds + self.classifier = classifier + + self._source_dataset = source_dataset + self._paired_dataset = paired_dataset + self._dataset_with_attribution = attribution_dataset + + @property + def source_dataset(self): + return self._source_dataset + + @property + def paired_dataset(self): + return self._paired_dataset - def evaluate(self, x, x_t, y, y_t, attribution, predictions, vmin=-1, vmax=1): + @property + def dataset_with_attribution(self): + return self._dataset_with_attribution + + def _source_classification_report( + self, return_classification=False, print_report=True + ): + """ + Classify the source data and return the confusion matrix. + """ + pred = [] + target = [] + for image, source_class_index in tqdm(self.source_dataset): + pred.append(self.run_inference(image).argmax()) + target.append(source_class_index) + + if print_report: + print(classification_report(target, pred)) + + cm = confusion_matrix(target, pred, normalize="true") + if return_classification: + return cm, pred, target + return cm + + def _counterfactual_classification_report( + self, + return_classification=False, + print_report=True, + ): + """ + Classify the counterfactual data and return the confusion matrix. + """ + pred = [] + source = [] + target = [] + for sample in tqdm(self.counterfactual_dataset): + pred.append(self.run_inference(sample.counterfactual).argmax()) + target.append(sample.target_class_index) + source.append(sample.source_class_index) + + if print_report: + print(classification_report(target, pred)) + + cm = confusion_matrix(target, pred, normalize="true") + if return_classification: + return cm, pred, source, target + return cm + + def classification_report( + self, + data="counterfactuals", + return_classification=False, + print_report=True, + ): + """ + Classify the data and return the confusion matrix. + """ + if data == "counterfactuals": + return self._counterfactual_classification_report( + return_classification=return_classification, + print_report=print_report, + ) + elif data == "source": + return self._source_classification_report( + return_classification=return_classification, + print_report=print_report, + ) + else: + raise ValueError(f"Data must be 'counterfactuals' or 'source', not {data}") + + def quantify(self, processor=None): + if processor is None: + processor = Processor() + report = Report(name=processor.name) + for inputs in tqdm(self.dataset_with_attribution): + predictions = { + "original": self.run_inference(inputs.image)[0], + "counterfactual": self.run_inference(inputs.counterfactual)[0], + } + results = self.evaluate( + inputs.image, + inputs.counterfactual, + inputs.source_class_index, + inputs.target_class_index, + inputs.attribution, + predictions, + processor, + ) + report.accumulate( + inputs, + predictions, + results, + ) + return report + + def evaluate( + self, x, x_t, y, y_t, attribution, predictions, processor, vmin=-1, vmax=1 + ): """ Run QuAC evaluation on the data point. @@ -51,7 +258,9 @@ def evaluate(self, x, x_t, y, y_t, attribution, predictions, vmin=-1, vmax=1): x_t: the counterfactual image y: the class of the input image y_t: the class of the counterfactual image - attrihbution: the attribution map + attribution: the attribution map + predictions: the predictions of the classifier + processor: the attribution processing function (to get mask) vmin: the minimal possible value of the attribution, to be used for thresholding. Defaults to -1 vmax: the maximal possible value of the attribution, to be used for thresholding. Defaults to 1. """ @@ -59,6 +268,7 @@ def evaluate(self, x, x_t, y, y_t, attribution, predictions, vmin=-1, vmax=1): # "real" changes into "fake_class" classification_real = predictions["original"] + # TODO remove the need for this results = { "thresholds": [], "hybrids": [], @@ -67,21 +277,23 @@ def evaluate(self, x, x_t, y, y_t, attribution, predictions, vmin=-1, vmax=1): } for threshold in np.arange(vmin, vmax, (vmax - vmin) / self.num_thresholds): # soft mask of the parts to copy - mask, mask_size = self.create_mask(attribution, threshold) - + mask, mask_size = processor.create_mask(attribution, threshold) # hybrid = real parts copied to fake hybrid = x_t * mask + x * (1.0 - mask) - - classification_hybrid = self.run_inference(hybrid)[0] - - score_change = classification_hybrid[y_t] - classification_real[y_t] - # Append results # TODO Do we want to store the hybrid? results["thresholds"].append(threshold) results["hybrids"].append(hybrid) results["mask_sizes"].append(mask_size / np.prod(x.shape)) - results["score_change"].append(score_change) + + # Classification + # classification_hybrid = self.run_inference(hybrid)[0] + # score_change = classification_hybrid[y_t] - classification_real[y_t] + # results["score_change"].append(score_change) + hybrid = np.stack(results["hybrids"], axis=0) + classification_hybrid = self.run_inference(hybrid) + score_change = classification_hybrid[:, y_t] - classification_real[y_t] + results["score_change"] = score_change return results @torch.no_grad() @@ -94,30 +306,71 @@ def run_inference(self, im): class_probs = F.softmax(self.classifier(im_tensor), dim=1).cpu().numpy() return class_probs - def create_mask(self, attribution, threshold): - channels, _, _ = attribution.shape - # TODO find a way to get rid of this - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (self.struc, self.struc)) - mask_size = 0 - mask = [] - # construct mask channel by channel - for c in range(channels): - # threshold - if self.channel_wise: - channel_mask = attribution[c, :, :] > threshold - else: - channel_mask = np.any(attribution > threshold, axis=0) - # morphological closing - channel_mask = cv2.morphologyEx( - channel_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel - ) - # TODO This might be misleading, given the blur afterwards - mask_size += np.sum(channel_mask) - # TODO Add connected components - # Blur - mask.append( - cv2.GaussianBlur( - channel_mask.astype(np.float32), (self.sigma, self.sigma), 0 - ) - ) - return np.array(mask), mask_size + +class Evaluator(BaseEvaluator): + """This class evaluates the quality of an attribution using the QuAC method. + + Raises: + FileNotFoundError: If the source, counterfactual or attribution directories do not exist. + """ + + def __init__( + self, + classifier, + source_directory, + counterfactual_directory, + attribution_directory, + transform=None, + num_thresholds=200, + device=None, + ): + # Check that they all exist + for directory in [ + source_directory, + counterfactual_directory, + attribution_directory, + ]: + if not Path(directory).exists(): + raise FileNotFoundError(f"Directory {directory} does not exist") + + super().__init__( + classifier, None, None, None, num_thresholds=num_thresholds, device=device + ) + self.transform = transform + self.source_directory = source_directory + self.counterfactual_directory = counterfactual_directory + self.attribution_directory = attribution_directory + + @property + def source_dataset(self): + # NOTE: Recomputed each time, but should be used sparingly. + dataset = ImageFolder(self.source_directory, transform=self.transform) + return dataset + + @property + def counterfactual_dataset(self): + # NOTE: Recomputed each time, but should be used sparingly. + dataset = CounterfactualDataset( + self.counterfactual_directory, transform=self.transform + ) + return dataset + + @property + def paired_dataset(self): + # NOTE: Recomputed each time, but should be used sparingly. + dataset = PairedImageDataset( + self.source_directory, + self.counterfactual_directory, + transform=self.transform, + ) + return dataset + + @property + def dataset_with_attribution(self): + dataset = PairedWithAttribution( + self.source_directory, + self.counterfactual_directory, + self.attribution_directory, + transform=self.transform, + ) + return dataset diff --git a/src/quac/generate/__init__.py b/src/quac/generate/__init__.py new file mode 100644 index 0000000..e81180c --- /dev/null +++ b/src/quac/generate/__init__.py @@ -0,0 +1,260 @@ +"""Utilities for generating counterfactual images.""" + +from .model import LatentInferenceModel, ReferenceInferenceModel +from .data import LabelFreePngFolder + +import logging +from quac.training.classification import ClassifierWrapper +import torch +from torchvision import transforms +from typing import Union + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger(__name__) + + +class CounterfactualNotFound(Exception): + pass + + +def load_classifier( + checkpoint, mean=0.5, std=0.5, eval=True, assume_normalized=False, device=None +): + """ + Load a classifier from a torchscript checkpoint. + + This also creates a wrapper around the classifier, which normalizes the input. + The classifier expects the input range to be [-1, 1], and normalizes it with the give `mean` and `std`. + + Parameters: + checkpoint: the path to the checkpoint + mean: the mean to normalize the input + std: the standard deviation to normalize + eval: whether to put the classifier in evaluation mode, defaults to True + device: the device to use, defaults to None + """ + classifier = ClassifierWrapper(checkpoint, mean=mean, std=std) + if device: + classifier.to(device) + if eval: + classifier.eval() + return classifier + + +def load_data( + data_directory, img_size, grayscale=True, mean=0.5, std=0.5 +) -> LabelFreePngFolder: + """ + Load a dataset from a directory. + + This assumes that the images are in a folder, with no subfolders, and no labels. + The images are resized to `img_size`, and normalized with the given `mean` and `std`. + If `grayscale` is True, the images are converted to grayscale. + + The returned dataset will return the image file name as the second element of the tuple. + + Parameters: + data_directory: the directory to load the images from + img_size: the size to resize the images to + grayscale: whether to convert the images to grayscale, defaults to True + mean: the mean to normalize the images, defaults to 0.5 + std: the standard deviation to normalize the images, defaults to 0.5 + """ + dataset = LabelFreePngFolder( + root=data_directory, + transform=transforms.Compose( + [ + transforms.Resize([img_size, img_size]), + transforms.Grayscale() if grayscale else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ), + ) + return dataset + + +def load_stargan( + latent_model_checkpoint_dir: str, + img_size: int = 128, + input_dim: int = 1, + style_dim: int = 64, + latent_dim: int = 16, + num_domains: int = 6, + checkpoint_iter: int = 100000, + kind="latent", + single_output_encoder: bool = False, + final_activation: Union[str, None] = None, +) -> torch.nn.Module: + """ + Load an inference version of the StarGANv2 model from a checkpoint. + + Parameters: + latent_model_checkpoint_dir: the directory of the checkpoint + img_size: the size of the model input + input_dim: the number of input channels + style_dim: the dimension of the style + latent_dim: the dimension of the latent space + num_domains: the number of domains + checkpoint_iter: the iteration of the checkpoint to load + kind: the kind of style to use, either "latent" or "reference" + single_output_encoder: whether to use a single output encoder, only used if kind is "reference" + + Returns: + the loaded inference model + """ + if kind == "reference": + latent_inference_model = ReferenceInferenceModel( + checkpoint_dir=latent_model_checkpoint_dir, + img_size=img_size, + input_dim=input_dim, + style_dim=style_dim, + latent_dim=latent_dim, + num_domains=num_domains, + single_output_encoder=single_output_encoder, + final_activation=final_activation, + ) + else: + latent_inference_model = LatentInferenceModel( + checkpoint_dir=latent_model_checkpoint_dir, + img_size=img_size, + input_dim=input_dim, + style_dim=style_dim, + latent_dim=latent_dim, + num_domains=num_domains, + final_activation=final_activation, + ) + latent_inference_model.load_checkpoint(checkpoint_iter) + latent_inference_model.eval() + return latent_inference_model + + +@torch.no_grad() +def get_counterfactual( + classifier, + latent_inference_model, + x, + target, + kind="latent", # or "reference" + dataset_ref=None, + batch_size=10, + device=None, + max_tries=100, + best_pred_so_far=None, + best_cf_so_far=None, + best_cf_path_so_far=None, + error_if_not_found=False, + return_path=False, + return_pred=False, +) -> torch.Tensor: + """ + Tries to find a counterfactual for the given sample, given the target. + It creates a batch, and returns one of the samples if it is classified correctly. + + Parameters: + classifier: the classifier to use + latent_inference_model: the latent inference model to use + x: the sample to find a counterfactual for + target: the target class + kind: the kind of style to use, either "latent" or "reference" + dataset_ref: the dataset of reference images to use, required if kind is "reference" + batch_size: the number of counterfactuals to generate + device: the device to use + max_tries: the maximum number of tries to find a counterfactual + error_if_not_found: whether to raise an error if no counterfactual is found, if set to False, the best counterfactual found so far is returned + return_path: whether to return the path of the reference used to create best counterfactual found so far, + only used if kind is "reference" + + Returns: + a counterfactual + + Raises: + CounterfactualNotFound: if no counterfactual is found after max_tries tries + """ + if best_pred_so_far is None: + best_pred_so_far = torch.zeros(target + 1) + # Copy x batch_size times + x_multiple = torch.stack([x] * batch_size) + if kind == "reference": + assert ( + dataset_ref is not None + ), "Reference dataset required for reference style." + if len(dataset_ref) // batch_size < max_tries: + max_tries = len(dataset_ref) // batch_size + logger.warning( + f"Not enough reference images, reducing max_tries to {max_tries}." + ) + # Get a batch of reference images, starting from batch_size * max_tries, of size batch_size + ref_batch, ref_paths = zip( + *[ + dataset_ref[i] + for i in range(batch_size * (max_tries - 1), batch_size * max_tries) + ] + ) + ref_batch = torch.stack(ref_batch) + # Generate batch_size counterfactuals + xcf = latent_inference_model( + x_multiple.to(device), + ref_batch.to(device), + torch.tensor([target] * batch_size).to(device), + ) + else: # kind == "latent" + # Generate batch_size counterfactuals from random latents + xcf = latent_inference_model( + x_multiple.to(device), + torch.tensor([target] * batch_size).to(device), + ) + + # Evaluate the counterfactuals + p = torch.softmax(classifier(xcf), dim=-1) + # Get the predictions + predictions = torch.argmax(p, dim=-1) + # Get best so far + best_idx_so_far = torch.argmax(p[:, target]) + if p[best_idx_so_far, target] > best_pred_so_far[target]: + best_pred_so_far = p[best_idx_so_far] # , target] + best_cf_so_far = xcf[best_idx_so_far].cpu() + if kind == "reference": + best_cf_path_so_far = ref_paths[best_idx_so_far] + else: + best_cf_path_so_far = None + # Get the indices of the correct predictions + indices = torch.where(predictions == target)[0] + + if len(indices) == 0: + if max_tries > 0: + logger.info( + f"Counterfactual not found, trying again. {max_tries} tries left." + ) + return get_counterfactual( + classifier, + latent_inference_model, + x, + target, + kind, + dataset_ref, + batch_size, + device, + max_tries - 1, + best_pred_so_far=best_pred_so_far, + best_cf_so_far=best_cf_so_far, + best_cf_path_so_far=best_cf_path_so_far, + return_path=return_path, + return_pred=return_pred, + ) + else: + if error_if_not_found: + raise CounterfactualNotFound( + "Counterfactual not found after max_tries tries." + ) + logger.info( + f"Counterfactual not found after {max_tries} tries, using best so far." + ) + # Return the best counterfactual so far + if return_path and kind == "reference": + if return_pred: + return best_cf_so_far, best_cf_path_so_far, best_pred_so_far + return best_cf_so_far, best_cf_path_so_far + if return_pred: + return best_cf_so_far, best_pred_so_far + return best_cf_so_far diff --git a/src/quac/generate/data.py b/src/quac/generate/data.py new file mode 100644 index 0000000..0e08406 --- /dev/null +++ b/src/quac/generate/data.py @@ -0,0 +1,34 @@ +from PIL import Image +from pathlib import Path +import torch + + +class LabelFreePngFolder(torch.utils.data.Dataset): + # TODO Move to quac.data + """Get all images in a folder, no subfolders, no labels.""" + + def __init__(self, root, transform=None): + super().__init__() + self.root = Path(root) + self.transform = transform + self.samples = [ + path + for path in self.root.iterdir() + if path.is_file() + and path.name.endswith(".png") + or path.name.endswith(".jpg") + ] + assert len(self.samples) > 0, f"No images found in {self.root}." + + def load_image(self, path): + return Image.open(path) + + def __getitem__(self, index): + path = self.samples[index] + sample = self.load_image(path) + if self.transform is not None: + sample = self.transform(sample) + return sample, path.name + + def __len__(self): + return len(self.samples) diff --git a/src/quac/generate/model.py b/src/quac/generate/model.py new file mode 100644 index 0000000..d2ad1b5 --- /dev/null +++ b/src/quac/generate/model.py @@ -0,0 +1,105 @@ +"""Reduces the model into just want is needed for inference.""" + +from os.path import join as ospj +from quac.training.stargan import ( + Generator, + MappingNetwork, + StyleEncoder, + SingleOutputStyleEncoder, +) +from quac.training.checkpoint import CheckpointIO +import torch + + +class LatentInferenceModel(torch.nn.Module): + def __init__( + self, + checkpoint_dir, + img_size, + style_dim, + latent_dim, + input_dim=1, + num_domains=6, + final_activation=None, + ) -> None: + super().__init__() + generator = Generator( + img_size, style_dim, input_dim=input_dim, final_activation=final_activation + ) + mapping_network = MappingNetwork(latent_dim, style_dim, num_domains=num_domains) + + self.nets = torch.nn.ModuleDict( + { + "generator": generator, + "mapping_network": mapping_network, + } + ) + + self.checkpoint_io = CheckpointIO( + ospj(checkpoint_dir, "{:06d}_nets_ema.ckpt"), + data_parallel=False, + **self.nets, + ) + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.nets.to(self.device) + self.latent_dim = latent_dim + self.style_dim = style_dim + + def load_checkpoint(self, step): + self.checkpoint_io.load(step) + + def forward(self, x_src, y_trg): + z = torch.randn(x_src.size(0), self.latent_dim).to(self.device) + s = self.nets.mapping_network(z, y_trg) + x_fake = self.nets.generator(x_src, s) + return x_fake + + +class ReferenceInferenceModel(torch.nn.Module): + def __init__( + self, + checkpoint_dir, + img_size, + style_dim=64, + latent_dim=16, + input_dim=1, + num_domains=6, + single_output_encoder=False, + final_activation=None, + ) -> None: + super().__init__() + generator = Generator( + img_size, style_dim, input_dim=input_dim, final_activation=final_activation + ) + if single_output_encoder: + style_encoder = SingleOutputStyleEncoder( + img_size, style_dim, num_domains, input_dim=input_dim + ) + else: + style_encoder = StyleEncoder( + img_size, style_dim, num_domains, input_dim=input_dim + ) + + self.nets = torch.nn.ModuleDict( + {"generator": generator, "style_encoder": style_encoder} + ) + + self.checkpoint_io = CheckpointIO( + ospj(checkpoint_dir, "{:06d}_nets_ema.ckpt"), + data_parallel=False, + **self.nets, + ) + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.nets.to(self.device) + self.latent_dim = latent_dim + self.style_dim = style_dim + + def load_checkpoint(self, step): + self.checkpoint_io.load(step) + + def forward(self, x_src, x_ref, y_trg): + s = self.nets.style_encoder(x_ref, y_trg) + x_fake = self.nets.generator(x_src, s) + return x_fake diff --git a/src/quac/report.py b/src/quac/report.py index 669240a..e20fb2a 100644 --- a/src/quac/report.py +++ b/src/quac/report.py @@ -5,6 +5,7 @@ from scipy.interpolate import interp1d import torch from tqdm import tqdm +import warnings class Report: @@ -19,14 +20,12 @@ class Report: Optionally, it also stores the hybrids generated for each threshold. """ - def __init__(self, save_dir, name=None, metadata={}): - # TODO Set up all of the data in the namespace - self.save_dir = Path(save_dir) + def __init__(self, name=None, metadata={}): if name is None: - self.name = "attribution" + self.name = "report" else: self.name = name - self.attribution_dir = None + # Shows where the attribution is, if needed # TODO check that the metadata is JSON serializable self.metadata = metadata # Initialize as empty @@ -47,31 +46,18 @@ def __init__(self, save_dir, name=None, metadata={}): # Initialize interpolation values self.interp_mask_values = np.arange(0.0, 1.0001, 0.01) - def make_attribution_dir(self): - """Create a directory to store the attributions""" - if self.attribution_dir is None: - self.attribution_dir = self.save_dir / self.name - self.attribution_dir.mkdir(parents=True, exist_ok=True) - - def accumulate( - self, - inputs, - predictions, - attribution, - evaluation_results, - save_attribution=True, - save_intermediates=False, - ): + def accumulate(self, inputs, predictions, evaluation_results): """ Store a new result. If `save_intermediates` is `True`, the hybrids are stored to disk. Otherwise they are discarded. """ # Store the input information - self.paths.append(inputs["sample_path"]) - self.target_paths.append(inputs["target_path"]) - self.labels.append(inputs["class_index"]) - self.target_labels.append(inputs["target_class_index"]) + self.paths.append(inputs.path) + self.target_paths.append(inputs.counterfactual_path) + self.labels.append(inputs.source_class_index) + self.target_labels.append(inputs.target_class_index) + self.attribution_paths.append(inputs.attribution_path) # Store the prediction results self.predictions.append(predictions["original"]) self.target_predictions.append(predictions["counterfactual"]) @@ -79,27 +65,7 @@ def accumulate( self.thresholds.append(evaluation_results["thresholds"]) self.normalized_mask_sizes.append(evaluation_results["mask_sizes"]) self.score_changes.append(evaluation_results["score_change"]) - # Store the attribution to disk - if save_attribution: - self.make_attribution_dir() - filename = Path(inputs["sample_path"]).stem - attribution_path = ( - self.attribution_dir / f"{filename}_{inputs['target_class_index']}.npy" - ) - with open(attribution_path, "wb") as fd: - np.save(fd, attribution) - self.attribution_paths.append(attribution_path) - - # Store the hybrids to disk - if save_intermediates: - for h in evaluation_results["hybrids"]: - # TODO store the hybrids - pass - - def load_attribution(self, index): - """Load the attribution for a given index""" - with open(self.attribution_paths[index], "rb") as fd: - return np.load(fd) + # TODO Store the hybrids to disk ? def interpolate_score_values(self, normalized_mask_sizes, score_changes): """Computes the score changes interpolated at the desired mask sizes""" @@ -150,12 +116,16 @@ def make_json_serializable(self, obj): return [self.make_json_serializable(x) for x in obj] return obj - def store(self): + def store(self, save_dir): """Store report to disk""" - self.save_dir.mkdir(parents=True, exist_ok=True) - with open(self.save_dir / f"{self.name}.json", "w") as fd: + if self.quac_scores is None: + self.compute_scores() + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + with open(save_dir / f"{self.name}.json", "w") as fd: json.dump( { + "metadata": self.metadata, "thresholds": self.make_json_serializable(self.thresholds), "normalized_mask_sizes": self.make_json_serializable( self.normalized_mask_sizes @@ -172,6 +142,7 @@ def store(self): "attribution_paths": self.make_json_serializable( self.attribution_paths ), + "quac_scores": self.make_json_serializable(self.quac_scores), }, fd, ) @@ -180,6 +151,7 @@ def load(self, filename): """Load report from disk""" with open(filename, "r") as fd: data = json.load(fd) + self.metadata = data.get("metadata", {}) self.thresholds = data["thresholds"] self.normalized_mask_sizes = data["normalized_mask_sizes"] self.score_changes = data["score_changes"] @@ -190,11 +162,31 @@ def load(self, filename): self.predictions = data.get("predictions", []) self.target_predictions = data.get("target_predictions", []) self.attribution_paths = data.get("attribution_paths", []) + self.quac_scores = data.get("quac_scores", None) + + def get_curve(self): + """Gets the median and IQR of the QuAC curve""" + # TODO Cache the results, takes forever otherwise + plot_values = [] + for normalized_mask_sizes, score_changes in zip( + self.normalized_mask_sizes, self.score_changes + ): + interp_score_values = self.interpolate_score_values( + normalized_mask_sizes, score_changes + ) + plot_values.append(interp_score_values) + plot_values = np.array(plot_values) + # mean = np.mean(plot_values, axis=0) + # std = np.std(plot_values, axis=0) + median = np.median(plot_values, axis=0) + p25 = np.percentile(plot_values, 25, axis=0) + p75 = np.percentile(plot_values, 75, axis=0) + return median, p25, p75 def plot_curve(self, ax=None): """Plot the QuAC curve - We plot the mean and standard deviation of the QuAC curve acrosss all accumulated results. + We plot the median and IQR of the QuAC curve acrosss all accumulated results. Parameters ---------- @@ -204,28 +196,56 @@ def plot_curve(self, ax=None): if ax is None: fig, ax = plt.subplots() - plot_values = [] - for normalized_mask_sizes, score_changes in zip( - self.normalized_mask_sizes, self.score_changes - ): - interp_score_values = self.interpolate_score_values( - normalized_mask_sizes, score_changes - ) - plot_values.append(interp_score_values) - plot_values = np.array(plot_values) - mean = np.mean(plot_values, axis=0) - std = np.std(plot_values, axis=0) + mean, p25, p75 = self.get_curve() + ax.plot(self.interp_mask_values, mean, label=self.name) - ax.fill_between(self.interp_mask_values, mean - std, mean + std, alpha=0.2) + ax.fill_between(self.interp_mask_values, p25, p75, alpha=0.2) if ax is None: plt.show() + def optimal_thresholds(self, min_percentage=0.0): + """Get the optimal threshold for each sample + + The optimal threshold has a minimal mask size, and maximizes the score change. + We optimize $|m| - \delta f$ where $m$ is the mask size and $\delta f$ is the score change. + + Parameters + ---------- + min_percentage: float + The optimal threshold chosen needs to account for at least this percentage of total score change. + Increasing this value will favor high percentage changes even when they require larger masks. + """ + mask_scores = np.array(self.score_changes) + mask_sizes = np.array(self.normalized_mask_sizes) + thresholds = np.array(self.thresholds) + tradeoff_scores = np.abs(mask_sizes) - mask_scores + # Determine what to ignore + if min_percentage > 0.0: + min_value = np.min(mask_scores, axis=1) + max_value = np.max(mask_scores, axis=1) + threshold = min_value + min_percentage * (max_value - min_value) + below_threshold = mask_scores < threshold[:, None] + tradeoff_scores[ + below_threshold + ] = np.inf # Ignores the points with not enough score change + thr_idx = np.argmin(tradeoff_scores, axis=1) + + optimal_thresholds = np.take_along_axis( + thresholds, thr_idx[:, None], axis=1 + ).squeeze() + return optimal_thresholds + def get_optimal_threshold(self, index, return_index=False): + # TODO Deprecate, use vectorized version! + warnings.warn( + "This function is deprecated, please use the vectorized version instead.", + DeprecationWarning, + ) mask_scores = np.array(self.score_changes[index]) mask_sizes = np.array(self.normalized_mask_sizes[index]) pareto_scores = mask_sizes**2 + (1 - mask_scores) ** 2 thr_idx = np.argmin(pareto_scores) if return_index: - return thr_idx, self.thresholds[index][thr_idx] + return self.thresholds[index][thr_idx], thr_idx return self.thresholds[index][thr_idx] diff --git a/src/quac/training/checkpoint.py b/src/quac/training/checkpoint.py new file mode 100644 index 0000000..db51ceb --- /dev/null +++ b/src/quac/training/checkpoint.py @@ -0,0 +1,50 @@ +""" +StarGAN v2 +Copyright (c) 2020-present NAVER Corp. + +This work is licensed under the Creative Commons Attribution-NonCommercial +4.0 International License. To view a copy of this license, visit +http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. +""" + +import os +import torch + + +class CheckpointIO(object): + def __init__(self, fname_template, data_parallel=False, **kwargs): + os.makedirs(os.path.dirname(fname_template), exist_ok=True) + self.fname_template = fname_template + self.module_dict = kwargs + self.data_parallel = data_parallel + + def register(self, **kwargs): + self.module_dict.update(kwargs) + + def save(self, step): + fname = self.fname_template.format(step) + print("Saving checkpoint into %s..." % fname) + outdict = {} + for name, module in self.module_dict.items(): + if self.data_parallel: + outdict[name] = module.module.state_dict() + else: + outdict[name] = module.state_dict() + + torch.save(outdict, fname) + + def load(self, step): + fname = self.fname_template.format(step) + assert os.path.exists(fname), fname + " does not exist!" + print("Loading checkpoint from %s..." % fname) + if torch.cuda.is_available(): + module_dict = torch.load(fname) + else: + module_dict = torch.load(fname, map_location=torch.device("cpu")) + + for name, module in self.module_dict.items(): + if self.data_parallel: + module.module.load_state_dict(module_dict[name]) + else: + module.load_state_dict(module_dict[name]) diff --git a/src/quac/training/classification.py b/src/quac/training/classification.py new file mode 100644 index 0000000..cf392a1 --- /dev/null +++ b/src/quac/training/classification.py @@ -0,0 +1,42 @@ +import torch +from torchvision import transforms + + +class Identity(torch.nn.Module): + def forward(self, x): + return x + + +class ClassifierWrapper(torch.nn.Module): + """ + This class expects a torchscript model. See [here](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format) + for how to convert a model to torchscript. + """ + + def __init__( + self, + model_checkpoint, + mean=None, + std=None, + assume_normalized=False, + do_nothing=False, + ): + """Wraps a torchscript model, and applies normalization.""" + super().__init__() + self.model = torch.jit.load(model_checkpoint) + self.model.eval() + self.transform = transforms.Normalize(mean, std) + if mean is None: + self.transform = Identity() + self.assume_normalized = assume_normalized # TODO Remove this, it's in forward + self.do_nothing = do_nothing + + def forward(self, x, assume_normalized=False, do_nothing=False): + """Assumes that x is between -1 and 1.""" + if do_nothing or self.do_nothing: + return self.model(x) + # TODO it would be even better if the range was between 0 and 1 so we wouldn't have to do the below + if not self.assume_normalized and not assume_normalized: + x = (x + 1) / 2 + x = self.transform(x) + return self.model(x) diff --git a/src/quac/training/config.py b/src/quac/training/config.py new file mode 100644 index 0000000..d40c846 --- /dev/null +++ b/src/quac/training/config.py @@ -0,0 +1,88 @@ +from pydantic import BaseModel +from typing import Optional, Union, Literal + + +class ModelConfig(BaseModel): + img_size: int = 128 + style_dim: int = 64 + latent_dim: int = 16 + num_domains: int = 5 + input_dim: int = 3 + final_activation: str = "tanh" + + +class DataConfig(BaseModel): + source: str + reference: str + img_size: int = 128 + batch_size: int = 1 + num_workers: int = 4 + grayscale: bool = False + mean: Optional[float] = 0.5 + std: Optional[float] = 0.5 + rand_crop_prob: Optional[float] = 0 + + +class RunConfig(BaseModel): + resume_iter: int = 0 + total_iters: int = 100000 + log_every: int = 1000 + save_every: int = 10000 + eval_every: int = 10000 + + +class ValConfig(BaseModel): + classifier_checkpoint: str + num_outs_per_domain: int = 10 + mean: Optional[float] = 0.5 + std: Optional[float] = 0.5 + img_size: int = 128 + val_batch_size: int = 16 + assume_normalized: bool = False + do_nothing: bool = False + + +class LossConfig(BaseModel): + lambda_ds: float = 0.0 # No diversity by default + lambda_sty: float = 1.0 + lambda_cyc: float = 1.0 + lambda_reg: float = 1.0 + ds_iter: int = 100000 + + +class SolverConfig(BaseModel): + root_dir: str + f_lr: float = 1e-6 + lr: float = 1e-4 + beta1: float = 0.0 + beta2: float = 0.99 + weight_decay: float = 1e-4 + + +class WandBLogConfig(BaseModel): + project: str = "default" + name: str = "default" + notes: str = "" + tags: list = [] + + +class TensorboardLogConfig(BaseModel): + log_dir: str + comment: str = "" + + +class ExperimentConfig(BaseModel): + # Metadata for keeping track of experiments + log_type: Literal["wandb", "tensorboard"] = "wandb" + log: Union[WandBLogConfig, TensorboardLogConfig] = WandBLogConfig() + # Some input required + data: DataConfig + solver: SolverConfig + validation_data: DataConfig + validation_config: ValConfig + # No input required + model: ModelConfig = ModelConfig() + run: RunConfig = RunConfig() + loss: LossConfig = LossConfig() + # Optional + test_data: Optional[DataConfig] = None diff --git a/src/quac/training/data_loader.py b/src/quac/training/data_loader.py new file mode 100644 index 0000000..7e944e4 --- /dev/null +++ b/src/quac/training/data_loader.py @@ -0,0 +1,566 @@ +""" +StarGAN v2 +Copyright (c) 2020-present NAVER Corp. + +This work is licensed under the Creative Commons Attribution-NonCommercial +4.0 International License. To view a copy of this license, visit +http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. +""" + +from pathlib import Path +from itertools import chain +import glob +import os +import random + +from munch import Munch +from PIL import Image +import numpy as np + +import torch +from torch.utils import data +from torch.utils.data.sampler import WeightedRandomSampler +from torchvision import transforms +from torchvision.datasets import ImageFolder + + +class RGB: + def __call__(self, img): + if isinstance(img, Image.Image): + return img.convert("RGB") + else: # Tensor + if img.size(0) == 1: + return torch.cat([img, img, img], dim=0) + return img + + +def listdir(dname): + fnames = list( + chain( + *[ + list(Path(dname).rglob("*." + ext)) + for ext in ["png", "jpg", "jpeg", "JPG"] + ] + ) + ) + return fnames + + +class DefaultDataset(data.Dataset): + def __init__(self, root, transform=None): + self.samples = listdir(root) + self.samples.sort() + self.transform = transform + self.targets = None + + def __getitem__(self, index): + fname = self.samples[index] + img = Image.open(fname) + if self.transform is not None: + img = self.transform(img) + return img + + def __len__(self): + return len(self.samples) + + +class AugmentedDataset(data.Dataset): + """Adds an augmented version of the input to the sample.""" + + def __init__(self, root, transform=None, augment=None): + self.samples, self.targets = self._make_dataset(root) + self.transform = transform + if augment is None: + # Default augmentation: random horizontal flip, random vertical flip + augment = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + ] + ) + self.augment = augment + + def _make_dataset(self, root): + domains = glob.glob(os.path.join(root, "*")) + fnames, labels = [], [] + for idx, domain in enumerate(sorted(domains)): + class_dir = os.path.join(root, domain) + cls_fnames = listdir(class_dir) + fnames += cls_fnames + labels += [idx] * len(cls_fnames) + return fnames, labels + + def __getitem__(self, index): + fname = self.samples[index] + label = self.targets[index] + img = Image.open(fname) + img2 = self.augment(img) + if self.transform is not None: + img = self.transform(img) + img2 = self.transform(img2) + return img, img2, label + + def __len__(self): + return len(self.targets) + + +class ReferenceDataset(data.Dataset): + def __init__(self, root, transform=None): + self.samples, self.targets = self._make_dataset(root) + self.transform = transform + + def _make_dataset(self, root): + domains = glob.glob(os.path.join(root, "*")) + fnames, fnames2, labels = [], [], [] + for idx, domain in enumerate(sorted(domains)): + class_dir = os.path.join(root, domain) + cls_fnames = listdir(class_dir) + fnames += cls_fnames + fnames2 += random.sample(cls_fnames, len(cls_fnames)) + labels += [idx] * len(cls_fnames) + return list(zip(fnames, fnames2)), labels + + def __getitem__(self, index): + fname, fname2 = self.samples[index] + label = self.targets[index] + img = Image.open(fname) + img2 = Image.open(fname2) + if self.transform is not None: + img = self.transform(img) + img2 = self.transform(img2) + return img, img2, label + + def __len__(self): + return len(self.targets) + + +def _make_balanced_sampler(labels): + class_counts = np.bincount(labels) + assert np.all(class_counts > 0), f"Some of the classes are empty. {class_counts}" + class_weights = 1.0 / class_counts + weights = class_weights[labels] + return WeightedRandomSampler(weights, len(weights)) + + +def get_train_loader( + root, + which="source", + img_size=256, + batch_size=8, + prob=0.5, + num_workers=4, + grayscale=False, + mean=0.5, + std=0.5, +): + print( + "Preparing DataLoader to fetch %s images " + "during the training phase..." % which + ) + + crop = transforms.RandomResizedCrop(img_size, scale=[0.8, 1.0], ratio=[0.9, 1.1]) + rand_crop = transforms.Lambda(lambda x: crop(x) if random.random() < prob else x) + + transform_list = [rand_crop] + if grayscale: + transform_list.append(transforms.Grayscale()) + else: + transform_list.append(RGB()) + + transform_list += [ + transforms.Resize([img_size, img_size]), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ToTensor(), + ] + if mean is not None and std is not None: + transform_list.append(transforms.Normalize(mean=mean, std=std)) + transform = transforms.Compose(transform_list) + + if which == "source": + # dataset = ImageFolder(root, transform) + dataset = AugmentedDataset(root, transform) + elif which == "reference": + dataset = ReferenceDataset(root, transform) + else: + raise NotImplementedError + + sampler = _make_balanced_sampler(dataset.targets) + return data.DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + ) + + +def get_eval_loader( + root, + img_size=256, + batch_size=32, + imagenet_normalize=False, + shuffle=True, + num_workers=4, + drop_last=False, + grayscale=False, + mean=0.5, + std=0.5, +): + print("Preparing DataLoader for the evaluation phase...") + if imagenet_normalize: + height, width = 299, 299 + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + else: + height, width = img_size, img_size + + if mean is not None: + normalize = transforms.Normalize(mean=mean, std=std) + else: + normalize = transforms.Lambda(lambda x: x) + + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale()) + else: + transform_list.append(RGB()) + + transform = transforms.Compose( + [ + *transform_list, + transforms.Resize([height, width]), + transforms.ToTensor(), + normalize, + ] + ) + + dataset = DefaultDataset(root, transform=transform) + return data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=True, + drop_last=drop_last, + ) + + +def get_test_loader( + root, + img_size=256, + batch_size=32, + shuffle=False, + drop_last=False, + num_workers=4, + grayscale=False, + mean=0.5, + std=0.5, + return_dataset=False, +): + print("Preparing DataLoader for the generation phase...") + transform_list = [] + if grayscale: + transform_list.append(transforms.Grayscale()) + else: + transform_list.append(RGB()) + + transform_list += [ + transforms.Resize([img_size, img_size]), + transforms.ToTensor(), + ] + if mean is not None and std is not None: + transform_list.append(transforms.Normalize(mean=mean, std=std)) + transform = transforms.Compose(transform_list) + + dataset = ImageFolder(root, transform) + if return_dataset: + return dataset + return data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=True, + drop_last=drop_last, + ) + + +class InputFetcher: + def __init__(self, loader, loader_ref=None, latent_dim=16, mode=""): + self.loader = loader + self.loader_ref = loader_ref + self.latent_dim = latent_dim + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.mode = mode + + def _fetch_inputs(self): + try: + x, y = next(self.iter) + except (AttributeError, StopIteration): + self.iter = iter(self.loader) + x, y = next(self.iter) + return x, y + + def _fetch_refs(self): + try: + x, x2, y = next(self.iter_ref) + except (AttributeError, StopIteration): + self.iter_ref = iter(self.loader_ref) + x, x2, y = next(self.iter_ref) + return x, x2, y + + def __next__(self): + x, y = self._fetch_inputs() + if self.mode == "train": + x_ref, x_ref2, y_ref = self._fetch_refs() + z_trg = torch.randn(x.size(0), self.latent_dim) + z_trg2 = torch.randn(x.size(0), self.latent_dim) + inputs = Munch( + x_src=x, + y_src=y, + y_ref=y_ref, + x_ref=x_ref, + x_ref2=x_ref2, + z_trg=z_trg, + z_trg2=z_trg2, + ) + elif self.mode == "val": + x_ref, y_ref = self._fetch_refs() + inputs = Munch(x_src=x, y_src=y, x_ref=x_ref, y_ref=y_ref) + elif self.mode == "test": + inputs = Munch(x=x, y=y) + else: + raise NotImplementedError + + return Munch({k: v.to(self.device) for k, v in inputs.items()}) + + +class AugmentedInputFetcher(InputFetcher): + def __init__(self, loader, loader_ref=None, latent_dim=16, mode=""): + super().__init__(loader, loader_ref, latent_dim, mode) + + def _fetch_inputs(self): + try: + x, x2, y = next(self.iter) + except (AttributeError, StopIteration): + self.iter = iter(self.loader) + x, x2, y = next(self.iter) + return x, x2, y + + def __next__(self): + x, x2, y = self._fetch_inputs() + if self.mode == "train": + x_ref, x_ref2, y_ref = self._fetch_refs() + z_trg = torch.randn(x.size(0), self.latent_dim) + z_trg2 = torch.randn(x.size(0), self.latent_dim) + inputs = Munch( + x_src=x, + y_src=y, + x_src2=x2, + y_ref=y_ref, + x_ref=x_ref, + x_ref2=x_ref2, + z_trg=z_trg, + z_trg2=z_trg2, + ) + elif self.mode == "val": + x_ref, _, y_ref = self._fetch_refs() + inputs = Munch(x_src=x, y_src=y, x_ref=x_ref, y_ref=y_ref) + elif self.mode == "test": + inputs = Munch(x=x, y=y) + else: + raise NotImplementedError + + return Munch({k: v.to(self.device) for k, v in inputs.items()}) + + +class TrainingData: + def __init__( + self, + source, + reference, + img_size=128, + batch_size=8, + num_workers=4, + grayscale=False, + mean=None, + std=None, + rand_crop_prob=0, + ): + self.src = get_train_loader( + root=source, + which="source", + img_size=img_size, + batch_size=batch_size, + num_workers=num_workers, + grayscale=grayscale, + mean=mean, + std=std, + prob=rand_crop_prob, + ) + self.reference = get_train_loader( + root=reference, + which="reference", + img_size=img_size, + batch_size=batch_size, + num_workers=num_workers, + grayscale=grayscale, + mean=mean, + std=std, + prob=rand_crop_prob, + ) + + +class ValidationData: + """ + A data loader for validation. + + """ + + def __init__( + self, + source, + reference=None, + mode="latent", + img_size=128, + batch_size=32, + num_workers=4, + grayscale=False, + mean=None, + std=None, + **kwargs, + ): + """ + Parameters + ---------- + source_directory : str + The directory containing the source images. + ref_directory : str + The directory containing the reference images, defaults to source_directory if None. + mode : str + The mode of the data loader, either "latent" or "reference". + If "latent", the data loader will only load the source images. + If "reference", the data loader will load both the source and reference images. + image_size : int + The size of the images; images of a different size will be resized. + batch_size : int + The batch size for source data. + num_workers : int + The number of workers for the data loader. + grayscale : bool + Whether the images are grayscale. + mean: float + The mean for normalization, for the classifier. + std: float + The standard deviation for normalization, for the classifier. + kwargs : dict + Unused keyword arguments, for compatibility with configuration. + """ + assert mode in ["latent", "reference"] + # parameters + self.image_size = img_size + self.batch_size = batch_size + self.num_workers = num_workers + self.grayscale = grayscale + self.mean = mean + self.std = std + # The source and target classes + self.source = None + self.target = None + # The roots of the source and target directories + self.source_root = Path(source) + if reference is not None: + self.ref_root = Path(reference) + else: + self.ref_root = self.source_root + + # Available classes + self.available_sources = [ + subdir.name for subdir in self.source_root.iterdir() if subdir.is_dir() + ] + self._available_targets = None + self.set_mode(mode) + + def set_mode(self, mode): + assert mode in ["latent", "reference"] + self.mode = mode + + @property + def available_targets(self): + if self.mode == "latent": + return self.available_sources + elif self._available_targets is None: + self._available_targets = [ + subdir.name + for subdir in Path(self.ref_root).iterdir() + if subdir.is_dir() + ] + return self._available_targets + + def set_target(self, target): + assert ( + target in self.available_targets + ), f"{target} not in {self.available_targets}" + self.target = target + + def set_source(self, source): + assert ( + source in self.available_sources + ), f"{source} not in {self.available_sources}" + self.source = source + + @property + def reference_directory(self): + if self.mode == "latent": + return None + if self.target is None: + raise (ValueError("Target not set.")) + return self.ref_root / self.target + + @property + def source_directory(self): + if self.source is None: + raise (ValueError("Source not set.")) + return self.source_root / self.source + + def print_info(self): + print(f"Avaliable sources: {self.available_sources}") + print(f"Avaliable targets: {self.available_targets}") + print(f"Mode: {self.mode}") + try: + print(f"Current source directory: {self.source_directory}") + except ValueError: + print("Source not set.") + try: + print(f"Current target directory: {self.reference_directory}") + except ValueError: + print("Target not set.") + + @property + def loader_src(self): + return get_eval_loader( + self.source_directory, + img_size=self.image_size, + batch_size=self.batch_size, + num_workers=self.num_workers, + grayscale=self.grayscale, + mean=self.mean, + std=self.std, + drop_last=False, + ) + + @property + def loader_ref(self): + return get_eval_loader( + self.reference_directory, + img_size=self.image_size, + batch_size=self.batch_size, + num_workers=self.num_workers, + grayscale=self.grayscale, + mean=self.mean, + std=self.std, + drop_last=True, + ) diff --git a/src/quac/training/eval.py b/src/quac/training/eval.py new file mode 100644 index 0000000..cc7066f --- /dev/null +++ b/src/quac/training/eval.py @@ -0,0 +1,141 @@ +""" +StarGAN v2 +Copyright (c) 2020-present NAVER Corp. + +This work is licensed under the Creative Commons Attribution-NonCommercial +4.0 International License. To view a copy of this license, visit +http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. +""" + +import os +from collections import OrderedDict + +import numpy as np +from pathlib import Path +from quac.training import utils +from quac.training.classification import ClassifierWrapper +from quac.training.data_loader import get_eval_loader +import torch +from tqdm import tqdm + + +@torch.no_grad() +def calculate_metrics( + eval_dir, + step=0, + mode="latent", + classifier_checkpoint=None, + img_size=128, + val_batch_size=16, + num_outs_per_domain=10, + mean=None, + std=None, + input_dim=3, + run=None, +): + print("Calculating conversion rate for all tasks...", flush=True) + translation_rate_values = ( + OrderedDict() + ) # How many output images are of the right class + conversion_rate_values = ( + OrderedDict() + ) # How many input samples have a valid counterfactual + + domains = [subdir.name for subdir in Path(eval_dir).iterdir() if subdir.is_dir()] + + for subdir in Path(eval_dir).iterdir(): + if not subdir.is_dir() or subdir.name.startswith("."): # Skip hidden files + continue + src_domain = subdir.name + + for subdir2 in Path(subdir).iterdir(): + if not subdir2.is_dir() or subdir2.name.startswith("."): + continue + trg_domain = subdir2.name + + task = "%s/%s" % (src_domain, trg_domain) + print("Calculating conversion rate for %s..." % task, flush=True) + target_class = domains.index(trg_domain) + + translation_rate, conversion_rate = calculate_conversion_given_path( + subdir2, + model_checkpoint=classifier_checkpoint, + target_class=target_class, + img_size=img_size, + batch_size=val_batch_size, + num_outs_per_domain=num_outs_per_domain, + mean=mean, + std=std, + grayscale=(input_dim == 1), + ) + conversion_rate_values[ + "conversion_rate_%s/%s" % (mode, task) + ] = conversion_rate + translation_rate_values[ + "translation_rate_%s/%s" % (mode, task) + ] = translation_rate + + # calculate the average conversion rate for all tasks + conversion_rate_mean = 0 + translation_rate_mean = 0 + for _, value in conversion_rate_values.items(): + conversion_rate_mean += value / len(conversion_rate_values) + for _, value in translation_rate_values.items(): + translation_rate_mean += value / len(translation_rate_values) + + conversion_rate_values["conversion_rate_%s/mean" % mode] = conversion_rate_mean + translation_rate_values["translation_rate_%s/mean" % mode] = translation_rate_mean + + # report conversion rate values + filename = os.path.join(eval_dir, "conversion_rate_%.5i_%s.json" % (step, mode)) + utils.save_json(conversion_rate_values, filename) + # report translation rate values + filename = os.path.join(eval_dir, "translation_rate_%.5i_%s.json" % (step, mode)) + utils.save_json(translation_rate_values, filename) + if run is not None: + run.log(conversion_rate_values, step=step) + run.log(translation_rate_values, step=step) + + +@torch.no_grad() +def calculate_conversion_given_path( + path, + model_checkpoint, + target_class, + img_size=128, + batch_size=50, + num_outs_per_domain=10, + mean=0.5, + std=0.5, + grayscale=False, +): + print("Calculating conversion given path %s..." % path, flush=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + classifier = ClassifierWrapper(model_checkpoint, mean=mean, std=std) + classifier.to(device) + classifier.eval() + + loader = get_eval_loader( + path, + img_size=img_size, + batch_size=batch_size, + imagenet_normalize=False, + shuffle=False, + grayscale=grayscale, + ) + + predictions = [] + for x in tqdm(loader, total=len(loader)): + x = x.to(device) + predictions.append(classifier(x).cpu().numpy()) + predictions = np.concatenate(predictions, axis=0) + # Do it in a vectorized way, by reshaping the predictions + predictions = predictions.reshape(-1, num_outs_per_domain, predictions.shape[-1]) + predictions = predictions.argmax(axis=-1) + # + at_least_one = np.any(predictions == target_class, axis=1) + # + conversion_rate = np.mean(at_least_one) # (sum(at_least_one) / len(at_least_one) + translation_rate = np.mean(predictions == target_class) + return translation_rate, conversion_rate diff --git a/src/quac/training/logging.py b/src/quac/training/logging.py new file mode 100644 index 0000000..1b91b24 --- /dev/null +++ b/src/quac/training/logging.py @@ -0,0 +1,84 @@ +from numbers import Number +from typing import Union, Optional +import torch +import numpy as np + +try: + import wandb +except ImportError: + wandb = None + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + SummaryWriter = None + + +class Logger: + def create(log_type, resume_iter=0, hparams={}, **kwargs): + if log_type == "wandb": + if wandb is None: + raise ImportError("wandb is not installed.") + resume = "allow" if resume_iter > 0 else False + return WandBLogger(hparams=hparams, resume=resume, **kwargs) + elif log_type == "tensorboard": + if SummaryWriter is None: + raise ImportError("Tensorboard is not available.") + purge_step = resume_iter if resume_iter > 0 else None + return TensorboardLogger(hparams=hparams, purge_step=purge_step, **kwargs) + else: + raise NotImplementedError + + +class WandBLogger: + def __init__( + self, + hparams: dict, + project: str, + name: str, + notes: str, + tags: list, + resume: bool = False, + id: Optional[str] = None, + ): + self.run = wandb.init( + project=project, + name=name, + notes=notes, + tags=tags, + config=hparams, + resume=resume, + id=id, + ) + + def log(self, data: dict[str, Number], step: int = 0): + self.run.log(data, step=step) + + def log_images( + self, data: dict[str, Union[torch.Tensor, np.ndarray]], step: int = 0 + ): + for key, value in data.items(): + self.run.log({key: wandb.Image(value)}, step=step) + + +class TensorboardLogger: + # NOTE: Not tested + def __init__( + self, + log_dir: str, + comment: str, + hparams: dict, + purge_step: Union[int, None] = None, + ): + self.writer = SummaryWriter(log_dir, comment=comment, purge_step=purge_step) + self.writer.add_hparams(hparams, {}) + + def log(self, data: dict[str, Number], step: int = 0): + for key, value in data.items(): + self.writer.add_scalar(key, value, step) + + def log_images( + self, data: dict[str, Union[torch.Tensor, np.ndarray]], step: int = 0 + ): + for key, value in data.items(): + self.writer.add_images(key, value, step) diff --git a/src/quac/training/solver.py b/src/quac/training/solver.py new file mode 100644 index 0000000..b00a351 --- /dev/null +++ b/src/quac/training/solver.py @@ -0,0 +1,605 @@ +""" +StarGAN v2 +Copyright (c) 2020-present NAVER Corp. + +This work is licensed under the Creative Commons Attribution-NonCommercial +4.0 International License. To view a copy of this license, visit +http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. +""" + +import datetime +from munch import Munch +import numpy as np +import os +from os.path import join as ospj +from pathlib import Path +from quac.training.data_loader import AugmentedInputFetcher +from quac.training.checkpoint import CheckpointIO +import quac.training.utils as utils +from quac.training.classification import ClassifierWrapper +import shutil +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from tqdm import tqdm +import wandb + + +transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + ] +) + + +class Solver(nn.Module): + def __init__( + self, + nets, + nets_ema, + f_lr: float, + lr: float, + beta1: float, + beta2: float, + weight_decay: float, + root_dir: str, + run=None, + ): + super().__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.nets = nets + self.nets_ema = nets_ema + self.run = run + self.root_dir = Path(root_dir) + self.checkpoint_dir = self.root_dir / "checkpoints" + self.checkpoint_dir.mkdir(exist_ok=True, parents=True) + + checkpoint_dir = str(self.checkpoint_dir) + + # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device) + for name, module in self.nets.items(): + utils.print_network(module, name) + setattr(self, name, module) + for name, module in self.nets_ema.items(): + setattr(self, name + "_ema", module) + + self.optims = Munch() + for net in self.nets.keys(): + self.optims[net] = torch.optim.Adam( + params=self.nets[net].parameters(), + lr=f_lr if net == "mapping_network" else lr, + betas=[beta1, beta2], + weight_decay=weight_decay, + ) + + self.ckptios = [ + CheckpointIO( + ospj(checkpoint_dir, "{:06d}_nets.ckpt"), + data_parallel=True, + **self.nets, + ), + CheckpointIO( + ospj(checkpoint_dir, "{:06d}_nets_ema.ckpt"), + data_parallel=True, + **self.nets_ema, + ), + CheckpointIO(ospj(checkpoint_dir, "{:06d}_optims.ckpt"), **self.optims), + ] + else: + self.ckptios = [ + CheckpointIO( + ospj(checkpoint_dir, "{:06d}_nets_ema.ckpt"), + data_parallel=True, + **self.nets_ema, + ) + ] + + self.to(self.device) + # TODO The EMA doesn't need to be in named_childeren() + for name, network in self.named_children(): + if "ema" not in name: + print("Initializing %s..." % name) + network.apply(utils.he_init) + + def _save_checkpoint(self, step): + for ckptio in self.ckptios: + ckptio.save(step) + + def _load_checkpoint(self, step): + for ckptio in self.ckptios: + ckptio.load(step) + + def _reset_grad(self): + for optim in self.optims.values(): + optim.zero_grad() + + @property + def latent_dim(self): + try: + latent_dim = self.nets.mapping_network.latent_dim + except AttributeError: + # it's a data parallel model + latent_dim = self.nets.mapping_network.module.latent_dim + return latent_dim + + def train( + self, + loader, + resume_iter: int = 0, + total_iters: int = 100000, + log_every: int = 100, + save_every: int = 10000, + eval_every: int = 10000, + # sample_dir: str = "samples", + lambda_ds: float = 1.0, + ds_iter: int = 10000, + lambda_reg: float = 1.0, + lambda_sty: float = 1.0, + lambda_cyc: float = 1.0, + # Validation things + val_loader=None, + val_config=None, + ): + start = datetime.datetime.now() + nets = self.nets + nets_ema = self.nets_ema + optims = self.optims + + fetcher = AugmentedInputFetcher( + loader.src, + loader.reference, + latent_dim=self.latent_dim, + mode="train", + ) + + # resume training if necessary + if resume_iter > 0: + self._load_checkpoint(resume_iter) + + # remember the initial value of ds weight + initial_lambda_ds = lambda_ds + + print("Start training...") + for i in range(resume_iter, total_iters): + # fetch images and labels + inputs = next(fetcher) + x_real, x_aug, y_org = inputs.x_src, inputs.x_src2, inputs.y_src + x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref + z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2 + + # train the discriminator + d_loss, d_losses_latent = compute_d_loss( + nets, x_real, y_org, y_trg, z_trg=z_trg, lambda_reg=lambda_reg + ) + self._reset_grad() + d_loss.backward() + optims.discriminator.step() + + d_loss, d_losses_ref = compute_d_loss( + nets, x_real, y_org, y_trg, x_ref=x_ref, lambda_reg=lambda_reg + ) + self._reset_grad() + d_loss.backward() + optims.discriminator.step() + + # train the generator + g_loss, g_losses_latent, fake_x_latent = compute_g_loss( + nets, + x_real, + y_org, + y_trg, + z_trgs=[z_trg, z_trg2], + lambda_sty=lambda_sty, + lambda_ds=lambda_ds, + lambda_cyc=lambda_cyc, + ) + self._reset_grad() + g_loss.backward() + optims.generator.step() + optims.mapping_network.step() + optims.style_encoder.step() + + g_loss, g_losses_ref, fake_x_reference = compute_g_loss( + nets, + x_real, + y_org, + y_trg, + x_refs=[x_ref, x_ref2], + x_aug=x_aug, + lambda_sty=lambda_sty, + lambda_ds=lambda_ds, + lambda_cyc=lambda_cyc, + ) + self._reset_grad() + g_loss.backward() + optims.generator.step() + + # compute moving average of network parameters + moving_average(nets.generator, nets_ema.generator, beta=0.999) + moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999) + moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999) + + # decay weight for diversity sensitive loss + if lambda_ds > 0: + lambda_ds -= initial_lambda_ds / ds_iter + + if (i + 1) % eval_every == 0 and val_loader is not None: + self.evaluate( + val_loader, iteration=i + 1, mode="reference", val_config=val_config + ) + self.evaluate( + val_loader, iteration=i + 1, mode="latent", val_config=val_config + ) + + # save model checkpoints + if (i + 1) % save_every == 0: + self._save_checkpoint(step=i + 1) + + # print out log losses, images + if (i + 1) % log_every == 0: + elapsed = datetime.datetime.now() - start + # Log the images made by the EMA model! + with torch.no_grad(): + ema_fake_x_latent = nets_ema.generator( + x_real, nets_ema.mapping_network(z_trg, y_trg) + ) + ema_fake_x_reference = nets_ema.generator( + x_real, nets_ema.style_encoder(x_ref, y_trg) + ) + self.log( + d_losses_latent, + d_losses_ref, + g_losses_latent, + g_losses_ref, + lambda_ds, + x_real, + x_ref, + fake_x_latent, + fake_x_reference, + ema_fake_x_latent, + ema_fake_x_reference, + y_org, # Source classes + y_trg, # Target classes + step=i + 1, + total_iters=total_iters, + elapsed_time=elapsed, + ) + + def log( + self, + d_losses_latent, + d_losses_ref, + g_losses_latent, + g_losses_ref, + lambda_ds, + x_real, + x_ref, + fake_x_latent, + fake_x_reference, + ema_fake_x_latent, + ema_fake_x_reference, + y_source, + y_target, + step, + total_iters, + elapsed_time, + ): + all_losses = dict() + for loss, prefix in zip( + [d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref], + ["D/latent_", "D/ref_", "G/latent_", "G/ref_"], + ): + for key, value in loss.items(): + all_losses[prefix + key] = value + all_losses["G/lambda_ds"] = lambda_ds + # log all losses to wandb or print them + if self.run: + self.run.log(all_losses, step=step) + for name, img, label in zip( + [ + "x_real", + "x_ref", + "fake_x_latent", + "fake_x_reference", + "ema_fake_x_latent", + "ema_fake_x_reference", + ], + [ + x_real, + x_ref, + fake_x_latent, + fake_x_reference, + ema_fake_x_latent, + ema_fake_x_reference, + ], + [y_source, y_target, y_target, y_target, y_target, y_target], + ): + # TODO put captions back in somehow + self.run.log_images({name: img}, step=step) + + print( + f"[{elapsed_time}]: {step}/{total_iters}", + flush=True, + ) + g_losses = "\t".join( + [ + f"{key}: {value:.4f}" + for key, value in all_losses.items() + if not key.startswith("D/") + ] + ) + d_losses = "\t".join( + [ + f"{key}: {value:.4f}" + for key, value in all_losses.items() + if key.startswith("D/") + ] + ) + print(f"G Losses: {g_losses}", flush=True) + print(f"D Losses: {d_losses}", flush=True) + + @torch.no_grad() + def evaluate( + self, + val_loader, + iteration=None, + num_outs_per_domain=10, + mode="latent", + val_config=None, + ): + """ + Generates images for evaluation and stores them to disk. + + Parameters + ---------- + val_loader + """ + if iteration is None: # Choose the iteration to evaluate + resume_iter = resume_iter + self._load_checkpoint(resume_iter) + + # Generate images for evaluation + eval_dir = self.root_dir / "eval" + eval_dir.mkdir(exist_ok=True, parents=True) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Load classifier + classifier = ClassifierWrapper( + val_config.classifier_checkpoint, + val_config.mean, + val_config.std, + assume_normalized=val_config.assume_normalized, + do_nothing=val_config.do_nothing, + ) + classifier.to(device) + assert mode in ["latent", "reference"] + + val_loader.set_mode(mode) + + domains = val_loader.available_targets + print("Number of domains: %d" % len(domains)) + + conversion_rate_values = {} + translation_rate_values = {} + + for trg_idx, trg_domain in enumerate(domains): + src_domains = [x for x in val_loader.available_sources if x != trg_domain] + val_loader.set_target(trg_domain) + if mode == "reference": + loader_ref = val_loader.loader_ref + + for src_idx, src_domain in enumerate(src_domains): + task = "%s/%s" % (src_domain, trg_domain) + # Creating the path + path_fake = os.path.join(eval_dir, task) + shutil.rmtree(path_fake, ignore_errors=True) + os.makedirs(path_fake) + + # Setting the source domain + val_loader.set_source(src_domain) + loader_src = val_loader.loader_src + + for i, x_src in enumerate(tqdm(loader_src, total=len(loader_src))): + N = x_src.size(0) + x_src = x_src.to(device) + y_trg = torch.tensor([trg_idx] * N).to(device) + + predictions = [] + # generate num_outs_per_domain outputs from the same input + for j in range(num_outs_per_domain): + if mode == "latent": + z_trg = torch.randn(N, self.latent_dim).to(device) + s_trg = self.nets_ema.mapping_network(z_trg, y_trg) + else: + # x_ref = x_trg.clone() + try: + # TODO don't need to re-do this every time, just use + # the same set of reference images for the whole dataset! + x_ref = next(iter_ref).to(device) + except: + iter_ref = iter(loader_ref) + x_ref = next(iter_ref).to(device) + + if x_ref.size(0) > N: + x_ref = x_ref[:N] + elif x_ref.size(0) < N: + raise ValueError( + "Not enough reference images." + "Make sure that the batch size of the validation loader is bigger than `num_outs_per_domain`." + ) + s_trg = self.nets_ema.style_encoder(x_ref, y_trg) + + x_fake = self.nets_ema.generator(x_src, s_trg) + # Run the classification + pred = classifier( + x_fake, + assume_normalized=val_config.assume_normalized, + do_nothing=val_config.do_nothing, + ) + predictions.append(pred.cpu().numpy()) + predictions = np.stack(predictions, axis=0) + assert len(predictions) > 0 + # Do it in a vectorized way, by reshaping the predictions + predictions = predictions.reshape( + -1, num_outs_per_domain, predictions.shape[-1] + ) + predictions = predictions.argmax(axis=-1) + # + at_least_one = np.any(predictions == trg_idx, axis=1) + # + conversion_rate = np.mean(at_least_one) + translation_rate = np.mean(predictions == trg_idx) + + # STORE + conversion_rate_values[ + f"conversion_rate_{mode}/" + task + ] = conversion_rate + translation_rate_values[ + f"translation_rate_{mode}/" + task + ] = translation_rate + + # Add average conversion rate and translation rate + conversion_rate_values[f"conversion_rate_{mode}/average"] = np.mean( + [conversion_rate_values[key] for key in conversion_rate_values.keys()] + ) + translation_rate_values[f"translation_rate_{mode}/average"] = np.mean( + [translation_rate_values[key] for key in translation_rate_values.keys()] + ) + + # report conversion rate values + filename = os.path.join( + eval_dir, "conversion_rate_%.5i_%s.json" % (iteration, mode) + ) + utils.save_json(conversion_rate_values, filename) + # report translation rate values + filename = os.path.join( + eval_dir, "translation_rate_%.5i_%s.json" % (iteration, mode) + ) + utils.save_json(translation_rate_values, filename) + if self.run is not None: + self.run.log(conversion_rate_values, step=iteration) + self.run.log(translation_rate_values, step=iteration) + + +def compute_d_loss(nets, x_real, y_org, y_trg, z_trg=None, x_ref=None, lambda_reg=1.0): + assert (z_trg is None) != (x_ref is None) + # with real images + x_real.requires_grad_() + out = nets.discriminator(x_real, y_org) + loss_real = adv_loss(out, 1) + loss_reg = r1_reg(out, x_real) + + # with fake images + with torch.no_grad(): + if z_trg is not None: + s_trg = nets.mapping_network(z_trg, y_trg) + else: # x_ref is not None + s_trg = nets.style_encoder(x_ref, y_trg) + + x_fake = nets.generator(x_real, s_trg) + out = nets.discriminator(x_fake, y_trg) + loss_fake = adv_loss(out, 0) + + loss = loss_real + loss_fake + lambda_reg * loss_reg + return loss, Munch( + real=loss_real.item(), fake=loss_fake.item(), reg=loss_reg.item() + ) + + +def compute_g_loss( + nets, + x_real, + y_org, + y_trg, + z_trgs=None, + x_refs=None, + x_aug=None, + lambda_sty: float = 1.0, + lambda_ds: float = 1.0, + lambda_cyc: float = 1.0, +): + assert (z_trgs is None) != (x_refs is None) + if z_trgs is not None: + z_trg, z_trg2 = z_trgs + if x_refs is not None: + x_ref, x_ref2 = x_refs + + # adversarial loss + if z_trgs is not None: + s_trg = nets.mapping_network(z_trg, y_trg) + else: + s_trg = nets.style_encoder(x_ref, y_trg) + + x_fake = nets.generator(x_real, s_trg) + out = nets.discriminator(x_fake, y_trg) + loss_adv = adv_loss(out, 1) + + # style reconstruction loss + # Adds random augmentation to x_fake before passing to style encoder + s_pred = nets.style_encoder(transform(x_fake), y_trg) + loss_sty = torch.mean(torch.abs(s_pred - s_trg)) + + # diversity sensitive loss + if z_trgs is not None: + s_trg2 = nets.mapping_network(z_trg2, y_trg) + else: + s_trg2 = nets.style_encoder(x_ref2, y_trg) + x_fake2 = nets.generator(x_real, s_trg2) + x_fake2 = x_fake2.detach() + loss_ds = torch.mean(torch.abs(x_fake - x_fake2)) + + # cycle-consistency loss + s_org = nets.style_encoder(x_real, y_org) + x_rec = nets.generator(x_fake, s_org) + loss_cyc = torch.mean(torch.abs(x_rec - x_real)) + + # style invariance loss + if x_aug is not None: + s_pred2 = nets.style_encoder(x_aug, y_org) + loss_sty2 = torch.mean(torch.abs(s_pred2 - s_org)) + loss_sty = (loss_sty + loss_sty2) / 2 + + loss = ( + loss_adv + lambda_sty * loss_sty - lambda_ds * loss_ds + lambda_cyc * loss_cyc + ) + return ( + loss, + Munch( + adv=loss_adv.item(), + sty=loss_sty.item(), + ds=loss_ds.item(), + cyc=loss_cyc.item(), + ), + x_fake, + ) + + +def moving_average(model, model_test, beta=0.999): + for param, param_test in zip(model.parameters(), model_test.parameters()): + param_test.data = torch.lerp(param.data, param_test.data, beta) + + +def adv_loss(logits, target): + assert target in [1, 0] + targets = torch.full_like(logits, fill_value=target) + loss = F.binary_cross_entropy_with_logits(logits, targets) + return loss + + +def r1_reg(d_out, x_in): + # zero-centered gradient penalty for real images + batch_size = x_in.size(0) + grad_dout = torch.autograd.grad( + outputs=d_out.sum(), + inputs=x_in, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + grad_dout2 = grad_dout.pow(2) + assert grad_dout2.size() == x_in.size() + reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) + return reg diff --git a/src/quac/training/stargan.py b/src/quac/training/stargan.py new file mode 100644 index 0000000..b3f5a68 --- /dev/null +++ b/src/quac/training/stargan.py @@ -0,0 +1,413 @@ +""" +StarGAN v2 +Copyright (c) 2020-present NAVER Corp. + +This work is licensed under the Creative Commons Attribution-NonCommercial +4.0 International License. To view a copy of this license, visit +http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. +""" + +import copy +import math + +from munch import Munch +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResBlk(nn.Module): + """ + Residual Block. + + Parameters + ---------- + dim_in: int + Number of input channels. + dim_out: int + Number of output channels. + actv: torch.nn.Module + Activation function. + normalize: bool + If True, apply instance normalization. Default: False. + downsample: bool + If True, apply average pooling with stride 2. Default: False. + """ + + def __init__( + self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), normalize=False, downsample=False + ): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = downsample + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + if self.normalize: + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + if self.downsample: + x = F.avg_pool2d(x, 2) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x / math.sqrt(2) # unit variance + + +class AdaIN(nn.Module): + """ + Adaptive Instance normalization. + """ + + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm2d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features * 2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1, 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma) * self.norm(x) + beta + + +class AdainResBlk(nn.Module): + """A Residual block with Adaptive Instance Normalization.""" + + def __init__( + self, + dim_in: int, + dim_out: int, + style_dim: int = 64, + actv: nn.Module = nn.LeakyReLU(0.2), + upsample: bool = False, + ): + super().__init__() + self.actv = actv + self.upsample = upsample + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out, style_dim) + + def _build_weights(self, dim_in, dim_out, style_dim=64): + self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) + self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) + self.norm1 = AdaIN(style_dim, dim_in) + self.norm2 = AdaIN(style_dim, dim_out) + if self.learned_sc: + self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) + + def _shortcut(self, x): + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.learned_sc: + x = self.conv1x1(x) + return x + + def _residual(self, x, s): + x = self.norm1(x, s) + x = self.actv(x) + if self.upsample: + x = F.interpolate(x, scale_factor=2, mode="nearest") + x = self.conv1(x) + x = self.norm2(x, s) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x, s): + out = self._residual(x, s) + out = (out + self._shortcut(x)) / math.sqrt(2) + return out + + +class Generator(nn.Module): + def __init__( + self, + img_size=256, + style_dim=64, + max_conv_dim=512, + input_dim=1, + final_activation=None, + ): + super().__init__() + dim_in = 2**14 // img_size + self.img_size = img_size + self.from_rgb = nn.Conv2d(input_dim, dim_in, 3, 1, 1) + self.encode = nn.ModuleList() + self.decode = nn.ModuleList() + self.to_rgb = nn.Sequential( + nn.InstanceNorm2d(dim_in, affine=True), + nn.LeakyReLU(0.2), + nn.Conv2d(dim_in, input_dim, 1, 1, 0), + ) + if final_activation == "sigmoid": + # print("Using sigmoid") + self.final_activation = nn.Sigmoid() + else: + # print("Using tanh") + self.final_activation = nn.Tanh() + + # down/up-sampling blocks + repeat_num = int(np.log2(img_size)) - 4 + for _ in range(repeat_num): + dim_out = min(dim_in * 2, max_conv_dim) + self.encode.append(ResBlk(dim_in, dim_out, normalize=True, downsample=True)) + self.decode.insert( + 0, AdainResBlk(dim_out, dim_in, style_dim, upsample=True) + ) # stack-like + dim_in = dim_out + + # bottleneck blocks + for _ in range(2): + self.encode.append(ResBlk(dim_out, dim_out, normalize=True)) + self.decode.insert(0, AdainResBlk(dim_out, dim_out, style_dim)) + + def forward(self, x, s): + x = self.from_rgb(x) + # cache = {} + for block in self.encode: + x = block(x) + for block in self.decode: + x = block(x, s) + return self.final_activation(self.to_rgb(x)) + + +class MappingNetwork(nn.Module): + def __init__(self, latent_dim=16, style_dim=64, num_domains=2): + super().__init__() + self.latent_dim = latent_dim + layers = [] + layers += [nn.Linear(latent_dim, 512)] + layers += [nn.ReLU()] + for _ in range(3): + layers += [nn.Linear(512, 512)] + layers += [nn.ReLU()] + self.shared = nn.Sequential(*layers) + + self.unshared = nn.ModuleList() + for _ in range(num_domains): + self.unshared += [ + nn.Sequential( + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, style_dim), + ) + ] + + def forward(self, z, y): + h = self.shared(z) + out = [] + for layer in self.unshared: + out += [layer(h)] + out = torch.stack(out, dim=1) # (batch, num_domains, style_dim) + idx = torch.LongTensor(range(y.size(0))).to(y.device) + s = out[idx, y] # (batch, style_dim) + return s + + +class StyleEncoder(nn.Module): + def __init__( + self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512, input_dim=3 + ): + super().__init__() + dim_in = 2**14 // img_size + + self.nearest_power = None + if np.ceil(np.log2(img_size)) != np.floor(np.log2(img_size)): # Not power of 2 + self.nearest_power = int(np.log2(img_size)) + + blocks = [] + blocks += [nn.Conv2d(input_dim, dim_in, 3, 1, 1)] + + repeat_num = int(np.log2(img_size)) - 2 + for _ in range(repeat_num): + # For img_size = 224, repeat_num = 5, dim_out = 256, 512, 512, 512, 512 + dim_out = min(dim_in * 2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, downsample=True)] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] + blocks += [nn.LeakyReLU(0.2)] + self.shared = nn.Sequential(*blocks) + + self.unshared = nn.ModuleList() + for _ in range(num_domains): + self.unshared += [nn.Linear(dim_out, style_dim)] + + def forward(self, x, y): + if self.nearest_power is not None: + # Required for img_size=224 in the retina case + # Resize input image to nearest power of 2 + x = F.interpolate(x, size=2**self.nearest_power, mode="bilinear") + h = self.shared(x) + h = h.view(h.size(0), -1) + out = [] + for layer in self.unshared: + out += [layer(h)] + out = torch.stack(out, dim=1) # (batch, num_domains, style_dim) + idx = torch.LongTensor(range(y.size(0))).to(y.device) + s = out[idx, y] # (batch, style_dim) + return s + + +class SingleOutputStyleEncoder(nn.Module): + def __init__( + self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512, input_dim=3 + ): + super().__init__() + dim_in = 2**14 // img_size + + self.nearest_power = None + if np.ceil(np.log2(img_size)) != np.floor(np.log2(img_size)): # Not power of 2 + self.nearest_power = int(np.log2(img_size)) + + blocks = [] + blocks += [nn.Conv2d(input_dim, dim_in, 3, 1, 1)] + + repeat_num = int(np.log2(img_size)) - 2 + for _ in range(repeat_num): + # For img_size = 224, repeat_num = 5, dim_out = 256, 512, 512, 512, 512 + dim_out = min(dim_in * 2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, downsample=True)] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] + blocks += [nn.LeakyReLU(0.2)] + self.shared = nn.Sequential(*blocks) + + # Making this shared again, to try to learn new things from data + self.output = nn.Linear(dim_out, style_dim) + + def forward(self, x, y): + if self.nearest_power is not None: + # Required for img_size=224 in the retina case + # Resize input image to nearest power of 2 + x = F.interpolate(x, size=2**self.nearest_power, mode="bilinear") + h = self.shared(x) + h = h.view(h.size(0), -1) + out = [] + s = self.output(h) + return s + + +class Discriminator(nn.Module): + def __init__(self, img_size=256, num_domains=2, max_conv_dim=512, input_dim=3): + super().__init__() + dim_in = 2**14 // img_size + blocks = [] + blocks += [nn.Conv2d(input_dim, dim_in, 3, 1, 1)] + + repeat_num = int(np.log2(img_size)) - 2 + for _ in range(repeat_num): + dim_out = min(dim_in * 2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, downsample=True)] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] + blocks += [nn.LeakyReLU(0.2)] + blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)] + self.main = nn.Sequential(*blocks) + + def forward(self, x, y): + out = self.main(x) + out = out.view(out.size(0), -1) # (batch, num_domains) + idx = torch.LongTensor(range(y.size(0))).to(y.device) + out = out[idx, y] # (batch) + return out + + +def build_model( + img_size=128, + style_dim=64, + input_dim=3, + latent_dim=16, + num_domains=4, + single_output_style_encoder=False, + final_activation=None, + gpu_ids=[0], +): + generator = nn.DataParallel( + Generator( + img_size, style_dim, input_dim=input_dim, final_activation=final_activation + ), + device_ids=gpu_ids, + ) + mapping_network = nn.DataParallel( + MappingNetwork(latent_dim, style_dim, num_domains), + device_ids=gpu_ids, + ) + if single_output_style_encoder: + print("Using single output style encoder") + style_encoder = nn.DataParallel( + SingleOutputStyleEncoder( + img_size, + style_dim, + num_domains, + input_dim=input_dim, + ), + device_ids=gpu_ids, + ) + else: + style_encoder = nn.DataParallel( + StyleEncoder( + img_size, + style_dim, + num_domains, + input_dim=input_dim, + ), + device_ids=gpu_ids, + ) + discriminator = nn.DataParallel( + Discriminator(img_size, num_domains, input_dim=input_dim), + device_ids=gpu_ids, + ) + generator_ema = copy.deepcopy(generator) + mapping_network_ema = copy.deepcopy(mapping_network) + style_encoder_ema = copy.deepcopy(style_encoder) + + nets = Munch( + generator=generator, + mapping_network=mapping_network, + style_encoder=style_encoder, + discriminator=discriminator, + ) + nets_ema = Munch( + generator=generator_ema, + mapping_network=mapping_network_ema, + style_encoder=style_encoder_ema, + ) + + return nets, nets_ema diff --git a/src/quac/training/utils.py b/src/quac/training/utils.py new file mode 100644 index 0000000..6d7dd2f --- /dev/null +++ b/src/quac/training/utils.py @@ -0,0 +1,264 @@ +""" +StarGAN v2 +Copyright (c) 2020-present NAVER Corp. + +This work is licensed under the Creative Commons Attribution-NonCommercial +4.0 International License. To view a copy of this license, visit +http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. +""" + +import json +import matplotlib.pyplot as plt +import numpy as np +from os.path import join as ospj +from pathlib import Path +import pandas as pd +import re +import torch +import torch.nn as nn +import torchvision.utils as vutils + + +def save_json(json_file, filename): + with open(filename, "w") as f: + json.dump(json_file, f, indent=4, sort_keys=False) + + +def print_network(network, name): + num_params = 0 + for p in network.parameters(): + num_params += p.numel() + # print(network) + print("Number of parameters of %s: %i" % (name, num_params)) + + +def he_init(module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu") + if module.bias is not None: + nn.init.constant_(module.bias, 0) + if isinstance(module, nn.Linear): + nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu") + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + +def denormalize(x): + out = (x + 1) / 2 + return out.clamp_(0, 1) + + +def save_image(x, ncol, filename): + x = denormalize(x) + vutils.save_image(x.cpu(), filename, nrow=ncol, padding=0) + + +class Logger: + def __init__(self, log_dir, nets, num_outs_per_domain=10) -> None: + self.log_dir = Path(log_dir) + if not self.log_dir.exists(): + self.log_dir.mkdir(parents=True) + self.nets = nets + + @torch.no_grad() + def translate_and_reconstruct(self, x_src, y_src, x_ref, y_ref, filename): + N, C, H, W = x_src.size() + s_ref = self.nets.style_encoder(x_ref, y_ref) + x_fake = self.nets.generator(x_src, s_ref) + s_src = self.nets.style_encoder(x_src, y_src) + x_rec = self.nets.generator(x_fake, s_src) + x_concat = [x_src, x_ref, x_fake, x_rec] + x_concat = torch.cat(x_concat, dim=0) + save_image(x_concat, N, filename) + del x_concat + + @torch.no_grad() + def translate_using_latent(self, x_src, y_trg_list, z_trg_list, psi, filename): + N, C, H, W = x_src.size() + latent_dim = z_trg_list[0].size(1) + x_concat = [x_src] + + for i, y_trg in enumerate(y_trg_list): + z_many = torch.randn(10000, latent_dim).to(x_src.device) + y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0]) + s_many = self.nets.mapping_network(z_many, y_many) + s_avg = torch.mean(s_many, dim=0, keepdim=True) + s_avg = s_avg.repeat(N, 1) + + for z_trg in z_trg_list: + s_trg = self.nets.mapping_network(z_trg, y_trg) + s_trg = torch.lerp(s_avg, s_trg, psi) + x_fake = self.nets.generator(x_src, s_trg) + x_concat += [x_fake] + + x_concat = torch.cat(x_concat, dim=0) + save_image(x_concat, N, filename) + + @torch.no_grad() + def translate_using_reference(self, x_src, x_ref, y_ref, filename): + N, C, H, W = x_src.size() + wb = torch.ones(1, C, H, W).to(x_src.device) + x_src_with_wb = torch.cat([wb, x_src], dim=0) + + masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None + s_ref = nets.style_encoder(x_ref, y_ref) + s_ref_list = s_ref.unsqueeze(1).repeat(1, N, 1) + x_concat = [x_src_with_wb] + for i, s_ref in enumerate(s_ref_list): + x_fake = nets.generator(x_src, s_ref, masks=masks) + x_fake_with_ref = torch.cat([x_ref[i : i + 1], x_fake], dim=0) + x_concat += [x_fake_with_ref] + + x_concat = torch.cat(x_concat, dim=0) + save_image(x_concat, N + 1, filename) + del x_concat + + @torch.no_grad() + def debug_image(self, inputs, step): + x_src, y_src = inputs.x_src, inputs.y_src + x_ref, y_ref = inputs.x_ref, inputs.y_ref + + device = inputs.x_src.device + N = inputs.x_src.size(0) + + # translate and reconstruct (reference-guided) + filename = ospj(self.sample_dir, "%06d_cycle_consistency.jpg" % (step)) + self.translate_and_reconstruct(x_src, y_src, x_ref, y_ref, filename) + + # latent-guided image synthesis + y_trg_list = [ + torch.tensor(y).repeat(N).to(device) + for y in range(min(args.num_domains, 5)) + ] + z_trg_list = ( + torch.randn(self.num_outs_per_domain, 1, args.latent_dim) + .repeat(1, N, 1) + .to(device) + ) + for psi in [0.5, 0.7, 1.0]: + filename = ospj(self.sample_dir, "%06d_latent_psi_%.1f.jpg" % (step, psi)) + self.translate_using_latent(x_src, y_trg_list, z_trg_list, psi, filename) + + # reference-guided image synthesis + filename = ospj(self.sample_dir, "%06d_reference.jpg" % (step)) + self.translate_using_reference(x_src, x_ref, y_ref, filename) + + +########################### +# LOSS PLOTTING FUNCTIONS # +########################### +def get_epoch_number(f): + # Get the epoch number from the filename and sort the files by epoch + # The epoch number is the only number in the filename, and can be 5 or 6 digits long + return int(re.findall(r"\d+", f.name)[0]) + + +def load_json_files(files): + # Load the data from the json files, store everything into a dictionary, with the epoch number as the key + data = {} + for f in files: + with open(f, "r") as file: + data[get_epoch_number(f)] = json.load(file) + + df = pd.DataFrame.from_dict(data, orient="index") + df.columns = [col.split("/")[-1] for col in df.columns] + df.sort_index(inplace=True) + return df + + +def plot_from_data( + reference_conversion_data, + latent_conversion_data, + reference_translation_data, + latent_translation_data, +): + n_cols = int(np.ceil(len(reference_conversion_data.columns) / 7)) + fig, axes = plt.subplots(7, n_cols, figsize=(15, 15)) + for col, ax in zip(reference_conversion_data.columns, axes.ravel()): + reference_conversion_data[col].plot( + ax=ax, label="Reference Conversion", color="black" + ) + latent_conversion_data[col].plot( + ax=ax, label="Latent Conversion", color="black", linestyle="--" + ) + reference_translation_data[col].plot( + ax=ax, label="Reference Translation", color="gray" + ) + latent_translation_data[col].plot( + ax=ax, label="Latent Translation", color="gray", linestyle="--" + ) + ax.set_ylim(0, 1) + # format the title only if there are any non-numeric characters in the column name + alphabetic_title = any([c.isalpha() for c in col]) + if alphabetic_title: + # Split the words in the column by the underscore, remove all numbers, and add the word "to" between the words + title = " to ".join( + [word.capitalize() for word in col.split("_") if not word.isdigit()] + ) + # Remove all remaining numbers from the title + title = "".join([i for i in title if not i.isdigit()]) + else: + # The title is \d2\d and we want it to be \d to \d, unfortunately sometimes \d is the number 2, so we need to be careful + title = re.sub(r"(\d)2(\d)", r"\1 to \2", col) + ax.set_title(title) + # Hide all of the extra axes if any + for ax in axes.ravel()[len(reference_conversion_data.columns) :]: + ax.axis("off") + # Add an x-axis label to the bottom row of plots that still has a visible x-axis + num_axes = len(reference_conversion_data.columns) + for ax in axes.ravel()[num_axes - n_cols :]: + ax.set_xlabel("Iteration") + # Add a y-axis label to the left column of plots + for ax in axes[:, 0]: + ax.set_ylabel("Rate") + # Make a legend for the whole figure, assuming that the labels are the same for all subplots + fig.legend( + *ax.get_legend_handles_labels(), loc="upper right", bbox_to_anchor=(1.15, 1) + ) + fig.tight_layout() + + return fig + + +def plot_metrics(root, show: bool = True, save: str = ""): + """ + root: str + The root directory containing the json files with the metrics. + show: bool + Whether to show the plot. + save: str + The path to save the plot to. + """ + # Goal is to get a better idea of the success/failure of conversion as training goes on. + # Will be useful to understand how it changes over time, and how the translation/conversion/diversity tradeoff plays out. + + files = list(Path(root).rglob("*.json")) + + # Split files by whether they are LPIPS, conversion_rate, or translation_rate + conversion_files = [f for f in files if "conversion_rate" in f.name] + translation_files = [f for f in files if "translation_rate" in f.name] + + # Split files by whether they are reference or latent + reference_conversion = [f for f in conversion_files if "reference" in f.name] + latent_conversion = [f for f in conversion_files if "latent" in f.name] + + reference_translation = [f for f in translation_files if "reference" in f.name] + latent_translation = [f for f in translation_files if "latent" in f.name] + + # Load the data from the json files + reference_conversion_data = load_json_files(reference_conversion) + latent_conversion_data = load_json_files(latent_conversion) + reference_translation_data = load_json_files(reference_translation) + latent_translation_data = load_json_files(latent_translation) + + fig = plot_from_data( + reference_conversion_data, + latent_conversion_data, + reference_translation_data, + latent_translation_data, + ) + if show: + plt.show() + if save is not None: + fig.savefig(save) diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 0000000..e7b21a4 --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,16 @@ +from quac.data import PairedImageDataset +from torchvision import transforms + + +def test_paired_image_folders(): + transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]) + dataset = PairedImageDataset( + "/nrs/funke/adjavond/data/synapses/test", + "/nrs/funke/adjavond/data/synapses/counterfactuals/stargan_invariance_v0/test", + transform=transform, + ) + x, xc, y, yc = dataset[0] + assert x.shape == (1, 128, 128) + assert xc.shape == x.shape + assert y != yc + assert len(dataset.classes) == 6 diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..f47f567 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,22 @@ +import pytest +from quac.training.stargan import build_model +from quac.training.config import ModelConfig +import torch +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + + +def test_model(): + args = ModelConfig() + nets = build_model(args) + example_input = torch.randn(4, 1, 128, 128) + example_class = torch.randint(0, 5, (4,)) + example_latent = torch.randn(4, 16) + # Ensure that the sizes of the outputs are as expected + latent_style = nets.mapping_network(example_latent, example_class) + assert latent_style.shape == (4, 64) + style = nets.style_encoder(example_input, example_class) + assert style.shape == (4, 64) + out = nets.generator(example_input, style) + assert out.shape == (4, 1, 128, 128)