diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a2bf1d6..dc763e9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: black - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.0.1 - hooks: - - id: mypy + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.0.1 + # hooks: + # - id: mypy diff --git a/README.md b/README.md index 1a46bf0..9e27d68 100644 --- a/README.md +++ b/README.md @@ -4,4 +4,15 @@ -Logo made with the help of DALL-E 2. \ No newline at end of file +Logo made with the help of DALL-E 2. + +Installing: +1. Clone this repository +2. Create a `conda` environment with `python, pytorch, torchvision`; I recommend `mamba` +3. Activate your new environment (`mamba activate ...`) +4. Change into the directory holding this repository. +5. `pip install .` + +Installing as developper: +1. - 4. Same as above. +5. `pip install -e .\[dev\]` diff --git a/docs/tutorials/attribute.md b/docs/tutorials/attribute.md new file mode 100644 index 0000000..abb97c3 --- /dev/null +++ b/docs/tutorials/attribute.md @@ -0,0 +1,83 @@ +# Attribution and evaluation given counterfactuals + +## Attribution +```python +# Load the classifier +from quac.generate import load_classifier +classifier = load_classifier( + +) + +# Defining attributions +from quac.attribution import ( + DDeepLift, + DIntegratedGradients, + AttributionIO +) +from torchvision import transforms + +attributor = AttributionIO( + attributions = { + "deeplift" : DDeepLift(), + "ig" : DIntegratedGradients() + }, + output_directory = "my_attributions_directory" +) + +transform = transforms.Compose( + [ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.Normalize(...) + ] +) + +# This will run attributions and store all of the results in the output_directory +# Shows a progress bar +attributor.run( + source_directory="my_source_image_directory", + counterfactual_directory="my_counterfactual_image_directory", + transform=transform +) +``` + +## Evaluation +Once you have attributions, you can run evaluations. +You may want to try different methods for thresholding and smoothing the attributions to get masks. + + +In this example, we evaluate the results from the DeepLift attribution method. + +```python +# Defining processors and evaluators +from quac.evaluation import Processor, Evaluator +from sklearn.metrics import ConfusionMatrixDisplay + +classifier = load_classifier(...) + +evaluator = Evaluator( + classifier, + source_directory="my_source_image_directory", + counterfactual_directory="my_counterfactual_image_directory", + attribution_directory="my_attributions_directory/deeplift", + transform=transform +) + + +cf_confusion_matrix = evaluator.classification_report( + data="counterfactuals", # this is the default + return_classifications=False, + print_report=True, + ) + +# Plot the confusion matrix +disp = ConfusionMatrixDisplay( + confusion_matrix=cf_confusion_matrix, +) +disp.show() + +# Run QuAC evaluation on your attribution and store a report +report = evaluator.quantify(processor=Processor()) +# The report will be stored based on the processor's name, which is "default" by default +report.store("my_attributions_directory/deeplift/reports") +``` diff --git a/docs/tutorials/train.md b/docs/tutorials/train.md new file mode 100644 index 0000000..9e18aee --- /dev/null +++ b/docs/tutorials/train.md @@ -0,0 +1,151 @@ +# Training the StarGAN + +In this tutorial, we go over the basics of how to train a (slightly modified) StarGAN for use in QuAC. + +## Defining the dataset + +The data is expected to be in the form of image files with a directory structure denoting the classification. +For example: +``` +data_folder/ + crow/ + crow1.png + crow2.png + raven/ + raven1.png + raven2.png +``` + +A training dataset is defined in `quac.training.data` which will need to be given two directories: a `source` and a `reference`. These directories can be the same. + +The validation dataset will need the same information. + +For example: +```python +from quac.training.data import TrainingDataset + +dataset = TrainingDataset( + source="path/to/training/data", + reference="path/to/training/data", + img_size=128, + batch_size=4, + num_workers=4 +) + +# Setup data for validation +val_dataset = ValidationData( + source="path/to/training/data", + reference="path/to/training/data", + img_size=128, + batch_size=16, + num_workers=16 +) + +``` +## Defining the models + +The models can be built using a function in `quac.training.stargan`. + +```python +from quac.training.stargan import build_model + +nets, nets_ema = build_model( + img_size=256, # Images are made square + style_dim=64, # The size of the style vector + input_dim=1, # Number of channels in the input + latent_dim=16, # The size of the random latent + num_domains=4, # Number of classes + single_output_style_encoder=False +) +## Defining the models +nets, nets_ema = build_model(**experiment.model.model_dump()) + +``` + +If using multiple or specific GPUs, it may be necessary to add the `gpu_ids` argument. + +The `nets_ema` are a copy of the `nets` that will not be trained but rather will be an exponential moving average of the weight of the `nets`. +The sub-networks of both can be accessed in a dictionary-like manner. + +## Creating a logger +```python +# Example using WandB +logger = Logger.create( + log_type="wandb", + project="project-name", + name="experiment name", + tags=["experiment", "project", "test", "quac", "stargan"], + hparams={ # this holds all of the hyperparameters you want to store for your run + "hyperparameter_key": "Hyperparameter values" + } +) + +# TODO example using tensorboard +``` + +## Defining the Solver + +It is now time to initiate the `Solver` object, which will do the bulk of the work in training. + +```python +solver = Solver( + nets, + nets_ema, + # Checkpointing + checkpoint_dir="path/to/store/checkpoints", + # Parameters for the Adam optimizers + lr=1e-4, + beta1=0.5, + beta2=0.99, + weight_decay=0.1, +) + +# TODO +solver = Solver(nets, nets_ema, **experiment.solver.model_dump(), run=logger) +``` + +## Training +We use the solver to train on the data as follows: + +```python +from quac.training.options import ValConfig +val_config=ValConfig( + classifier_checkpoint="/path/to/classifier/", mean=0.5, std=0.5 +) + +solver.train(dataset, val_config) +``` + +All results will be stored in the `checkpoint_directory` defined above. +Validation will be done during training at regular intervals (by default, every 10000 iterations). + +## BONUS: Training with a Config file + +```python +run_config=RunConfig( + # All of these are default + resume_iter=0, + total_iter=100000, + log_every=1000, + save_every=10000, + eval_every=10000, +) +val_config=ValConfig( + classifier_checkpoint="/path/to/classifier/", + # The below is default + val_batch_size=32 + num_outs_per_domain=10, + mean=0.5, + std=0.5, + grayscale=True, +) +loss_config=LossConfig( + # The following should probably not be changed + # unless you really know what you're doing :) + # All of these are default + lambda_ds=1., + lambda_reg=1., + lambda_sty=1., + lambda_cyc=1., +) +``` diff --git a/pyproject.toml b/pyproject.toml index 59a1c12..a843ada 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,12 +18,15 @@ authors = [ dynamic = ["version"] dependencies = [ "captum", - "numpy", - "torch", + "numpy", + "munch", + "torch", "torchvision", "funlib.learn.torch@git+https://github.com/funkelab/funlib.learn.torch", "opencv-python", - "scipy" + "pydantic", + "scipy", + "scikit-learn" ] [project.optional-dependencies] diff --git a/src/quac/attribution.py b/src/quac/attribution.py index 233c483..7743671 100644 --- a/src/quac/attribution.py +++ b/src/quac/attribution.py @@ -1,13 +1,18 @@ """Holds all of the discriminative attribution methods that are accepted by QuAC.""" -from captum import attr -import torch + +from captum import attr import numpy as np import scipy +from pathlib import Path +from quac.data import PairedImageDataset +from tqdm import tqdm +import torch +from typing import Callable def residual(real_img, fake_img): - """Residual attribution method. - + """Residual attribution method. + This method just takes the standardized difference between the real and fake images. """ res = np.abs(real_img - fake_img) @@ -18,7 +23,7 @@ def residual(real_img, fake_img): def random(real_img, fake_img): """Random attribution method. - This method randomly assigns attribution to each pixel in the image, then applies a Gaussian filter for smoothing. + This method randomly assigns attribution to each pixel in the image, then applies a Gaussian filter for smoothing. """ rand = np.abs(np.random.randn(*np.shape(real_img))) rand = np.abs(scipy.ndimage.gaussian_filter(rand, 4)) @@ -31,60 +36,109 @@ class BaseAttribution: """ Basic format of an attribution class. """ - def __init__(self, classifier): + + def __init__(self, classifier, normalize=True): self.classifier = classifier - - def _attribute(self, real_img, counterfactual_img, real_class, target_class, **kwargs): - raise NotImplementedError("The base attribution class does not have an attribute method.") + self.normalize = normalize + + def _normalize(self, attribution): + """Scale the attribution to be between 0 and 1. + + Note that this also takes the absolute value of the attribution. + Generally in this framework, we only care about the absolute value of the attribution, + because if "negative changes" need to be made, this should be inherent in + the counterfactual image. + """ + attribution = torch.abs(attribution) + # We scale the attribution to be between 0 and 1, batch-wise + min_vals = attribution.flatten(1).min(1)[0][:, None, None, None] + max_vals = attribution.flatten(1).max(1)[0][:, None, None, None] + return (attribution - min_vals) / (max_vals - min_vals) + + def _attribute( + self, real_img, counterfactual_img, real_class, target_class, **kwargs + ): + raise NotImplementedError( + "The base attribution class does not have an attribute method." + ) - def attribute(self, real_img, counterfactual_img, real_class, target_class, **kwargs): + def attribute( + self, + real_img, + counterfactual_img, + real_class, + target_class, + device="cuda", + **kwargs, + ): self.classifier.zero_grad() - attribution = self._attribute(real_img, counterfactual_img, real_class, target_class, **kwargs) - return attribution.detach().cpu().numpy() + # Check if there is a batch dimension, if not, add it + batch_added = False + if len(real_img.shape) == 3: + real_img = real_img[None, ...] + counterfactual_img = counterfactual_img[None, ...] + batch_added = True + + attribution = self._attribute( + real_img.to(device), + counterfactual_img.to(device), + real_class, + target_class, + **kwargs, + ) + attribution = attribution.detach() + if self.normalize: + attribution = self._normalize(attribution) + if batch_added: + attribution = attribution[0] + return attribution.cpu().numpy() class DIntegratedGradients(BaseAttribution): """ Discriminative version of the Integrated Gradients attribution method. """ - def __init__(self, classifier): - super().__init__(classifier) + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) self.ig = attr.IntegratedGradients(classifier) def _attribute(self, real_img, counterfactual_img, real_class, target_class): # FIXME in the original DAPI code, the real and counterfactual were switched. attribution = self.ig.attribute( - real_img[None, ...].cuda(), - baselines=counterfactual_img[None, ...].cuda(), - target=real_class + real_img, + baselines=counterfactual_img, + target=real_class, ) - return attribution[0] + return attribution class DDeepLift(BaseAttribution): """ Discriminative version of the DeepLift attribution method. """ - def __init__(self, classifier): - super().__init__(classifier) + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) self.dl = attr.DeepLift(classifier) def _attribute(self, real_img, counterfactual_img, real_class, target_class): # FIXME in the original DAPI code, the real and counterfactual were switched. attribution = self.dl.attribute( - real_img[None, ...].cuda(), - baselines=counterfactual_img[None, ...].cuda(), - target=real_class + real_img, + baselines=counterfactual_img, + target=real_class, ) - return attribution[0] + return attribution class DInGrad(BaseAttribution): """ Discriminative version of the InputxGradient attribution method. """ - def __init__(self, classifier): - super().__init__(classifier) + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) self.saliency = attr.Saliency(self.classifier) def _attribute(self, real_img, counterfactual_img, real_class, target_class): @@ -92,7 +146,90 @@ def _attribute(self, real_img, counterfactual_img, real_class, target_class): # grads_fake = self.saliency.attribute(counterfactual_img, # target=target_class) # ingrad_diff_0 = grads_fake * (real_img - counterfactual_img) - grads_real = self.saliency.attribute(real_img[None, ...].cuda(), - target=real_class).detach().cpu() - ingrad_diff_1 = grads_real * (counterfactual_img[None, ...] - real_img[None, ...]) - return ingrad_diff_1[0] \ No newline at end of file + grads_real = self.saliency.attribute(real_img, target=real_class).detach().cpu() + ingrad_diff_1 = grads_real * (counterfactual_img - real_img) + return ingrad_diff_1 + + +class VanillaIntegratedGradients(BaseAttribution): + """Wrapper class for Integrated Gradients from Captum. + + Allows us to use it as a baseline. + """ + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) + self.ig = attr.IntegratedGradients(classifier) + + def _attribute(self, real_img, counterfactual_img, real_class, target_class): + batched_attribution = ( + self.ig.attribute(real_img, target=real_class).detach().cpu() + ) + return batched_attribution + + +class VanillaDeepLift(BaseAttribution): + """Wrapper class for DeepLift from Captum. + + Allows us to use it as a baseline. + """ + + def __init__(self, classifier, normalize=True): + super().__init__(classifier, normalize=normalize) + self.dl = attr.DeepLift(classifier) + + def _attribute(self, real_img, counterfactual_img, real_class, target_class): + batched_attribution = ( + self.dl.attribute(real_img, target=real_class).detach().cpu() + ) + return batched_attribution + + +class AttributionIO: + """ + Running the attribution methods on the images. + Storing the results in the output directory. + + """ + + def __init__(self, attributions: dict[str, BaseAttribution], output_directory: str): + self.attributions = attributions + self.output_directory = Path(output_directory) + + def get_directory(self, attr_name: str, source_class: str, target_class: str): + directory = self.output_directory / f"{attr_name}/{source_class}/{target_class}" + directory.mkdir(parents=True, exist_ok=True) + return directory + + def run( + self, + source_directory: str, + counterfactual_directory: str, + transform: Callable, + device: str = "cuda", + ): + if device == "cuda": + if not torch.cuda.is_available(): + raise ValueError("CUDA is not available on this machine.") + print("Loading paired data") + dataset = PairedImageDataset( + source_directory, counterfactual_directory, transform=transform + ) + print("Running attributions") + for sample in tqdm(dataset, total=len(dataset)): + for attr_name, attribution in self.attributions.items(): + attr = attribution.attribute( + sample.image, + sample.counterfactual, + sample.source_class_index, + sample.target_class_index, + device=device, + ) + # Store the attribution + np.save( + self.get_directory( + attr_name, sample.source_class, sample.target_class + ) + / f"{sample.path.stem}.npy", + attr, + ) diff --git a/src/quac/data.py b/src/quac/data.py index a789a73..56efa06 100644 --- a/src/quac/data.py +++ b/src/quac/data.py @@ -1,4 +1,8 @@ +from dataclasses import dataclass +import numpy as np import os +from pathlib import Path +import torch from torch.utils.data import Dataset from torchvision.datasets.folder import ( default_loader, @@ -23,22 +27,12 @@ def find_classes(directory): return classes, class_to_idx -def make_dataset( +def check_requirements( directory: str, - paired_directory: str, class_to_idx: Optional[Dict[str, int]] = None, extensions: Optional[Union[str, Tuple[str, ...]]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, -) -> List[Tuple[str, str, int, int]]: - """Generates a list of samples of a form (path_to_sample, class). - - See :class:`DatasetFolder` for details. - - Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function - by default. - """ - directory = os.path.expanduser(directory) - +): if class_to_idx is None: _, class_to_idx = find_classes(directory) elif not class_to_idx: @@ -46,6 +40,9 @@ def make_dataset( "'class_to_index' must have at least one entry to collect any samples." ) + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: @@ -60,6 +57,112 @@ def is_valid_file(x: str) -> bool: is_valid_file = cast(Callable[[str], bool], is_valid_file) + return is_valid_file + + +def make_counterfactual_dataset( + counterfactual_directory: str, + class_to_idx: Optional[Dict[str, int]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, int]]: + """Generates a list of samples of a form (path_to_sample, class) + for data organized in a counterfactual style directory. + + The dataset is organized in the following way: + ``` + root_directory/ + ├── class_x + | └── class_y + │ ├── xxx.ext + │ ├── xxy.ext + │ └── xxz.ext + └── class_y + └── class_x + ├── 123.ext + ├── nsdf3.ext + └── ... + └── asd932_.ext + ``` + + We want to use the most nested subdirectories as the class labels. + """ + directory = os.path.expanduser(counterfactual_directory) + if class_to_idx is None: + _, class_to_idx = find_classes(directory) + + is_valid_file = check_requirements( + counterfactual_directory, + class_to_idx, + extensions, + is_valid_file, + ) + + instances = [] + available_classes = set() + for source_class in sorted(class_to_idx.keys()): + source_dir = os.path.join(directory, source_class) + if not os.path.isdir(source_dir): + continue + target_directories = {} + for target_class in sorted(class_to_idx.keys()): + if target_class == source_class: + continue + target_dir = os.path.join(directory, source_class, target_class) + if os.path.isdir(target_dir): + target_directories[target_class] = target_dir + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + assert source_class != target_class + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + item = ( + path, + class_to_idx[source_class], + class_to_idx[target_class], + ) + instances.append(item) + + if target_class not in available_classes: + available_classes.add(source_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes and not allow_empty: + msg = ( + f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + ) + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + +def make_paired_dataset( + directory: str, + paired_directory: str, + class_to_idx: Optional[Dict[str, int]], + extensions: Optional[Union[str, Tuple[str, ...]]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, str, int, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + See :class:`DatasetFolder` for details. + + Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function + by default. + """ + directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = find_classes(directory) + + is_valid_file = check_requirements( + directory, class_to_idx, extensions, is_valid_file + ) + instances = [] available_classes = set() for source_class in sorted(class_to_idx.keys()): @@ -74,26 +177,117 @@ def is_valid_file(x: str) -> bool: target_dir = os.path.join(paired_directory, source_class, target_class) if os.path.isdir(target_dir): target_directories[target_class] = target_dir - for root, _, fnames in sorted(os.walk(source_dir, followlinks=True)): - for fname in sorted(fnames): - path = os.path.join(root, fname) - if is_valid_file(path): - for target_class, target_dir in target_directories.items(): - target_path = os.path.join(target_dir, fname) - if os.path.isfile(target_path) and is_valid_file(target_path): - item = ( - path, - target_path, - class_index, - class_to_idx[target_class], + for root, _, fnames in sorted(os.walk(source_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + for target_class, target_dir in target_directories.items(): + target_path = os.path.join(target_dir, fname) + if os.path.isfile(target_path) and is_valid_file( + target_path + ): + item = ( + path, + target_path, + class_index, + class_to_idx[target_class], + ) + instances.append(item) + + if source_class not in available_classes: + available_classes.add(source_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes and not allow_empty: + msg = ( + f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + ) + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + +def make_paired_attribution_dataset( + directory: str, + paired_directory: str, + attribution_directory: str, + class_to_idx: Optional[Dict[str, int]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, str, int, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + See :class:`DatasetFolder` for details. + + Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function + by default. + """ + directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = find_classes(directory) + + is_valid_file = check_requirements( + directory, class_to_idx, extensions, is_valid_file + ) + + instances = [] + available_classes = set() + for source_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[source_class] + source_dir = os.path.join(directory, source_class) + if not os.path.isdir(source_dir): + continue + target_directories = {} + attribution_directories = {} + for target_class in sorted(class_to_idx.keys()): + if target_class == source_class: + continue + target_dir = os.path.join(paired_directory, source_class, target_class) + if os.path.isdir(target_dir): + target_directories[target_class] = target_dir + # Add the attribution directory as well. It is organized in the same way + # as the paire directory is + attribution_dir = os.path.join( + attribution_directory, source_class, target_class + ) + if os.path.isdir(attribution_dir): + attribution_directories[target_class] = attribution_dir + + for root, _, fnames in sorted(os.walk(source_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path): + for target_class, target_dir in target_directories.items(): + target_path = os.path.join(target_dir, fname) + # attribution path must replace the extension to npy + attr_filename = fname.split(".")[-2] + ".npy" + attr_path = os.path.join( + attribution_directories[target_class], attr_filename ) - instances.append(item) - if source_class not in available_classes: - available_classes.add(source_class) + if ( + os.path.isfile(target_path) + and is_valid_file(target_path) + and os.path.isfile(attr_path) + ): + item = ( + path, + target_path, + attr_path, + class_index, + class_to_idx[target_class], + ) + instances.append(item) + + if source_class not in available_classes: + available_classes.add(source_class) empty_classes = set(class_to_idx.keys()) - available_classes - if empty_classes: + if empty_classes and not allow_empty: msg = ( f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " ) @@ -104,8 +298,56 @@ def is_valid_file(x: str) -> bool: return instances -class PairedImageFolders(Dataset): - def __init__(self, source_directory, paired_directory, transform=None): +@dataclass +class Sample: + image: torch.Tensor + source_class_index: int + path: Path = None + source_class: str = None + + +# TODO remove? +@dataclass +class CounterfactualSample: + counterfactual: torch.Tensor + target_class_index: int + source_class_index: int + path: Path = None + counterfactual_path: Path = None + source_class: str = None + target_class: str = None + + +@dataclass +class PairedSample: + image: torch.Tensor + counterfactual: torch.Tensor + source_class_index: int + target_class_index: int + path: Path = None + counterfactual_path: Path = None + source_class: str = None + target_class: str = None + + +@dataclass +class SampleWithAttribution: + attribution: np.ndarray + image: torch.Tensor + counterfactual: torch.Tensor + source_class_index: int + target_class_index: int + path: Path = None + counterfactual_path: Path = None + source_class: str = None + target_class: str = None + attribution_path: Path = None + + +class PairedImageDataset(Dataset): + def __init__( + self, source_directory, paired_directory, transform=None, allow_empty=True + ): """A dataset that loads images from paired directories, where one has images generated based on the other. @@ -138,16 +380,22 @@ def __init__(self, source_directory, paired_directory, transform=None): └── ... └── asd932_.ext ``` - Note that this will not work if the file names do not match! + note:: this will not work if the file names do not match! + + note:: the transform is applied sequentially to the image, counterfactual, and attribution. + This means that if there is any randomness in the transform, the three images will faie to match. + Additionally, the attribution will be a torch tensor when the transform is applied, so no PIL-only transforms + can be used. """ classes, class_to_idx = find_classes(source_directory) self.classes = classes self.class_to_idx = class_to_idx - self.samples = make_dataset( + self.samples = make_paired_dataset( source_directory, paired_directory, class_to_idx, is_valid_file=is_image_file, + allow_empty=allow_empty, ) self.transform = transform @@ -159,14 +407,111 @@ def __getitem__(self, index): if self.transform is not None: sample = self.transform(sample) target_sample = self.transform(target_sample) - output = { - "sample_path": path, - "target_path": target_path, - "sample": sample, - "target_sample": target_sample, - "class_index": class_index, - "target_class_index": target_class_index, - } + output = PairedSample( + path=Path(path), + counterfactual_path=Path(target_path), + image=sample, + counterfactual=target_sample, + source_class_index=class_index, + target_class_index=target_class_index, + source_class=self.classes[class_index], + target_class=self.classes[target_class_index], + ) + return output + + def __len__(self): + return len(self.samples) + + +class CounterfactualDataset(Dataset): + def __init__(self, counterfactual_directory, transform=None, allow_empty=True): + classes, class_to_idx = find_classes(counterfactual_directory) + self.classes = classes + self.class_to_idx = class_to_idx + self.samples = make_counterfactual_dataset( + counterfactual_directory, + class_to_idx, + is_valid_file=is_image_file, + allow_empty=allow_empty, + ) + self.transform = transform + + def __getitem__(self, index): + path, source_class_index, target_class_index = self.samples[index] + sample = default_loader(path) + if self.transform is not None: + sample = self.transform(sample) + output = CounterfactualSample( + counterfactual_path=Path(path), + counterfactual=sample, + source_class_index=source_class_index, + source_class=self.classes[source_class_index], + target_class_index=target_class_index, + target_class=self.classes[target_class_index], + ) + return output + + def __len__(self): + return len(self.samples) + + +class PairedWithAttribution(Dataset): + """This dataset returns both the original and counterfactual images, + as well as an attribution heatmap. + + note:: the transform is applied sequentially to the image, counterfactual. + This means that if there is any randomness in the transform, the images will fail to match. + Additionally, no transform is applied to the attribution. + """ + + def __init__( + self, + source_directory, + paired_directory, + attribution_directory, + transform=None, + allow_empty=True, + ): + classes, class_to_idx = find_classes(source_directory) + self.classes = classes + self.class_to_idx = class_to_idx + self.samples = make_paired_attribution_dataset( + source_directory, + paired_directory, + attribution_directory, + class_to_idx, + is_valid_file=is_image_file, + allow_empty=allow_empty, + ) + self.transform = transform + + def __getitem__(self, index) -> SampleWithAttribution: + ( + path, + target_path, + attribution_path, + class_index, + target_class_index, + ) = self.samples[index] + sample = default_loader(path) + target_sample = default_loader(target_path) + attribution = np.load(attribution_path) + if self.transform is not None: + sample = self.transform(sample) + target_sample = self.transform(target_sample) + + output = SampleWithAttribution( + path=Path(path), + counterfactual_path=Path(target_path), + attribution_path=Path(attribution_path), + image=sample, + counterfactual=target_sample, + attribution=attribution, + source_class_index=class_index, + target_class_index=target_class_index, + source_class=self.classes[class_index], + target_class=self.classes[target_class_index], + ) return output def __len__(self): diff --git a/src/quac/evaluation.py b/src/quac/evaluation.py index c6058bd..4608176 100644 --- a/src/quac/evaluation.py +++ b/src/quac/evaluation.py @@ -1,7 +1,14 @@ +import cv2 +from dataclasses import dataclass import numpy as np +from pathlib import Path +from quac.data import PairedImageDataset, CounterfactualDataset, PairedWithAttribution +from quac.report import Report +from sklearn.metrics import classification_report, confusion_matrix +from torchvision.datasets import ImageFolder from torch.nn import functional as F -import cv2 import torch +from tqdm import tqdm def image_to_tensor(image, device=None): @@ -12,36 +19,236 @@ def image_to_tensor(image, device=None): image_tensor = image_tensor.unsqueeze(0).unsqueeze(0) elif len(np.shape(image)) == 3: image_tensor = image_tensor.unsqueeze(0) + elif len(np.shape(image)) == 4: + return image_tensor.float() else: - raise ValueError("Input shape not understood") + raise ValueError(f"Input shape not understood, {image.shape}") return image_tensor.float() -class Evaluator: - """This class evaluates the quality of an attribution using the QuAC method. +class Processor: + """Class that turns attributions into masks.""" + + def __init__( + self, gaussian_kernel_size=11, struc=10, channel_wise=True, name="default" + ): + self.gaussian_kernel_size = gaussian_kernel_size + self.kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struc, struc)) + self.channel_wise = channel_wise + self.name = name + + def create_mask(self, attribution, threshold, return_size=True): + channels, _, _ = attribution.shape + mask_size = 0 + mask = [] + # construct mask channel by channel + for c in range(channels): + # threshold + if self.channel_wise: + channel_mask = attribution[c, :, :] > threshold + else: + channel_mask = np.any(attribution > threshold, axis=0) + # TODO explain the reasoning behind the morphological closing + # Morphological closing + channel_mask = cv2.morphologyEx( + channel_mask.astype(np.uint8), cv2.MORPH_CLOSE, self.kernel + ) + # TODO This might be misleading, given the blur afterwards + mask_size += np.sum(channel_mask) + # TODO Add connected components + # Blur + mask.append( + cv2.GaussianBlur( + channel_mask.astype(np.float32), + (self.gaussian_kernel_size, self.gaussian_kernel_size), + 0, + ) + ) + # TODO should we do this instead? + # mask_size += np.sum(mask) + if not return_size: + return np.array(mask) + return np.array(mask), mask_size + - It it is based on the assumption that there exists a counterfactual for each image. +class UnblurredProcessor(Processor): """ + Processor without any blurring + """ + + def __init__(self, struc=10, channel_wise=True, name="no_blur"): + self.kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (struc, struc)) + self.channel_wise = channel_wise + self.name = name + + def create_mask(self, attribution, threshold, return_size=True): + channels, _, _ = attribution.shape + mask_size = 0 + mask = [] + # construct mask channel by channel + for c in range(channels): + # threshold + if self.channel_wise: + channel_mask = attribution[c, :, :] > threshold + else: + channel_mask = np.any(attribution > threshold, axis=0) + # Morphological closing + channel_mask = cv2.morphologyEx( + channel_mask.astype(np.uint8), cv2.MORPH_CLOSE, self.kernel + ) + mask.append(channel_mask) + mask_size += np.sum(channel_mask) + if not return_size: + return np.array(mask) + return np.array(mask), mask_size + + +class BaseEvaluator: + """Base class for evaluating attributions.""" def __init__( self, classifier, - sigma=11, - struc=10, - channel_wise=False, + source_dataset=None, + paired_dataset=None, + attribution_dataset=None, num_thresholds=200, device=None, ): + """Initializes the evaluator. + + It requires three different datasets: the source dataset, the counterfactual dataset and the attribution dataset. + All of them must return objects in the forms of the dataclasses in `quac.data`. + + + Parameters + ---------- + classifier: + The classifier to be used for the evaluation. + source_dataset: + The source dataset must returns a `quac.data.Sample` object in its `__getitem__` method. + paired_dataset: + The paired dataset must returns a `quac.data.PairedSample` object in its `__getitem__` method. + attribution_dataset: + The attribution dataset must returns a `quac.data.SampleWithAttribution` object in its `__getitem__` method. + """ self.device = device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.classifier = classifier.to(self.device) - self.sigma = sigma - self.struc = struc - self.channel_wise = channel_wise self.num_thresholds = num_thresholds + self.classifier = classifier + + self._source_dataset = source_dataset + self._paired_dataset = paired_dataset + self._dataset_with_attribution = attribution_dataset + + @property + def source_dataset(self): + return self._source_dataset + + @property + def paired_dataset(self): + return self._paired_dataset - def evaluate(self, x, x_t, y, y_t, attribution, predictions, vmin=-1, vmax=1): + @property + def dataset_with_attribution(self): + return self._dataset_with_attribution + + def _source_classification_report( + self, return_classification=False, print_report=True + ): + """ + Classify the source data and return the confusion matrix. + """ + pred = [] + target = [] + for image, source_class_index in tqdm(self.source_dataset): + pred.append(self.run_inference(image).argmax()) + target.append(source_class_index) + + if print_report: + print(classification_report(target, pred)) + + cm = confusion_matrix(target, pred, normalize="true") + if return_classification: + return cm, pred, target + return cm + + def _counterfactual_classification_report( + self, + return_classification=False, + print_report=True, + ): + """ + Classify the counterfactual data and return the confusion matrix. + """ + pred = [] + source = [] + target = [] + for sample in tqdm(self.counterfactual_dataset): + pred.append(self.run_inference(sample.counterfactual).argmax()) + target.append(sample.target_class_index) + source.append(sample.source_class_index) + + if print_report: + print(classification_report(target, pred)) + + cm = confusion_matrix(target, pred, normalize="true") + if return_classification: + return cm, pred, source, target + return cm + + def classification_report( + self, + data="counterfactuals", + return_classification=False, + print_report=True, + ): + """ + Classify the data and return the confusion matrix. + """ + if data == "counterfactuals": + return self._counterfactual_classification_report( + return_classification=return_classification, + print_report=print_report, + ) + elif data == "source": + return self._source_classification_report( + return_classification=return_classification, + print_report=print_report, + ) + else: + raise ValueError(f"Data must be 'counterfactuals' or 'source', not {data}") + + def quantify(self, processor=None): + if processor is None: + processor = Processor() + report = Report(name=processor.name) + for inputs in tqdm(self.dataset_with_attribution): + predictions = { + "original": self.run_inference(inputs.image)[0], + "counterfactual": self.run_inference(inputs.counterfactual)[0], + } + results = self.evaluate( + inputs.image, + inputs.counterfactual, + inputs.source_class_index, + inputs.target_class_index, + inputs.attribution, + predictions, + processor, + ) + report.accumulate( + inputs, + predictions, + results, + ) + return report + + def evaluate( + self, x, x_t, y, y_t, attribution, predictions, processor, vmin=-1, vmax=1 + ): """ Run QuAC evaluation on the data point. @@ -51,7 +258,9 @@ def evaluate(self, x, x_t, y, y_t, attribution, predictions, vmin=-1, vmax=1): x_t: the counterfactual image y: the class of the input image y_t: the class of the counterfactual image - attrihbution: the attribution map + attribution: the attribution map + predictions: the predictions of the classifier + processor: the attribution processing function (to get mask) vmin: the minimal possible value of the attribution, to be used for thresholding. Defaults to -1 vmax: the maximal possible value of the attribution, to be used for thresholding. Defaults to 1. """ @@ -59,6 +268,7 @@ def evaluate(self, x, x_t, y, y_t, attribution, predictions, vmin=-1, vmax=1): # "real" changes into "fake_class" classification_real = predictions["original"] + # TODO remove the need for this results = { "thresholds": [], "hybrids": [], @@ -67,21 +277,23 @@ def evaluate(self, x, x_t, y, y_t, attribution, predictions, vmin=-1, vmax=1): } for threshold in np.arange(vmin, vmax, (vmax - vmin) / self.num_thresholds): # soft mask of the parts to copy - mask, mask_size = self.create_mask(attribution, threshold) - + mask, mask_size = processor.create_mask(attribution, threshold) # hybrid = real parts copied to fake hybrid = x_t * mask + x * (1.0 - mask) - - classification_hybrid = self.run_inference(hybrid)[0] - - score_change = classification_hybrid[y_t] - classification_real[y_t] - # Append results # TODO Do we want to store the hybrid? results["thresholds"].append(threshold) results["hybrids"].append(hybrid) results["mask_sizes"].append(mask_size / np.prod(x.shape)) - results["score_change"].append(score_change) + + # Classification + # classification_hybrid = self.run_inference(hybrid)[0] + # score_change = classification_hybrid[y_t] - classification_real[y_t] + # results["score_change"].append(score_change) + hybrid = np.stack(results["hybrids"], axis=0) + classification_hybrid = self.run_inference(hybrid) + score_change = classification_hybrid[:, y_t] - classification_real[y_t] + results["score_change"] = score_change return results @torch.no_grad() @@ -94,30 +306,71 @@ def run_inference(self, im): class_probs = F.softmax(self.classifier(im_tensor), dim=1).cpu().numpy() return class_probs - def create_mask(self, attribution, threshold): - channels, _, _ = attribution.shape - # TODO find a way to get rid of this - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (self.struc, self.struc)) - mask_size = 0 - mask = [] - # construct mask channel by channel - for c in range(channels): - # threshold - if self.channel_wise: - channel_mask = attribution[c, :, :] > threshold - else: - channel_mask = np.any(attribution > threshold, axis=0) - # morphological closing - channel_mask = cv2.morphologyEx( - channel_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel - ) - # TODO This might be misleading, given the blur afterwards - mask_size += np.sum(channel_mask) - # TODO Add connected components - # Blur - mask.append( - cv2.GaussianBlur( - channel_mask.astype(np.float32), (self.sigma, self.sigma), 0 - ) - ) - return np.array(mask), mask_size + +class Evaluator(BaseEvaluator): + """This class evaluates the quality of an attribution using the QuAC method. + + Raises: + FileNotFoundError: If the source, counterfactual or attribution directories do not exist. + """ + + def __init__( + self, + classifier, + source_directory, + counterfactual_directory, + attribution_directory, + transform=None, + num_thresholds=200, + device=None, + ): + # Check that they all exist + for directory in [ + source_directory, + counterfactual_directory, + attribution_directory, + ]: + if not Path(directory).exists(): + raise FileNotFoundError(f"Directory {directory} does not exist") + + super().__init__( + classifier, None, None, None, num_thresholds=num_thresholds, device=device + ) + self.transform = transform + self.source_directory = source_directory + self.counterfactual_directory = counterfactual_directory + self.attribution_directory = attribution_directory + + @property + def source_dataset(self): + # NOTE: Recomputed each time, but should be used sparingly. + dataset = ImageFolder(self.source_directory, transform=self.transform) + return dataset + + @property + def counterfactual_dataset(self): + # NOTE: Recomputed each time, but should be used sparingly. + dataset = CounterfactualDataset( + self.counterfactual_directory, transform=self.transform + ) + return dataset + + @property + def paired_dataset(self): + # NOTE: Recomputed each time, but should be used sparingly. + dataset = PairedImageDataset( + self.source_directory, + self.counterfactual_directory, + transform=self.transform, + ) + return dataset + + @property + def dataset_with_attribution(self): + dataset = PairedWithAttribution( + self.source_directory, + self.counterfactual_directory, + self.attribution_directory, + transform=self.transform, + ) + return dataset diff --git a/src/quac/generate/__init__.py b/src/quac/generate/__init__.py index 0d8c1e2..e81180c 100644 --- a/src/quac/generate/__init__.py +++ b/src/quac/generate/__init__.py @@ -1,4 +1,4 @@ -"""Utilities for generating counterfacual images.""" +"""Utilities for generating counterfactual images.""" from .model import LatentInferenceModel, ReferenceInferenceModel from .data import LabelFreePngFolder @@ -7,6 +7,7 @@ from quac.training.classification import ClassifierWrapper import torch from torchvision import transforms +from typing import Union logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) @@ -16,7 +17,9 @@ class CounterfactualNotFound(Exception): pass -def load_classifier(checkpoint, mean=0.5, std=0.5, eval=True, device=None): +def load_classifier( + checkpoint, mean=0.5, std=0.5, eval=True, assume_normalized=False, device=None +): """ Load a classifier from a torchscript checkpoint. @@ -81,6 +84,7 @@ def load_stargan( checkpoint_iter: int = 100000, kind="latent", single_output_encoder: bool = False, + final_activation: Union[str, None] = None, ) -> torch.nn.Module: """ Load an inference version of the StarGANv2 model from a checkpoint. @@ -108,6 +112,7 @@ def load_stargan( latent_dim=latent_dim, num_domains=num_domains, single_output_encoder=single_output_encoder, + final_activation=final_activation, ) else: latent_inference_model = LatentInferenceModel( @@ -117,6 +122,7 @@ def load_stargan( style_dim=style_dim, latent_dim=latent_dim, num_domains=num_domains, + final_activation=final_activation, ) latent_inference_model.load_checkpoint(checkpoint_iter) latent_inference_model.eval() @@ -134,9 +140,12 @@ def get_counterfactual( batch_size=10, device=None, max_tries=100, - best_pred_so_far=0, + best_pred_so_far=None, best_cf_so_far=None, + best_cf_path_so_far=None, error_if_not_found=False, + return_path=False, + return_pred=False, ) -> torch.Tensor: """ Tries to find a counterfactual for the given sample, given the target. @@ -153,6 +162,8 @@ def get_counterfactual( device: the device to use max_tries: the maximum number of tries to find a counterfactual error_if_not_found: whether to raise an error if no counterfactual is found, if set to False, the best counterfactual found so far is returned + return_path: whether to return the path of the reference used to create best counterfactual found so far, + only used if kind is "reference" Returns: a counterfactual @@ -160,6 +171,8 @@ def get_counterfactual( Raises: CounterfactualNotFound: if no counterfactual is found after max_tries tries """ + if best_pred_so_far is None: + best_pred_so_far = torch.zeros(target + 1) # Copy x batch_size times x_multiple = torch.stack([x] * batch_size) if kind == "reference": @@ -172,12 +185,13 @@ def get_counterfactual( f"Not enough reference images, reducing max_tries to {max_tries}." ) # Get a batch of reference images, starting from batch_size * max_tries, of size batch_size - ref_batch = torch.stack( - [ - dataset_ref[i][0] + ref_batch, ref_paths = zip( + *[ + dataset_ref[i] for i in range(batch_size * (max_tries - 1), batch_size * max_tries) ] ) + ref_batch = torch.stack(ref_batch) # Generate batch_size counterfactuals xcf = latent_inference_model( x_multiple.to(device), @@ -197,9 +211,13 @@ def get_counterfactual( predictions = torch.argmax(p, dim=-1) # Get best so far best_idx_so_far = torch.argmax(p[:, target]) - if p[best_idx_so_far, target] > best_pred_so_far: - best_pred_so_far = p[best_idx_so_far, target] - best_cf_so_far = xcf[best_idx_so_far].cpu().numpy() + if p[best_idx_so_far, target] > best_pred_so_far[target]: + best_pred_so_far = p[best_idx_so_far] # , target] + best_cf_so_far = xcf[best_idx_so_far].cpu() + if kind == "reference": + best_cf_path_so_far = ref_paths[best_idx_so_far] + else: + best_cf_path_so_far = None # Get the indices of the correct predictions indices = torch.where(predictions == target)[0] @@ -220,6 +238,9 @@ def get_counterfactual( max_tries - 1, best_pred_so_far=best_pred_so_far, best_cf_so_far=best_cf_so_far, + best_cf_path_so_far=best_cf_path_so_far, + return_path=return_path, + return_pred=return_pred, ) else: if error_if_not_found: @@ -230,4 +251,10 @@ def get_counterfactual( f"Counterfactual not found after {max_tries} tries, using best so far." ) # Return the best counterfactual so far + if return_path and kind == "reference": + if return_pred: + return best_cf_so_far, best_cf_path_so_far, best_pred_so_far + return best_cf_so_far, best_cf_path_so_far + if return_pred: + return best_cf_so_far, best_pred_so_far return best_cf_so_far diff --git a/src/quac/generate/data.py b/src/quac/generate/data.py index 8d6aec1..0e08406 100644 --- a/src/quac/generate/data.py +++ b/src/quac/generate/data.py @@ -4,6 +4,7 @@ class LabelFreePngFolder(torch.utils.data.Dataset): + # TODO Move to quac.data """Get all images in a folder, no subfolders, no labels.""" def __init__(self, root, transform=None): diff --git a/src/quac/generate/model.py b/src/quac/generate/model.py index bd23183..d2ad1b5 100644 --- a/src/quac/generate/model.py +++ b/src/quac/generate/model.py @@ -20,9 +20,12 @@ def __init__( latent_dim, input_dim=1, num_domains=6, + final_activation=None, ) -> None: super().__init__() - generator = Generator(img_size, style_dim, input_dim=input_dim) + generator = Generator( + img_size, style_dim, input_dim=input_dim, final_activation=final_activation + ) mapping_network = MappingNetwork(latent_dim, style_dim, num_domains=num_domains) self.nets = torch.nn.ModuleDict( @@ -58,14 +61,17 @@ def __init__( self, checkpoint_dir, img_size, - style_dim, - latent_dim, + style_dim=64, + latent_dim=16, input_dim=1, num_domains=6, single_output_encoder=False, + final_activation=None, ) -> None: super().__init__() - generator = Generator(img_size, style_dim, input_dim=input_dim) + generator = Generator( + img_size, style_dim, input_dim=input_dim, final_activation=final_activation + ) if single_output_encoder: style_encoder = SingleOutputStyleEncoder( img_size, style_dim, num_domains, input_dim=input_dim diff --git a/src/quac/report.py b/src/quac/report.py index 669240a..e20fb2a 100644 --- a/src/quac/report.py +++ b/src/quac/report.py @@ -5,6 +5,7 @@ from scipy.interpolate import interp1d import torch from tqdm import tqdm +import warnings class Report: @@ -19,14 +20,12 @@ class Report: Optionally, it also stores the hybrids generated for each threshold. """ - def __init__(self, save_dir, name=None, metadata={}): - # TODO Set up all of the data in the namespace - self.save_dir = Path(save_dir) + def __init__(self, name=None, metadata={}): if name is None: - self.name = "attribution" + self.name = "report" else: self.name = name - self.attribution_dir = None + # Shows where the attribution is, if needed # TODO check that the metadata is JSON serializable self.metadata = metadata # Initialize as empty @@ -47,31 +46,18 @@ def __init__(self, save_dir, name=None, metadata={}): # Initialize interpolation values self.interp_mask_values = np.arange(0.0, 1.0001, 0.01) - def make_attribution_dir(self): - """Create a directory to store the attributions""" - if self.attribution_dir is None: - self.attribution_dir = self.save_dir / self.name - self.attribution_dir.mkdir(parents=True, exist_ok=True) - - def accumulate( - self, - inputs, - predictions, - attribution, - evaluation_results, - save_attribution=True, - save_intermediates=False, - ): + def accumulate(self, inputs, predictions, evaluation_results): """ Store a new result. If `save_intermediates` is `True`, the hybrids are stored to disk. Otherwise they are discarded. """ # Store the input information - self.paths.append(inputs["sample_path"]) - self.target_paths.append(inputs["target_path"]) - self.labels.append(inputs["class_index"]) - self.target_labels.append(inputs["target_class_index"]) + self.paths.append(inputs.path) + self.target_paths.append(inputs.counterfactual_path) + self.labels.append(inputs.source_class_index) + self.target_labels.append(inputs.target_class_index) + self.attribution_paths.append(inputs.attribution_path) # Store the prediction results self.predictions.append(predictions["original"]) self.target_predictions.append(predictions["counterfactual"]) @@ -79,27 +65,7 @@ def accumulate( self.thresholds.append(evaluation_results["thresholds"]) self.normalized_mask_sizes.append(evaluation_results["mask_sizes"]) self.score_changes.append(evaluation_results["score_change"]) - # Store the attribution to disk - if save_attribution: - self.make_attribution_dir() - filename = Path(inputs["sample_path"]).stem - attribution_path = ( - self.attribution_dir / f"{filename}_{inputs['target_class_index']}.npy" - ) - with open(attribution_path, "wb") as fd: - np.save(fd, attribution) - self.attribution_paths.append(attribution_path) - - # Store the hybrids to disk - if save_intermediates: - for h in evaluation_results["hybrids"]: - # TODO store the hybrids - pass - - def load_attribution(self, index): - """Load the attribution for a given index""" - with open(self.attribution_paths[index], "rb") as fd: - return np.load(fd) + # TODO Store the hybrids to disk ? def interpolate_score_values(self, normalized_mask_sizes, score_changes): """Computes the score changes interpolated at the desired mask sizes""" @@ -150,12 +116,16 @@ def make_json_serializable(self, obj): return [self.make_json_serializable(x) for x in obj] return obj - def store(self): + def store(self, save_dir): """Store report to disk""" - self.save_dir.mkdir(parents=True, exist_ok=True) - with open(self.save_dir / f"{self.name}.json", "w") as fd: + if self.quac_scores is None: + self.compute_scores() + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + with open(save_dir / f"{self.name}.json", "w") as fd: json.dump( { + "metadata": self.metadata, "thresholds": self.make_json_serializable(self.thresholds), "normalized_mask_sizes": self.make_json_serializable( self.normalized_mask_sizes @@ -172,6 +142,7 @@ def store(self): "attribution_paths": self.make_json_serializable( self.attribution_paths ), + "quac_scores": self.make_json_serializable(self.quac_scores), }, fd, ) @@ -180,6 +151,7 @@ def load(self, filename): """Load report from disk""" with open(filename, "r") as fd: data = json.load(fd) + self.metadata = data.get("metadata", {}) self.thresholds = data["thresholds"] self.normalized_mask_sizes = data["normalized_mask_sizes"] self.score_changes = data["score_changes"] @@ -190,11 +162,31 @@ def load(self, filename): self.predictions = data.get("predictions", []) self.target_predictions = data.get("target_predictions", []) self.attribution_paths = data.get("attribution_paths", []) + self.quac_scores = data.get("quac_scores", None) + + def get_curve(self): + """Gets the median and IQR of the QuAC curve""" + # TODO Cache the results, takes forever otherwise + plot_values = [] + for normalized_mask_sizes, score_changes in zip( + self.normalized_mask_sizes, self.score_changes + ): + interp_score_values = self.interpolate_score_values( + normalized_mask_sizes, score_changes + ) + plot_values.append(interp_score_values) + plot_values = np.array(plot_values) + # mean = np.mean(plot_values, axis=0) + # std = np.std(plot_values, axis=0) + median = np.median(plot_values, axis=0) + p25 = np.percentile(plot_values, 25, axis=0) + p75 = np.percentile(plot_values, 75, axis=0) + return median, p25, p75 def plot_curve(self, ax=None): """Plot the QuAC curve - We plot the mean and standard deviation of the QuAC curve acrosss all accumulated results. + We plot the median and IQR of the QuAC curve acrosss all accumulated results. Parameters ---------- @@ -204,28 +196,56 @@ def plot_curve(self, ax=None): if ax is None: fig, ax = plt.subplots() - plot_values = [] - for normalized_mask_sizes, score_changes in zip( - self.normalized_mask_sizes, self.score_changes - ): - interp_score_values = self.interpolate_score_values( - normalized_mask_sizes, score_changes - ) - plot_values.append(interp_score_values) - plot_values = np.array(plot_values) - mean = np.mean(plot_values, axis=0) - std = np.std(plot_values, axis=0) + mean, p25, p75 = self.get_curve() + ax.plot(self.interp_mask_values, mean, label=self.name) - ax.fill_between(self.interp_mask_values, mean - std, mean + std, alpha=0.2) + ax.fill_between(self.interp_mask_values, p25, p75, alpha=0.2) if ax is None: plt.show() + def optimal_thresholds(self, min_percentage=0.0): + """Get the optimal threshold for each sample + + The optimal threshold has a minimal mask size, and maximizes the score change. + We optimize $|m| - \delta f$ where $m$ is the mask size and $\delta f$ is the score change. + + Parameters + ---------- + min_percentage: float + The optimal threshold chosen needs to account for at least this percentage of total score change. + Increasing this value will favor high percentage changes even when they require larger masks. + """ + mask_scores = np.array(self.score_changes) + mask_sizes = np.array(self.normalized_mask_sizes) + thresholds = np.array(self.thresholds) + tradeoff_scores = np.abs(mask_sizes) - mask_scores + # Determine what to ignore + if min_percentage > 0.0: + min_value = np.min(mask_scores, axis=1) + max_value = np.max(mask_scores, axis=1) + threshold = min_value + min_percentage * (max_value - min_value) + below_threshold = mask_scores < threshold[:, None] + tradeoff_scores[ + below_threshold + ] = np.inf # Ignores the points with not enough score change + thr_idx = np.argmin(tradeoff_scores, axis=1) + + optimal_thresholds = np.take_along_axis( + thresholds, thr_idx[:, None], axis=1 + ).squeeze() + return optimal_thresholds + def get_optimal_threshold(self, index, return_index=False): + # TODO Deprecate, use vectorized version! + warnings.warn( + "This function is deprecated, please use the vectorized version instead.", + DeprecationWarning, + ) mask_scores = np.array(self.score_changes[index]) mask_sizes = np.array(self.normalized_mask_sizes[index]) pareto_scores = mask_sizes**2 + (1 - mask_scores) ** 2 thr_idx = np.argmin(pareto_scores) if return_index: - return thr_idx, self.thresholds[index][thr_idx] + return self.thresholds[index][thr_idx], thr_idx return self.thresholds[index][thr_idx] diff --git a/src/quac/training/classification.py b/src/quac/training/classification.py index 9f37e0c..cf392a1 100644 --- a/src/quac/training/classification.py +++ b/src/quac/training/classification.py @@ -2,22 +2,41 @@ from torchvision import transforms +class Identity(torch.nn.Module): + def forward(self, x): + return x + + class ClassifierWrapper(torch.nn.Module): """ This class expects a torchscript model. See [here](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format) for how to convert a model to torchscript. """ - def __init__(self, model_checkpoint, mean, std): + def __init__( + self, + model_checkpoint, + mean=None, + std=None, + assume_normalized=False, + do_nothing=False, + ): """Wraps a torchscript model, and applies normalization.""" super().__init__() self.model = torch.jit.load(model_checkpoint) self.model.eval() self.transform = transforms.Normalize(mean, std) + if mean is None: + self.transform = Identity() + self.assume_normalized = assume_normalized # TODO Remove this, it's in forward + self.do_nothing = do_nothing - def forward(self, x): + def forward(self, x, assume_normalized=False, do_nothing=False): """Assumes that x is between -1 and 1.""" + if do_nothing or self.do_nothing: + return self.model(x) # TODO it would be even better if the range was between 0 and 1 so we wouldn't have to do the below - x = (x + 1) / 2 + if not self.assume_normalized and not assume_normalized: + x = (x + 1) / 2 x = self.transform(x) return self.model(x) diff --git a/src/quac/training/config.py b/src/quac/training/config.py index 9aad9d1..d40c846 100644 --- a/src/quac/training/config.py +++ b/src/quac/training/config.py @@ -1,4 +1,5 @@ from pydantic import BaseModel +from typing import Optional, Union, Literal class ModelConfig(BaseModel): @@ -6,27 +7,82 @@ class ModelConfig(BaseModel): style_dim: int = 64 latent_dim: int = 16 num_domains: int = 5 + input_dim: int = 3 + final_activation: str = "tanh" class DataConfig(BaseModel): - train_img_dir: str + source: str + reference: str img_size: int = 128 - batch_size: int = 16 - randcrop_prob: float = 0.0 + batch_size: int = 1 num_workers: int = 4 + grayscale: bool = False + mean: Optional[float] = 0.5 + std: Optional[float] = 0.5 + rand_crop_prob: Optional[float] = 0 -class TrainConfig(BaseModel): - f_lr: float = 1e-4 # Learning rate for the mapping network - lr: float = 1e-4 # Learning rate for the other networks - beta1: float = 0.5 # Beta1 for Adam optimizer - beta2: float = 0.999 # Beta2 for Adam optimizer - weight_decay: float = 1e-4 # Weight decay for Adam optimizer - latent_dim: int = 16 # Latent dimension for the mapping network - resume_iter: int = 0 # Iteration to resume training from - lamdba_ds: float = 1.0 # Weight for the diversity sensitive loss - total_iters: int = 100000 # Total number of iterations to train the model - ds_iter: int = 1000 # Number of iterations to optimize the diversity sensitive loss - log_every: int = 1000 # How often (iterations) to log training progress - save_every: int = 10000 # How often (iterations) to save the model - eval_every: int = 10000 # How often (iterations) to evaluate the model +class RunConfig(BaseModel): + resume_iter: int = 0 + total_iters: int = 100000 + log_every: int = 1000 + save_every: int = 10000 + eval_every: int = 10000 + + +class ValConfig(BaseModel): + classifier_checkpoint: str + num_outs_per_domain: int = 10 + mean: Optional[float] = 0.5 + std: Optional[float] = 0.5 + img_size: int = 128 + val_batch_size: int = 16 + assume_normalized: bool = False + do_nothing: bool = False + + +class LossConfig(BaseModel): + lambda_ds: float = 0.0 # No diversity by default + lambda_sty: float = 1.0 + lambda_cyc: float = 1.0 + lambda_reg: float = 1.0 + ds_iter: int = 100000 + + +class SolverConfig(BaseModel): + root_dir: str + f_lr: float = 1e-6 + lr: float = 1e-4 + beta1: float = 0.0 + beta2: float = 0.99 + weight_decay: float = 1e-4 + + +class WandBLogConfig(BaseModel): + project: str = "default" + name: str = "default" + notes: str = "" + tags: list = [] + + +class TensorboardLogConfig(BaseModel): + log_dir: str + comment: str = "" + + +class ExperimentConfig(BaseModel): + # Metadata for keeping track of experiments + log_type: Literal["wandb", "tensorboard"] = "wandb" + log: Union[WandBLogConfig, TensorboardLogConfig] = WandBLogConfig() + # Some input required + data: DataConfig + solver: SolverConfig + validation_data: DataConfig + validation_config: ValConfig + # No input required + model: ModelConfig = ModelConfig() + run: RunConfig = RunConfig() + loss: LossConfig = LossConfig() + # Optional + test_data: Optional[DataConfig] = None diff --git a/src/quac/training/data_loader.py b/src/quac/training/data_loader.py index 53970d0..508dc71 100644 --- a/src/quac/training/data_loader.py +++ b/src/quac/training/data_loader.py @@ -25,6 +25,16 @@ from torchvision.datasets import ImageFolder +class RGB: + def __call__(self, img): + if isinstance(img, Image.Image): + return img.convert("RGB") + else: # Tensor + if img.size(0) == 1: + return torch.cat([img, img, img], dim=0) + return img + + def listdir(dname): fnames = list( chain( @@ -102,7 +112,7 @@ def __init__(self, root, transform=None): self.transform = transform def _make_dataset(self, root): - domains = os.listdir(root) + domains = glob.glob(os.path.join(root, "*")) fnames, fnames2, labels = [], [], [] for idx, domain in enumerate(sorted(domains)): class_dir = os.path.join(root, domain) @@ -156,17 +166,18 @@ def get_train_loader( transform_list = [rand_crop] if grayscale: transform_list.append(transforms.Grayscale()) - - transform = transforms.Compose( - [ - *transform_list, - transforms.Resize([img_size, img_size]), - transforms.RandomHorizontalFlip(), - transforms.RandomVerticalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ] - ) + else: + transform_list.append(RGB()) + + transform_list += [ + transforms.Resize([img_size, img_size]), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ToTensor(), + ] + if mean is not None and std is not None: + transform_list.append(transforms.Normalize(mean=mean, std=std)) + transform = transforms.Compose(transform_list) if which == "source": # dataset = ImageFolder(root, transform) @@ -191,11 +202,13 @@ def get_eval_loader( root, img_size=256, batch_size=32, - imagenet_normalize=True, + imagenet_normalize=False, shuffle=True, num_workers=4, drop_last=False, grayscale=False, + mean=0.5, + std=0.5, ): print("Preparing DataLoader for the evaluation phase...") if imagenet_normalize: @@ -204,19 +217,24 @@ def get_eval_loader( std = [0.229, 0.224, 0.225] else: height, width = img_size, img_size - mean = 0.5 - std = 0.5 + + if mean is not None: + normalize = transforms.Normalize(mean=mean, std=std) + else: + normalize = transforms.Lambda(lambda x: x) transform_list = [] if grayscale: transform_list.append(transforms.Grayscale()) + else: + transform_list.append(RGB()) transform = transforms.Compose( [ *transform_list, transforms.Resize([height, width]), transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), + normalize, ] ) @@ -235,32 +253,39 @@ def get_test_loader( root, img_size=256, batch_size=32, - shuffle=True, + shuffle=False, + drop_last=False, num_workers=4, grayscale=False, mean=0.5, std=0.5, + return_dataset=False, ): print("Preparing DataLoader for the generation phase...") transform_list = [] if grayscale: transform_list.append(transforms.Grayscale()) - transform = transforms.Compose( - [ - *transform_list, - transforms.Resize([img_size, img_size]), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ] - ) + else: + transform_list.append(RGB()) + + transform_list += [ + transforms.Resize([img_size, img_size]), + transforms.ToTensor(), + ] + if mean is not None and std is not None: + transform_list.append(transforms.Normalize(mean=mean, std=std)) + transform = transforms.Compose(transform_list) dataset = ImageFolder(root, transform) + if return_dataset: + return dataset return data.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, + drop_last=drop_last, ) @@ -304,7 +329,7 @@ def __next__(self): z_trg2=z_trg2, ) elif self.mode == "val": - x_ref, y_ref = self._fetch_inputs() + x_ref, y_ref = self._fetch_refs() inputs = Munch(x_src=x, y_src=y, x_ref=x_ref, y_ref=y_ref) elif self.mode == "test": inputs = Munch(x=x, y=y) @@ -343,7 +368,7 @@ def __next__(self): z_trg2=z_trg2, ) elif self.mode == "val": - x_ref, _, y_ref = self._fetch_inputs() + x_ref, _, y_ref = self._fetch_refs() inputs = Munch(x_src=x, y_src=y, x_ref=x_ref, y_ref=y_ref) elif self.mode == "test": inputs = Munch(x=x, y=y) @@ -351,3 +376,192 @@ def __next__(self): raise NotImplementedError return Munch({k: v.to(self.device) for k, v in inputs.items()}) + + +class TrainingData: + def __init__( + self, + source, + reference, + img_size=128, + batch_size=8, + num_workers=4, + grayscale=False, + mean=None, + std=None, + rand_crop_prob=0, + ): + self.src = get_train_loader( + root=source, + which="source", + img_size=img_size, + batch_size=batch_size, + num_workers=num_workers, + grayscale=grayscale, + mean=mean, + std=std, + prob=rand_crop_prob, + ) + self.reference = get_train_loader( + root=reference, + which="reference", + img_size=img_size, + batch_size=batch_size, + num_workers=num_workers, + grayscale=grayscale, + mean=mean, + std=std, + prob=rand_crop_prob, + ) + + +class ValidationData: + """ + A data loader for validation. + + """ + + def __init__( + self, + source, + reference=None, + mode="latent", + img_size=128, + batch_size=32, + num_workers=4, + grayscale=False, + mean=None, + std=None, + **kwargs, + ): + """ + Parameters + ---------- + source_directory : str + The directory containing the source images. + ref_directory : str + The directory containing the reference images, defaults to source_directory if None. + mode : str + The mode of the data loader, either "latent" or "reference". + If "latent", the data loader will only load the source images. + If "reference", the data loader will load both the source and reference images. + image_size : int + The size of the images; images of a different size will be resized. + batch_size : int + The batch size for source data. + num_workers : int + The number of workers for the data loader. + grayscale : bool + Whether the images are grayscale. + mean: float + The mean for normalization, for the classifier. + std: float + The standard deviation for normalization, for the classifier. + kwargs : dict + Unused keyword arguments, for compatibility with configuration. + """ + assert mode in ["latent", "reference"] + # parameters + self.image_size = img_size + self.batch_size = batch_size + self.num_workers = num_workers + self.grayscale = grayscale + self.mean = mean + self.std = std + # The source and target classes + self.source = None + self.target = None + # The roots of the source and target directories + self.source_root = Path(source) + if reference is not None: + self.ref_root = Path(reference) + else: + self.ref_root = self.source_root + + # Available classes + self.available_sources = [ + subdir.name for subdir in self.source_root.iterdir() if subdir.is_dir() + ] + self._available_targets = None + self.set_mode(mode) + + def set_mode(self, mode): + assert mode in ["latent", "reference"] + self.mode = mode + + @property + def available_targets(self): + if self.mode == "latent": + return self.available_sources + elif self._available_targets is None: + self._available_targets = [ + subdir.name + for subdir in Path(self.ref_root).iterdir() + if subdir.is_dir() + ] + return self._available_targets + + def set_target(self, target): + assert ( + target in self.available_targets + ), f"{target} not in {self.available_targets}" + self.target = target + + def set_source(self, source): + assert ( + source in self.available_sources + ), f"{source} not in {self.available_sources}" + self.source = source + + @property + def reference_directory(self): + if self.mode == "latent": + return None + if self.target is None: + raise (ValueError("Target not set.")) + return self.ref_root / self.target + + @property + def source_directory(self): + if self.source is None: + raise (ValueError("Source not set.")) + return self.source_root / self.source + + def print_info(self): + print(f"Avaliable sources: {self.available_sources}") + print(f"Avaliable targets: {self.available_targets}") + print(f"Mode: {self.mode}") + try: + print(f"Current source directory: {self.source_directory}") + except ValueError: + print("Source not set.") + try: + print(f"Current target directory: {self.reference_directory}") + except ValueError: + print("Target not set.") + + @property + def loader_src(self): + return get_eval_loader( + self.source_directory, + img_size=self.image_size, + batch_size=self.batch_size, + num_workers=self.num_workers, + grayscale=self.grayscale, + mean=self.mean, + std=self.std, + drop_last=False, + ) + + @property + def loader_ref(self): + return get_eval_loader( + self.reference_directory, + img_size=self.image_size, + batch_size=self.batch_size, + num_workers=self.num_workers, + grayscale=self.grayscale, + mean=self.mean, + std=self.std, + drop_last=True, + ) diff --git a/src/quac/training/eval.py b/src/quac/training/eval.py new file mode 100644 index 0000000..cc7066f --- /dev/null +++ b/src/quac/training/eval.py @@ -0,0 +1,141 @@ +""" +StarGAN v2 +Copyright (c) 2020-present NAVER Corp. + +This work is licensed under the Creative Commons Attribution-NonCommercial +4.0 International License. To view a copy of this license, visit +http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to +Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. +""" + +import os +from collections import OrderedDict + +import numpy as np +from pathlib import Path +from quac.training import utils +from quac.training.classification import ClassifierWrapper +from quac.training.data_loader import get_eval_loader +import torch +from tqdm import tqdm + + +@torch.no_grad() +def calculate_metrics( + eval_dir, + step=0, + mode="latent", + classifier_checkpoint=None, + img_size=128, + val_batch_size=16, + num_outs_per_domain=10, + mean=None, + std=None, + input_dim=3, + run=None, +): + print("Calculating conversion rate for all tasks...", flush=True) + translation_rate_values = ( + OrderedDict() + ) # How many output images are of the right class + conversion_rate_values = ( + OrderedDict() + ) # How many input samples have a valid counterfactual + + domains = [subdir.name for subdir in Path(eval_dir).iterdir() if subdir.is_dir()] + + for subdir in Path(eval_dir).iterdir(): + if not subdir.is_dir() or subdir.name.startswith("."): # Skip hidden files + continue + src_domain = subdir.name + + for subdir2 in Path(subdir).iterdir(): + if not subdir2.is_dir() or subdir2.name.startswith("."): + continue + trg_domain = subdir2.name + + task = "%s/%s" % (src_domain, trg_domain) + print("Calculating conversion rate for %s..." % task, flush=True) + target_class = domains.index(trg_domain) + + translation_rate, conversion_rate = calculate_conversion_given_path( + subdir2, + model_checkpoint=classifier_checkpoint, + target_class=target_class, + img_size=img_size, + batch_size=val_batch_size, + num_outs_per_domain=num_outs_per_domain, + mean=mean, + std=std, + grayscale=(input_dim == 1), + ) + conversion_rate_values[ + "conversion_rate_%s/%s" % (mode, task) + ] = conversion_rate + translation_rate_values[ + "translation_rate_%s/%s" % (mode, task) + ] = translation_rate + + # calculate the average conversion rate for all tasks + conversion_rate_mean = 0 + translation_rate_mean = 0 + for _, value in conversion_rate_values.items(): + conversion_rate_mean += value / len(conversion_rate_values) + for _, value in translation_rate_values.items(): + translation_rate_mean += value / len(translation_rate_values) + + conversion_rate_values["conversion_rate_%s/mean" % mode] = conversion_rate_mean + translation_rate_values["translation_rate_%s/mean" % mode] = translation_rate_mean + + # report conversion rate values + filename = os.path.join(eval_dir, "conversion_rate_%.5i_%s.json" % (step, mode)) + utils.save_json(conversion_rate_values, filename) + # report translation rate values + filename = os.path.join(eval_dir, "translation_rate_%.5i_%s.json" % (step, mode)) + utils.save_json(translation_rate_values, filename) + if run is not None: + run.log(conversion_rate_values, step=step) + run.log(translation_rate_values, step=step) + + +@torch.no_grad() +def calculate_conversion_given_path( + path, + model_checkpoint, + target_class, + img_size=128, + batch_size=50, + num_outs_per_domain=10, + mean=0.5, + std=0.5, + grayscale=False, +): + print("Calculating conversion given path %s..." % path, flush=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + classifier = ClassifierWrapper(model_checkpoint, mean=mean, std=std) + classifier.to(device) + classifier.eval() + + loader = get_eval_loader( + path, + img_size=img_size, + batch_size=batch_size, + imagenet_normalize=False, + shuffle=False, + grayscale=grayscale, + ) + + predictions = [] + for x in tqdm(loader, total=len(loader)): + x = x.to(device) + predictions.append(classifier(x).cpu().numpy()) + predictions = np.concatenate(predictions, axis=0) + # Do it in a vectorized way, by reshaping the predictions + predictions = predictions.reshape(-1, num_outs_per_domain, predictions.shape[-1]) + predictions = predictions.argmax(axis=-1) + # + at_least_one = np.any(predictions == target_class, axis=1) + # + conversion_rate = np.mean(at_least_one) # (sum(at_least_one) / len(at_least_one) + translation_rate = np.mean(predictions == target_class) + return translation_rate, conversion_rate diff --git a/src/quac/training/logging.py b/src/quac/training/logging.py new file mode 100644 index 0000000..1b91b24 --- /dev/null +++ b/src/quac/training/logging.py @@ -0,0 +1,84 @@ +from numbers import Number +from typing import Union, Optional +import torch +import numpy as np + +try: + import wandb +except ImportError: + wandb = None + +try: + from torch.utils.tensorboard import SummaryWriter +except ImportError: + SummaryWriter = None + + +class Logger: + def create(log_type, resume_iter=0, hparams={}, **kwargs): + if log_type == "wandb": + if wandb is None: + raise ImportError("wandb is not installed.") + resume = "allow" if resume_iter > 0 else False + return WandBLogger(hparams=hparams, resume=resume, **kwargs) + elif log_type == "tensorboard": + if SummaryWriter is None: + raise ImportError("Tensorboard is not available.") + purge_step = resume_iter if resume_iter > 0 else None + return TensorboardLogger(hparams=hparams, purge_step=purge_step, **kwargs) + else: + raise NotImplementedError + + +class WandBLogger: + def __init__( + self, + hparams: dict, + project: str, + name: str, + notes: str, + tags: list, + resume: bool = False, + id: Optional[str] = None, + ): + self.run = wandb.init( + project=project, + name=name, + notes=notes, + tags=tags, + config=hparams, + resume=resume, + id=id, + ) + + def log(self, data: dict[str, Number], step: int = 0): + self.run.log(data, step=step) + + def log_images( + self, data: dict[str, Union[torch.Tensor, np.ndarray]], step: int = 0 + ): + for key, value in data.items(): + self.run.log({key: wandb.Image(value)}, step=step) + + +class TensorboardLogger: + # NOTE: Not tested + def __init__( + self, + log_dir: str, + comment: str, + hparams: dict, + purge_step: Union[int, None] = None, + ): + self.writer = SummaryWriter(log_dir, comment=comment, purge_step=purge_step) + self.writer.add_hparams(hparams, {}) + + def log(self, data: dict[str, Number], step: int = 0): + for key, value in data.items(): + self.writer.add_scalar(key, value, step) + + def log_images( + self, data: dict[str, Union[torch.Tensor, np.ndarray]], step: int = 0 + ): + for key, value in data.items(): + self.writer.add_images(key, value, step) diff --git a/src/quac/training/solver.py b/src/quac/training/solver.py index a9d6003..b00a351 100644 --- a/src/quac/training/solver.py +++ b/src/quac/training/solver.py @@ -8,24 +8,23 @@ Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. """ -import os -from os.path import join as ospj -import time import datetime from munch import Munch - +import numpy as np +import os +from os.path import join as ospj +from pathlib import Path +from quac.training.data_loader import AugmentedInputFetcher +from quac.training.checkpoint import CheckpointIO +import quac.training.utils as utils +from quac.training.classification import ClassifierWrapper +import shutil import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms - - -from quac.training.stargan import build_model -from quac.training.checkpoint import CheckpointIO -from quac.training.data_loader import InputFetcher, AugmentedInputFetcher -import quac.training.utils as utils - -# from metrics.eval import calculate_metrics +from tqdm import tqdm +import wandb transform = transforms.Compose( @@ -37,13 +36,28 @@ class Solver(nn.Module): - def __init__(self, nets, nets_ema, run, args): + def __init__( + self, + nets, + nets_ema, + f_lr: float, + lr: float, + beta1: float, + beta2: float, + weight_decay: float, + root_dir: str, + run=None, + ): super().__init__() - self.args = args self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.nets = nets self.nets_ema = nets_ema self.run = run + self.root_dir = Path(root_dir) + self.checkpoint_dir = self.root_dir / "checkpoints" + self.checkpoint_dir.mkdir(exist_ok=True, parents=True) + + checkpoint_dir = str(self.checkpoint_dir) # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device) for name, module in self.nets.items(): @@ -56,30 +70,28 @@ def __init__(self, nets, nets_ema, run, args): for net in self.nets.keys(): self.optims[net] = torch.optim.Adam( params=self.nets[net].parameters(), - lr=args.f_lr if net == "mapping_network" else args.lr, - betas=[args.beta1, args.beta2], - weight_decay=args.weight_decay, + lr=f_lr if net == "mapping_network" else lr, + betas=[beta1, beta2], + weight_decay=weight_decay, ) self.ckptios = [ CheckpointIO( - ospj(args.checkpoint_dir, "{:06d}_nets.ckpt"), + ospj(checkpoint_dir, "{:06d}_nets.ckpt"), data_parallel=True, **self.nets, ), CheckpointIO( - ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"), + ospj(checkpoint_dir, "{:06d}_nets_ema.ckpt"), data_parallel=True, **self.nets_ema, ), - CheckpointIO( - ospj(args.checkpoint_dir, "{:06d}_optims.ckpt"), **self.optims - ), + CheckpointIO(ospj(checkpoint_dir, "{:06d}_optims.ckpt"), **self.optims), ] else: self.ckptios = [ CheckpointIO( - ospj(args.checkpoint_dir, "{:06d}_nets_ema.ckpt"), + ospj(checkpoint_dir, "{:06d}_nets_ema.ckpt"), data_parallel=True, **self.nets_ema, ) @@ -88,7 +100,6 @@ def __init__(self, nets, nets_ema, run, args): self.to(self.device) # TODO The EMA doesn't need to be in named_childeren() for name, network in self.named_children(): - # Do not initialize the FAN parameters if "ema" not in name: print("Initializing %s..." % name) network.apply(utils.he_init) @@ -105,29 +116,54 @@ def _reset_grad(self): for optim in self.optims.values(): optim.zero_grad() - def train(self, loaders, run=None): - args = self.args + @property + def latent_dim(self): + try: + latent_dim = self.nets.mapping_network.latent_dim + except AttributeError: + # it's a data parallel model + latent_dim = self.nets.mapping_network.module.latent_dim + return latent_dim + + def train( + self, + loader, + resume_iter: int = 0, + total_iters: int = 100000, + log_every: int = 100, + save_every: int = 10000, + eval_every: int = 10000, + # sample_dir: str = "samples", + lambda_ds: float = 1.0, + ds_iter: int = 10000, + lambda_reg: float = 1.0, + lambda_sty: float = 1.0, + lambda_cyc: float = 1.0, + # Validation things + val_loader=None, + val_config=None, + ): + start = datetime.datetime.now() nets = self.nets nets_ema = self.nets_ema optims = self.optims - # fetch random validation images for debugging fetcher = AugmentedInputFetcher( - loaders.src, loaders.ref, args.latend_dim, "train" + loader.src, + loader.reference, + latent_dim=self.latent_dim, + mode="train", ) - fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, "val") - inputs_val = next(fetcher_val) # resume training if necessary - if args.resume_iter > 0: - self._load_checkpoint(args.resume_iter) + if resume_iter > 0: + self._load_checkpoint(resume_iter) # remember the initial value of ds weight - initial_lambda_ds = args.lambda_ds + initial_lambda_ds = lambda_ds print("Start training...") - start_time = time.time() - for i in range(args.resume_iter, args.total_iters): + for i in range(resume_iter, total_iters): # fetch images and labels inputs = next(fetcher) x_real, x_aug, y_org = inputs.x_src, inputs.x_src2, inputs.y_src @@ -136,22 +172,29 @@ def train(self, loaders, run=None): # train the discriminator d_loss, d_losses_latent = compute_d_loss( - nets, args, x_real, y_org, y_trg, z_trg=z_trg + nets, x_real, y_org, y_trg, z_trg=z_trg, lambda_reg=lambda_reg ) self._reset_grad() d_loss.backward() optims.discriminator.step() d_loss, d_losses_ref = compute_d_loss( - nets, args, x_real, y_org, y_trg, x_ref=x_ref + nets, x_real, y_org, y_trg, x_ref=x_ref, lambda_reg=lambda_reg ) self._reset_grad() d_loss.backward() optims.discriminator.step() # train the generator - g_loss, g_losses_latent = compute_g_loss( - nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2] + g_loss, g_losses_latent, fake_x_latent = compute_g_loss( + nets, + x_real, + y_org, + y_trg, + z_trgs=[z_trg, z_trg2], + lambda_sty=lambda_sty, + lambda_ds=lambda_ds, + lambda_cyc=lambda_cyc, ) self._reset_grad() g_loss.backward() @@ -159,60 +202,289 @@ def train(self, loaders, run=None): optims.mapping_network.step() optims.style_encoder.step() - g_loss, g_losses_ref = compute_g_loss( - nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], x_aug=x_aug + g_loss, g_losses_ref, fake_x_reference = compute_g_loss( + nets, + x_real, + y_org, + y_trg, + x_refs=[x_ref, x_ref2], + x_aug=x_aug, + lambda_sty=lambda_sty, + lambda_ds=lambda_ds, + lambda_cyc=lambda_cyc, ) self._reset_grad() g_loss.backward() optims.generator.step() - # TODO do this with timm # compute moving average of network parameters - nets_ema.update(nets) + moving_average(nets.generator, nets_ema.generator, beta=0.999) + moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999) + moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999) # decay weight for diversity sensitive loss - if args.lambda_ds > 0: - args.lambda_ds -= initial_lambda_ds / args.ds_iter - - # print out log info - if (i + 1) % args.log_every == 0: - # TODO replace with wandb - all_losses = dict() - for loss, prefix in zip( - [d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref], - ["D/latent_", "D/ref_", "G/latent_", "G/ref_"], - ): - for key, value in loss.items(): - all_losses[prefix + key] = value - all_losses["G/lambda_ds"] = args.lambda_ds - run.log(all_losses) - - # TODO put this on wandb instead - os.makedirs(args.sample_dir, exist_ok=True) - utils.debug_image(nets_ema, args, inputs=inputs_val, step=i + 1) + if lambda_ds > 0: + lambda_ds -= initial_lambda_ds / ds_iter + + if (i + 1) % eval_every == 0 and val_loader is not None: + self.evaluate( + val_loader, iteration=i + 1, mode="reference", val_config=val_config + ) + self.evaluate( + val_loader, iteration=i + 1, mode="latent", val_config=val_config + ) # save model checkpoints - if (i + 1) % args.save_every == 0: + if (i + 1) % save_every == 0: self._save_checkpoint(step=i + 1) - # compute FID and LPIPS if necessary - if (i + 1) % args.eval_every == 0: - calculate_metrics(nets_ema, args, i + 1, mode="latent") - calculate_metrics(nets_ema, args, i + 1, mode="reference") + # print out log losses, images + if (i + 1) % log_every == 0: + elapsed = datetime.datetime.now() - start + # Log the images made by the EMA model! + with torch.no_grad(): + ema_fake_x_latent = nets_ema.generator( + x_real, nets_ema.mapping_network(z_trg, y_trg) + ) + ema_fake_x_reference = nets_ema.generator( + x_real, nets_ema.style_encoder(x_ref, y_trg) + ) + self.log( + d_losses_latent, + d_losses_ref, + g_losses_latent, + g_losses_ref, + lambda_ds, + x_real, + x_ref, + fake_x_latent, + fake_x_reference, + ema_fake_x_latent, + ema_fake_x_reference, + y_org, # Source classes + y_trg, # Target classes + step=i + 1, + total_iters=total_iters, + elapsed_time=elapsed, + ) + + def log( + self, + d_losses_latent, + d_losses_ref, + g_losses_latent, + g_losses_ref, + lambda_ds, + x_real, + x_ref, + fake_x_latent, + fake_x_reference, + ema_fake_x_latent, + ema_fake_x_reference, + y_source, + y_target, + step, + total_iters, + elapsed_time, + ): + all_losses = dict() + for loss, prefix in zip( + [d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref], + ["D/latent_", "D/ref_", "G/latent_", "G/ref_"], + ): + for key, value in loss.items(): + all_losses[prefix + key] = value + all_losses["G/lambda_ds"] = lambda_ds + # log all losses to wandb or print them + if self.run: + self.run.log(all_losses, step=step) + for name, img, label in zip( + [ + "x_real", + "x_ref", + "fake_x_latent", + "fake_x_reference", + "ema_fake_x_latent", + "ema_fake_x_reference", + ], + [ + x_real, + x_ref, + fake_x_latent, + fake_x_reference, + ema_fake_x_latent, + ema_fake_x_reference, + ], + [y_source, y_target, y_target, y_target, y_target, y_target], + ): + # TODO put captions back in somehow + self.run.log_images({name: img}, step=step) + + print( + f"[{elapsed_time}]: {step}/{total_iters}", + flush=True, + ) + g_losses = "\t".join( + [ + f"{key}: {value:.4f}" + for key, value in all_losses.items() + if not key.startswith("D/") + ] + ) + d_losses = "\t".join( + [ + f"{key}: {value:.4f}" + for key, value in all_losses.items() + if key.startswith("D/") + ] + ) + print(f"G Losses: {g_losses}", flush=True) + print(f"D Losses: {d_losses}", flush=True) @torch.no_grad() - def evaluate(self): - args = self.args - nets_ema = self.nets_ema - resume_iter = args.resume_iter - self._load_checkpoint(args.resume_iter) - calculate_metrics(nets_ema, args, step=resume_iter, mode="latent") - calculate_metrics(nets_ema, args, step=resume_iter, mode="reference") + def evaluate( + self, + val_loader, + iteration=None, + num_outs_per_domain=10, + mode="latent", + val_config=None, + ): + """ + Generates images for evaluation and stores them to disk. + + Parameters + ---------- + val_loader + """ + if iteration is None: # Choose the iteration to evaluate + resume_iter = resume_iter + self._load_checkpoint(resume_iter) + + # Generate images for evaluation + eval_dir = self.root_dir / "eval" + eval_dir.mkdir(exist_ok=True, parents=True) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Load classifier + classifier = ClassifierWrapper( + val_config.classifier_checkpoint, + val_config.mean, + val_config.std, + assume_normalized=val_config.assume_normalized, + do_nothing=val_config.do_nothing, + ) + classifier.to(device) + assert mode in ["latent", "reference"] + + val_loader.set_mode(mode) + + domains = val_loader.available_targets + print("Number of domains: %d" % len(domains)) + + conversion_rate_values = {} + translation_rate_values = {} + + for trg_idx, trg_domain in enumerate(domains): + src_domains = [x for x in val_loader.available_sources if x != trg_domain] + val_loader.set_target(trg_domain) + if mode == "reference": + loader_ref = val_loader.loader_ref + + for src_idx, src_domain in enumerate(src_domains): + task = "%s/%s" % (src_domain, trg_domain) + # Creating the path + path_fake = os.path.join(eval_dir, task) + shutil.rmtree(path_fake, ignore_errors=True) + os.makedirs(path_fake) + + # Setting the source domain + val_loader.set_source(src_domain) + loader_src = val_loader.loader_src + + for i, x_src in enumerate(tqdm(loader_src, total=len(loader_src))): + N = x_src.size(0) + x_src = x_src.to(device) + y_trg = torch.tensor([trg_idx] * N).to(device) + + predictions = [] + # generate num_outs_per_domain outputs from the same input + for j in range(num_outs_per_domain): + if mode == "latent": + z_trg = torch.randn(N, self.latent_dim).to(device) + s_trg = self.nets_ema.mapping_network(z_trg, y_trg) + else: + # x_ref = x_trg.clone() + try: + # TODO don't need to re-do this every time, just use + # the same set of reference images for the whole dataset! + x_ref = next(iter_ref).to(device) + except: + iter_ref = iter(loader_ref) + x_ref = next(iter_ref).to(device) + + if x_ref.size(0) > N: + x_ref = x_ref[:N] + elif x_ref.size(0) < N: + raise ValueError( + "Not enough reference images." + "Make sure that the batch size of the validation loader is bigger than `num_outs_per_domain`." + ) + s_trg = self.nets_ema.style_encoder(x_ref, y_trg) + + x_fake = self.nets_ema.generator(x_src, s_trg) + # Run the classification + pred = classifier( + x_fake, + assume_normalized=val_config.assume_normalized, + do_nothing=val_config.do_nothing, + ) + predictions.append(pred.cpu().numpy()) + predictions = np.stack(predictions, axis=0) + assert len(predictions) > 0 + # Do it in a vectorized way, by reshaping the predictions + predictions = predictions.reshape( + -1, num_outs_per_domain, predictions.shape[-1] + ) + predictions = predictions.argmax(axis=-1) + # + at_least_one = np.any(predictions == trg_idx, axis=1) + # + conversion_rate = np.mean(at_least_one) + translation_rate = np.mean(predictions == trg_idx) + + # STORE + conversion_rate_values[ + f"conversion_rate_{mode}/" + task + ] = conversion_rate + translation_rate_values[ + f"translation_rate_{mode}/" + task + ] = translation_rate + + # Add average conversion rate and translation rate + conversion_rate_values[f"conversion_rate_{mode}/average"] = np.mean( + [conversion_rate_values[key] for key in conversion_rate_values.keys()] + ) + translation_rate_values[f"translation_rate_{mode}/average"] = np.mean( + [translation_rate_values[key] for key in translation_rate_values.keys()] + ) + # report conversion rate values + filename = os.path.join( + eval_dir, "conversion_rate_%.5i_%s.json" % (iteration, mode) + ) + utils.save_json(conversion_rate_values, filename) + # report translation rate values + filename = os.path.join( + eval_dir, "translation_rate_%.5i_%s.json" % (iteration, mode) + ) + utils.save_json(translation_rate_values, filename) + if self.run is not None: + self.run.log(conversion_rate_values, step=iteration) + self.run.log(translation_rate_values, step=iteration) -def compute_d_loss( - nets, args, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None -): + +def compute_d_loss(nets, x_real, y_org, y_trg, z_trg=None, x_ref=None, lambda_reg=1.0): assert (z_trg is None) != (x_ref is None) # with real images x_real.requires_grad_() @@ -227,18 +499,27 @@ def compute_d_loss( else: # x_ref is not None s_trg = nets.style_encoder(x_ref, y_trg) - x_fake = nets.generator(x_real, s_trg, masks=masks) + x_fake = nets.generator(x_real, s_trg) out = nets.discriminator(x_fake, y_trg) loss_fake = adv_loss(out, 0) - loss = loss_real + loss_fake + args.lambda_reg * loss_reg + loss = loss_real + loss_fake + lambda_reg * loss_reg return loss, Munch( real=loss_real.item(), fake=loss_fake.item(), reg=loss_reg.item() ) def compute_g_loss( - nets, args, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None, x_aug=None + nets, + x_real, + y_org, + y_trg, + z_trgs=None, + x_refs=None, + x_aug=None, + lambda_sty: float = 1.0, + lambda_ds: float = 1.0, + lambda_cyc: float = 1.0, ): assert (z_trgs is None) != (x_refs is None) if z_trgs is not None: @@ -252,7 +533,7 @@ def compute_g_loss( else: s_trg = nets.style_encoder(x_ref, y_trg) - x_fake = nets.generator(x_real, s_trg, masks=masks) + x_fake = nets.generator(x_real, s_trg) out = nets.discriminator(x_fake, y_trg) loss_adv = adv_loss(out, 1) @@ -266,14 +547,13 @@ def compute_g_loss( s_trg2 = nets.mapping_network(z_trg2, y_trg) else: s_trg2 = nets.style_encoder(x_ref2, y_trg) - x_fake2 = nets.generator(x_real, s_trg2, masks=masks) + x_fake2 = nets.generator(x_real, s_trg2) x_fake2 = x_fake2.detach() loss_ds = torch.mean(torch.abs(x_fake - x_fake2)) # cycle-consistency loss - masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None s_org = nets.style_encoder(x_real, y_org) - x_rec = nets.generator(x_fake, s_org, masks=masks) + x_rec = nets.generator(x_fake, s_org) loss_cyc = torch.mean(torch.abs(x_rec - x_real)) # style invariance loss @@ -282,24 +562,18 @@ def compute_g_loss( loss_sty2 = torch.mean(torch.abs(s_pred2 - s_org)) loss_sty = (loss_sty + loss_sty2) / 2 - # TODO Triplet loss with anchor=x_fake (from reference), neg=x_fake2 (from reference 2), and pos=x_fake3 (from augmented x_ref1) - # e.g. - # from torch.nn.functional import triplet_margin_loss - # if x_trg is not None: - # x_ref3 = transform(x_ref) - # s_trg3 = nets.style_encoder(x_ref3, y_trg) - # x_fake3 = nets.generator(x_real, s_trg3, masks=masks) - # loss_triplet = triplet_margin_loss(x_fake, x_fake2, x_fake3) - # I'm hoping this has the effect of a invariance-diversity double whammy - loss = ( - loss_adv - + args.lambda_sty * loss_sty - - args.lambda_ds * loss_ds - + args.lambda_cyc * loss_cyc + loss_adv + lambda_sty * loss_sty - lambda_ds * loss_ds + lambda_cyc * loss_cyc ) - return loss, Munch( - adv=loss_adv.item(), sty=loss_sty.item(), ds=loss_ds.item(), cyc=loss_cyc.item() + return ( + loss, + Munch( + adv=loss_adv.item(), + sty=loss_sty.item(), + ds=loss_ds.item(), + cyc=loss_cyc.item(), + ), + x_fake, ) diff --git a/src/quac/training/stargan.py b/src/quac/training/stargan.py index 7a8aeb7..b3f5a68 100644 --- a/src/quac/training/stargan.py +++ b/src/quac/training/stargan.py @@ -147,7 +147,14 @@ def forward(self, x, s): class Generator(nn.Module): - def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, input_dim=1): + def __init__( + self, + img_size=256, + style_dim=64, + max_conv_dim=512, + input_dim=1, + final_activation=None, + ): super().__init__() dim_in = 2**14 // img_size self.img_size = img_size @@ -159,7 +166,12 @@ def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, input_dim=1): nn.LeakyReLU(0.2), nn.Conv2d(dim_in, input_dim, 1, 1, 0), ) - self.final_activation = nn.Tanh() + if final_activation == "sigmoid": + # print("Using sigmoid") + self.final_activation = nn.Sigmoid() + else: + # print("Using tanh") + self.final_activation = nn.Tanh() # down/up-sampling blocks repeat_num = int(np.log2(img_size)) - 4 @@ -178,7 +190,7 @@ def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, input_dim=1): def forward(self, x, s): x = self.from_rgb(x) - cache = {} + # cache = {} for block in self.encode: x = block(x) for block in self.decode: @@ -189,6 +201,7 @@ def forward(self, x, s): class MappingNetwork(nn.Module): def __init__(self, latent_dim=16, style_dim=64, num_domains=2): super().__init__() + self.latent_dim = latent_dim layers = [] layers += [nn.Linear(latent_dim, 512)] layers += [nn.ReLU()] @@ -310,11 +323,11 @@ def forward(self, x, y): class Discriminator(nn.Module): - def __init__(self, img_size=256, num_domains=2, max_conv_dim=512): + def __init__(self, img_size=256, num_domains=2, max_conv_dim=512, input_dim=3): super().__init__() dim_in = 2**14 // img_size blocks = [] - blocks += [nn.Conv2d(1, dim_in, 3, 1, 1)] + blocks += [nn.Conv2d(input_dim, dim_in, 3, 1, 1)] repeat_num = int(np.log2(img_size)) - 2 for _ in range(repeat_num): @@ -336,38 +349,49 @@ def forward(self, x, y): return out -def build_model(args, gpu_ids=[0]): +def build_model( + img_size=128, + style_dim=64, + input_dim=3, + latent_dim=16, + num_domains=4, + single_output_style_encoder=False, + final_activation=None, + gpu_ids=[0], +): generator = nn.DataParallel( - Generator(args.img_size, args.style_dim, input_dim=args.input_dim), + Generator( + img_size, style_dim, input_dim=input_dim, final_activation=final_activation + ), device_ids=gpu_ids, ) mapping_network = nn.DataParallel( - MappingNetwork(args.latent_dim, args.style_dim, args.num_domains), + MappingNetwork(latent_dim, style_dim, num_domains), device_ids=gpu_ids, ) - if args.single_output_style_encoder: + if single_output_style_encoder: print("Using single output style encoder") style_encoder = nn.DataParallel( SingleOutputStyleEncoder( - args.img_size, - args.style_dim, - args.num_domains, - input_dim=args.input_dim, + img_size, + style_dim, + num_domains, + input_dim=input_dim, ), device_ids=gpu_ids, ) else: style_encoder = nn.DataParallel( StyleEncoder( - args.img_size, - args.style_dim, - args.num_domains, - input_dim=args.input_dim, + img_size, + style_dim, + num_domains, + input_dim=input_dim, ), device_ids=gpu_ids, ) discriminator = nn.DataParallel( - Discriminator(args.img_size, args.num_domains, input_dim=args.input_dim), + Discriminator(img_size, num_domains, input_dim=input_dim), device_ids=gpu_ids, ) generator_ema = copy.deepcopy(generator) diff --git a/src/quac/training/utils.py b/src/quac/training/utils.py index 2323caa..6d7dd2f 100644 --- a/src/quac/training/utils.py +++ b/src/quac/training/utils.py @@ -8,20 +8,15 @@ Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. """ -import os -from os.path import join as ospj import json -import glob -from shutil import copyfile - -from tqdm import tqdm -import ffmpeg - +import matplotlib.pyplot as plt import numpy as np +from os.path import join as ospj +from pathlib import Path +import pandas as pd +import re import torch import torch.nn as nn -import torch.nn.functional as F -import torchvision import torchvision.utils as vutils @@ -59,235 +54,211 @@ def save_image(x, ncol, filename): vutils.save_image(x.cpu(), filename, nrow=ncol, padding=0) -@torch.no_grad() -def translate_and_reconstruct(nets, args, x_src, y_src, x_ref, y_ref, filename): - N, C, H, W = x_src.size() - s_ref = nets.style_encoder(x_ref, y_ref) - masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None - x_fake = nets.generator(x_src, s_ref, masks=masks) - s_src = nets.style_encoder(x_src, y_src) - masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None - x_rec = nets.generator(x_fake, s_src, masks=masks) - x_concat = [x_src, x_ref, x_fake, x_rec] - x_concat = torch.cat(x_concat, dim=0) - save_image(x_concat, N, filename) - del x_concat - - -@torch.no_grad() -def translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi, filename): - N, C, H, W = x_src.size() - latent_dim = z_trg_list[0].size(1) - x_concat = [x_src] - masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None - - for i, y_trg in enumerate(y_trg_list): - z_many = torch.randn(10000, latent_dim).to(x_src.device) - y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0]) - s_many = nets.mapping_network(z_many, y_many) - s_avg = torch.mean(s_many, dim=0, keepdim=True) - s_avg = s_avg.repeat(N, 1) - - for z_trg in z_trg_list: - s_trg = nets.mapping_network(z_trg, y_trg) - s_trg = torch.lerp(s_avg, s_trg, psi) - x_fake = nets.generator(x_src, s_trg, masks=masks) - x_concat += [x_fake] - - x_concat = torch.cat(x_concat, dim=0) - save_image(x_concat, N, filename) - - -@torch.no_grad() -def translate_using_reference(nets, args, x_src, x_ref, y_ref, filename): - N, C, H, W = x_src.size() - wb = torch.ones(1, C, H, W).to(x_src.device) - x_src_with_wb = torch.cat([wb, x_src], dim=0) - - masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None - s_ref = nets.style_encoder(x_ref, y_ref) - s_ref_list = s_ref.unsqueeze(1).repeat(1, N, 1) - x_concat = [x_src_with_wb] - for i, s_ref in enumerate(s_ref_list): - x_fake = nets.generator(x_src, s_ref, masks=masks) - x_fake_with_ref = torch.cat([x_ref[i : i + 1], x_fake], dim=0) - x_concat += [x_fake_with_ref] - - x_concat = torch.cat(x_concat, dim=0) - save_image(x_concat, N + 1, filename) - del x_concat - - -@torch.no_grad() -def debug_image(nets, args, inputs, step): - x_src, y_src = inputs.x_src, inputs.y_src - x_ref, y_ref = inputs.x_ref, inputs.y_ref - - device = inputs.x_src.device - N = inputs.x_src.size(0) - - # translate and reconstruct (reference-guided) - filename = ospj(args.sample_dir, "%06d_cycle_consistency.jpg" % (step)) - translate_and_reconstruct(nets, args, x_src, y_src, x_ref, y_ref, filename) - - # latent-guided image synthesis - y_trg_list = [ - torch.tensor(y).repeat(N).to(device) for y in range(min(args.num_domains, 5)) - ] - z_trg_list = ( - torch.randn(args.num_outs_per_domain, 1, args.latent_dim) - .repeat(1, N, 1) - .to(device) +class Logger: + def __init__(self, log_dir, nets, num_outs_per_domain=10) -> None: + self.log_dir = Path(log_dir) + if not self.log_dir.exists(): + self.log_dir.mkdir(parents=True) + self.nets = nets + + @torch.no_grad() + def translate_and_reconstruct(self, x_src, y_src, x_ref, y_ref, filename): + N, C, H, W = x_src.size() + s_ref = self.nets.style_encoder(x_ref, y_ref) + x_fake = self.nets.generator(x_src, s_ref) + s_src = self.nets.style_encoder(x_src, y_src) + x_rec = self.nets.generator(x_fake, s_src) + x_concat = [x_src, x_ref, x_fake, x_rec] + x_concat = torch.cat(x_concat, dim=0) + save_image(x_concat, N, filename) + del x_concat + + @torch.no_grad() + def translate_using_latent(self, x_src, y_trg_list, z_trg_list, psi, filename): + N, C, H, W = x_src.size() + latent_dim = z_trg_list[0].size(1) + x_concat = [x_src] + + for i, y_trg in enumerate(y_trg_list): + z_many = torch.randn(10000, latent_dim).to(x_src.device) + y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0]) + s_many = self.nets.mapping_network(z_many, y_many) + s_avg = torch.mean(s_many, dim=0, keepdim=True) + s_avg = s_avg.repeat(N, 1) + + for z_trg in z_trg_list: + s_trg = self.nets.mapping_network(z_trg, y_trg) + s_trg = torch.lerp(s_avg, s_trg, psi) + x_fake = self.nets.generator(x_src, s_trg) + x_concat += [x_fake] + + x_concat = torch.cat(x_concat, dim=0) + save_image(x_concat, N, filename) + + @torch.no_grad() + def translate_using_reference(self, x_src, x_ref, y_ref, filename): + N, C, H, W = x_src.size() + wb = torch.ones(1, C, H, W).to(x_src.device) + x_src_with_wb = torch.cat([wb, x_src], dim=0) + + masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None + s_ref = nets.style_encoder(x_ref, y_ref) + s_ref_list = s_ref.unsqueeze(1).repeat(1, N, 1) + x_concat = [x_src_with_wb] + for i, s_ref in enumerate(s_ref_list): + x_fake = nets.generator(x_src, s_ref, masks=masks) + x_fake_with_ref = torch.cat([x_ref[i : i + 1], x_fake], dim=0) + x_concat += [x_fake_with_ref] + + x_concat = torch.cat(x_concat, dim=0) + save_image(x_concat, N + 1, filename) + del x_concat + + @torch.no_grad() + def debug_image(self, inputs, step): + x_src, y_src = inputs.x_src, inputs.y_src + x_ref, y_ref = inputs.x_ref, inputs.y_ref + + device = inputs.x_src.device + N = inputs.x_src.size(0) + + # translate and reconstruct (reference-guided) + filename = ospj(self.sample_dir, "%06d_cycle_consistency.jpg" % (step)) + self.translate_and_reconstruct(x_src, y_src, x_ref, y_ref, filename) + + # latent-guided image synthesis + y_trg_list = [ + torch.tensor(y).repeat(N).to(device) + for y in range(min(args.num_domains, 5)) + ] + z_trg_list = ( + torch.randn(self.num_outs_per_domain, 1, args.latent_dim) + .repeat(1, N, 1) + .to(device) + ) + for psi in [0.5, 0.7, 1.0]: + filename = ospj(self.sample_dir, "%06d_latent_psi_%.1f.jpg" % (step, psi)) + self.translate_using_latent(x_src, y_trg_list, z_trg_list, psi, filename) + + # reference-guided image synthesis + filename = ospj(self.sample_dir, "%06d_reference.jpg" % (step)) + self.translate_using_reference(x_src, x_ref, y_ref, filename) + + +########################### +# LOSS PLOTTING FUNCTIONS # +########################### +def get_epoch_number(f): + # Get the epoch number from the filename and sort the files by epoch + # The epoch number is the only number in the filename, and can be 5 or 6 digits long + return int(re.findall(r"\d+", f.name)[0]) + + +def load_json_files(files): + # Load the data from the json files, store everything into a dictionary, with the epoch number as the key + data = {} + for f in files: + with open(f, "r") as file: + data[get_epoch_number(f)] = json.load(file) + + df = pd.DataFrame.from_dict(data, orient="index") + df.columns = [col.split("/")[-1] for col in df.columns] + df.sort_index(inplace=True) + return df + + +def plot_from_data( + reference_conversion_data, + latent_conversion_data, + reference_translation_data, + latent_translation_data, +): + n_cols = int(np.ceil(len(reference_conversion_data.columns) / 7)) + fig, axes = plt.subplots(7, n_cols, figsize=(15, 15)) + for col, ax in zip(reference_conversion_data.columns, axes.ravel()): + reference_conversion_data[col].plot( + ax=ax, label="Reference Conversion", color="black" + ) + latent_conversion_data[col].plot( + ax=ax, label="Latent Conversion", color="black", linestyle="--" + ) + reference_translation_data[col].plot( + ax=ax, label="Reference Translation", color="gray" + ) + latent_translation_data[col].plot( + ax=ax, label="Latent Translation", color="gray", linestyle="--" + ) + ax.set_ylim(0, 1) + # format the title only if there are any non-numeric characters in the column name + alphabetic_title = any([c.isalpha() for c in col]) + if alphabetic_title: + # Split the words in the column by the underscore, remove all numbers, and add the word "to" between the words + title = " to ".join( + [word.capitalize() for word in col.split("_") if not word.isdigit()] + ) + # Remove all remaining numbers from the title + title = "".join([i for i in title if not i.isdigit()]) + else: + # The title is \d2\d and we want it to be \d to \d, unfortunately sometimes \d is the number 2, so we need to be careful + title = re.sub(r"(\d)2(\d)", r"\1 to \2", col) + ax.set_title(title) + # Hide all of the extra axes if any + for ax in axes.ravel()[len(reference_conversion_data.columns) :]: + ax.axis("off") + # Add an x-axis label to the bottom row of plots that still has a visible x-axis + num_axes = len(reference_conversion_data.columns) + for ax in axes.ravel()[num_axes - n_cols :]: + ax.set_xlabel("Iteration") + # Add a y-axis label to the left column of plots + for ax in axes[:, 0]: + ax.set_ylabel("Rate") + # Make a legend for the whole figure, assuming that the labels are the same for all subplots + fig.legend( + *ax.get_legend_handles_labels(), loc="upper right", bbox_to_anchor=(1.15, 1) ) - for psi in [0.5, 0.7, 1.0]: - filename = ospj(args.sample_dir, "%06d_latent_psi_%.1f.jpg" % (step, psi)) - translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi, filename) + fig.tight_layout() - # reference-guided image synthesis - filename = ospj(args.sample_dir, "%06d_reference.jpg" % (step)) - translate_using_reference(nets, args, x_src, x_ref, y_ref, filename) + return fig -# ======================= # -# Video-related functions # -# ======================= # +def plot_metrics(root, show: bool = True, save: str = ""): + """ + root: str + The root directory containing the json files with the metrics. + show: bool + Whether to show the plot. + save: str + The path to save the plot to. + """ + # Goal is to get a better idea of the success/failure of conversion as training goes on. + # Will be useful to understand how it changes over time, and how the translation/conversion/diversity tradeoff plays out. + files = list(Path(root).rglob("*.json")) -def sigmoid(x, w=1): - return 1.0 / (1 + np.exp(-w * x)) + # Split files by whether they are LPIPS, conversion_rate, or translation_rate + conversion_files = [f for f in files if "conversion_rate" in f.name] + translation_files = [f for f in files if "translation_rate" in f.name] + # Split files by whether they are reference or latent + reference_conversion = [f for f in conversion_files if "reference" in f.name] + latent_conversion = [f for f in conversion_files if "latent" in f.name] -def get_alphas(start=-5, end=5, step=0.5, len_tail=10): - return ( - [0] + [sigmoid(alpha) for alpha in np.arange(start, end, step)] + [1] * len_tail - ) + reference_translation = [f for f in translation_files if "reference" in f.name] + latent_translation = [f for f in translation_files if "latent" in f.name] + # Load the data from the json files + reference_conversion_data = load_json_files(reference_conversion) + latent_conversion_data = load_json_files(latent_conversion) + reference_translation_data = load_json_files(reference_translation) + latent_translation_data = load_json_files(latent_translation) -def interpolate(nets, args, x_src, s_prev, s_next): - """returns T x C x H x W""" - B = x_src.size(0) - frames = [] - masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None - alphas = get_alphas() - - for alpha in alphas: - s_ref = torch.lerp(s_prev, s_next, alpha) - x_fake = nets.generator(x_src, s_ref, masks=masks) - entries = torch.cat([x_src.cpu(), x_fake.cpu()], dim=2) - frame = torchvision.utils.make_grid( - entries, nrow=B, padding=0, pad_value=-1 - ).unsqueeze(0) - frames.append(frame) - frames = torch.cat(frames) - return frames - - -def slide(entries, margin=32): - """Returns a sliding reference window. - Args: - entries: a list containing two reference images, x_prev and x_next, - both of which has a shape (1, 3, 256, 256) - Returns: - canvas: output slide of shape (num_frames, 3, 256*2, 256+margin) - """ - _, C, H, W = entries[0].shape - alphas = get_alphas() - T = len(alphas) # number of frames - - canvas = -torch.ones((T, C, H * 2, W + margin)) - merged = torch.cat(entries, dim=2) # (1, 3, 512, 256) - for t, alpha in enumerate(alphas): - top = int(H * (1 - alpha)) # top, bottom for canvas - bottom = H * 2 - m_top = 0 # top, bottom for merged - m_bottom = 2 * H - top - canvas[t, :, top:bottom, :W] = merged[:, :, m_top:m_bottom, :] - return canvas - - -@torch.no_grad() -def video_ref(nets, args, x_src, x_ref, y_ref, fname): - video = [] - s_ref = nets.style_encoder(x_ref, y_ref) - s_prev = None - for data_next in tqdm(zip(x_ref, y_ref, s_ref), "video_ref", len(x_ref)): - x_next, y_next, s_next = [d.unsqueeze(0) for d in data_next] - if s_prev is None: - x_prev, y_prev, s_prev = x_next, y_next, s_next - continue - if y_prev != y_next: - x_prev, y_prev, s_prev = x_next, y_next, s_next - continue - - interpolated = interpolate(nets, args, x_src, s_prev, s_next) - entries = [x_prev, x_next] - slided = slide(entries) # (T, C, 256*2, 256) - frames = torch.cat( - [slided, interpolated], dim=3 - ).cpu() # (T, C, 256*2, 256*(batch+1)) - video.append(frames) - x_prev, y_prev, s_prev = x_next, y_next, s_next - - # append last frame 10 time - for _ in range(10): - video.append(frames[-1:]) - video = tensor2ndarray255(torch.cat(video)) - save_video(fname, video) - - -@torch.no_grad() -def video_latent(nets, args, x_src, y_list, z_list, psi, fname): - latent_dim = z_list[0].size(1) - s_list = [] - for i, y_trg in enumerate(y_list): - z_many = torch.randn(10000, latent_dim).to(x_src.device) - y_many = torch.LongTensor(10000).to(x_src.device).fill_(y_trg[0]) - s_many = nets.mapping_network(z_many, y_many) - s_avg = torch.mean(s_many, dim=0, keepdim=True) - s_avg = s_avg.repeat(x_src.size(0), 1) - - for z_trg in z_list: - s_trg = nets.mapping_network(z_trg, y_trg) - s_trg = torch.lerp(s_avg, s_trg, psi) - s_list.append(s_trg) - - s_prev = None - video = [] - # fetch reference images - for idx_ref, s_next in enumerate(tqdm(s_list, "video_latent", len(s_list))): - if s_prev is None: - s_prev = s_next - continue - if idx_ref % len(z_list) == 0: - s_prev = s_next - continue - frames = interpolate(nets, args, x_src, s_prev, s_next).cpu() - video.append(frames) - s_prev = s_next - for _ in range(10): - video.append(frames[-1:]) - video = tensor2ndarray255(torch.cat(video)) - save_video(fname, video) - - -def save_video(fname, images, output_fps=30, vcodec="libx264", filters=""): - assert isinstance(images, np.ndarray), "images should be np.array: NHWC" - num_frames, height, width, channels = images.shape - stream = ffmpeg.input( - "pipe:", format="rawvideo", pix_fmt="rgb24", s="{}x{}".format(width, height) - ) - stream = ffmpeg.filter(stream, "setpts", "2*PTS") # 2*PTS is for slower playback - stream = ffmpeg.output( - stream, fname, pix_fmt="yuv420p", vcodec=vcodec, r=output_fps + fig = plot_from_data( + reference_conversion_data, + latent_conversion_data, + reference_translation_data, + latent_translation_data, ) - stream = ffmpeg.overwrite_output(stream) - process = ffmpeg.run_async(stream, pipe_stdin=True) - for frame in tqdm(images, desc="writing video to %s" % fname): - process.stdin.write(frame.astype(np.uint8).tobytes()) - process.stdin.close() - process.wait() - - -def tensor2ndarray255(images): - images = torch.clamp(images * 0.5 + 0.5, 0, 1) - return images.cpu().numpy().transpose(0, 2, 3, 1) * 255 + if show: + plt.show() + if save is not None: + fig.savefig(save) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 776918b..e7b21a4 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,10 +1,10 @@ -from quac.data import PairedImageFolders +from quac.data import PairedImageDataset from torchvision import transforms def test_paired_image_folders(): transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]) - dataset = PairedImageFolders( + dataset = PairedImageDataset( "/nrs/funke/adjavond/data/synapses/test", "/nrs/funke/adjavond/data/synapses/counterfactuals/stargan_invariance_v0/test", transform=transform,