Skip to content

Latest commit

 

History

History
107 lines (82 loc) · 4.74 KB

README.md

File metadata and controls

107 lines (82 loc) · 4.74 KB

Continual Learning in Open-vocabulary Classification with Complementary Memory Systems

TMLR 2024

OpenReview ArXiv Project Page YouTube Video

Authors

Zhen Zhu · Weijie Lyu · Yao Xiao · Derek Hoiem

Overview

We introduce a method for flexible and efficient continual learning in open-vocabulary image classification, drawing inspiration from the complementary learning systems observed in human cognition. Specifically, we propose to combine predictions from a CLIP zero-shot model and the exemplar-based model, using the zero-shot estimated probability that a sample's class is within the exemplar classes. We also propose a "tree probe" method, an adaption of lazy learning principles, which enables fast learning from new examples with competitive accuracy to batch-trained linear models. We test in data incremental, class incremental, and task incremental settings, as well as ability to perform flexible inference on varying subsets of zero-shot and learned categories. Our proposed method achieves a good balance of learning speed, target task effectiveness, and zero-shot effectiveness.

Hardware

We test our code on a single NVIDIA RTX 3090Ti GPU.

Installation

Prerequisites

  • Anaconda or Miniconda
  • Git

Setup

  1. Clone the repository:

    git clone https://github.com/jessemelpolio/TreeProbe.git
    cd TreeProbe
    
  2. Create and activate the Conda environment:

    conda env create -f environment.yml
    conda activate TreeProbe
    

Project Structure

  • data/: Dataset handling and preprocessing
  • encode_features/: Scripts for encoding features using CLIP
  • engines/: Engine implementations for training and evaluation
  • models/: Model architectures and components
  • options/: Command-line argument parsing
  • scripts/: Utility scripts
  • main_xx.py: Main entry point for running experiments. xx can be data, task, or class.

Usage

  1. Prepare datasets: Our project uses various datasets for target tasks and zero-shot tasks.

    Click to expand dataset details

    Target Tasks: CIFAR100, SUN397, EuroSAT, OxfordIIITPet, Flowers102, FGVCAircraft, StanfordCars, Food101

    Zero-shot Tasks: ImageNet, UCF101, DTD

    Note: SUN397, EuroSAT, UCF101, and ImageNet require manual downloading from their original sources. Please follow the instructions in tutorials/download_data.md to obtain these datasets. Other datasets can be easily downloaded using the torchvision.datasets package. We also provide additional datasets in the data/ folder for your convenience but be aware that they are not tested rigorously and may not work with the codebase.

    To encode the intermediate image representations of these datasets to speed up training, check the script in scripts/encode_features.sh. After setting the correct data root in the script, you can run the script with:

    bash scripts/encode_features.sh
    
  2. Train and evaluate: Example scripts for task, data, and class-incremental learning:

    Click to expand example scripts
    bash scripts/task_incremental.sh
    
    bash scripts/data_incremental.sh
    
    bash scripts/class_incremental.sh
    

Warning

This codebase is only tested under a single GPU. If you want to use multiple GPUs, you need to modify the codebase.

We'd appreciate it if you could report any issues you encounter.

Bibtex

If you use this code for your research, please consider citing:

@article{zhu2024treeprobe,
  author       = {Zhen Zhu and Weijie Lyu and Yao Xiao and Derek Hoiem},
  title        = {Continual Learning in Open-vocabulary Classification with Complementary Memory Systems},
  journal      = {Trans. Mach. Learn. Res.},
  volume       = {2024},
  year         = {2024},
  url          = {https://openreview.net/forum?id=6j5M75iK3a}
}

Acknowledgements

  • This project uses DINOv2 by Facebook Research.
  • The project incorporates CLIP for vision-language learning.
  • The arguments configuration is inspired from SPADE.
  • This codebase shares a lot in common with AnytimeCL.