Skip to content

Using Pytorch Lightning and Torchxrayvision's Pretrained Densenet121 Models

License

Notifications You must be signed in to change notification settings

ihamdi/Covid-xRay-Classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

60 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Covid xRay Binary Classification

Using Pytorch Lightning and Torchxrayvision's Pretrained Densenet121 Models

Negative Typical
1e6f48393e17_03 09cf9767a7bf

Installation

  1. Create conda environment
conda create --name env-name python=3.6.13

     *Python 3.6.13 is needed since GDCM is not supported on versions above 3.6

  1. Clone Github
git clone -v https://github.com/ihamdi/Covid-xRay-Classification.git /your/directory/

       or download and extract a copy of the files.

  1. Install PyTorch according to your machine. For example:
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
  1. Install dependencies from requirements.txt file:
pip install -r requirements.txt
  1. Download Data

       Run python /scripts/download_data.py to download the data using the Kaggle API and extract it automatically. If you haven't used Kaggle API before, please take a look at the instructions at the bottom on how to get your API key.

       Otherwise, extract the contents of the "train" directory from the official SIIM-FISABIO-RSNA COVID-19 Detection page to the train folder inside the data directory.

Folder Structure

  1. bash directory contains schedule.sh bash file used to run multiple experiments in sequence.
  2. configs directory contains configuration files used for different experiments, callbacks, datamodules, sweeps, etc.
  3. data directory contains train folder where the data is expected to be.
  4. scripts directory contains download_data.py used to download the dataset directly from Kaggle.
  5. src directory contains folders where callbacks, datamodules, models, and utils functions, as well as training, are defined.

How to use:

Using Pytorch Lightning with the Hydra-Lightning-Template enables us to have experiments with different configurations/hyperparameters on the same dataset or doing a sweep with Optuna:

Experiments:

The experiment folder inside configs directory contains a template for configuring an experiment. The easiest way is to make a copy of template.yaml and edit the parameters accordingly.

If num_classes is set to 4, then the data will be a random mix from all labels. Otherwise, the code will default to binary classification, and the data will be a balanced mix of negative and non-negative labeled images (randomly chosen from the other 3 classes). The program also rejects any data folders with more than 1 xray files to avoid training on lateral chest xrays.

To run the default experiment, run the following command

python train.py

or

python train.py experiment=template

This will run an experiment based on the template using the following configuation:

  1. 20 epochs (unless early stopping is triggered)
  2. Torchxrayvision's "ALL" (pretrained Densenet121) model with no Dropout
  3. Adam optimizer with learning rate of 0.003 and AMSGrad enabled.
  4. Batch size of 32
  5. Number of workers of 10
  6. 640 images only
  7. 60 : 20 : 20 split
  8. Image size of 128x128
  9. IMG-MIN/(MAX-MIN)x255 normalization
  10. No augmentations

*Torchxrayvision models expect 224 so the code defaults to that automatically if one of them is chosen

Hyperparameter Search with Optuna:

As part of the Hydra template, Optuna can be used to find the best hyperparameters within a defined range. A template configuration file can be found within hparams_search folder inside the configs directory. The template hyperparameter search can be initiated using

python run.py -m hparams_search=template_optuna experiment=template

or

python run.py -m hparams_search=template_optuna experiment=template hydra.sweeper.n_trials=30

Results

When in binary classification mode, the code is able to produce a model with just over 80% accuracy on the validation data before it starts overfitting:

W B Chart 12_14_2021, 9_37_59 AM

This was done using the "All" model from torchxrayvision, Adam optimizer, and the following hyperparameters:

  1. drop_rate (Dropout) = 0
  2. lr (Learning Rate) = 0.0003
  3. amsgrad (for Adam) = True
  4. normal (Normalization Method) = 0 (img-min/(max-min)*255)
  5. rotation = 11.355
  6. scaling = 0.2789
  7. shear = 1
  8. translation = 0.07357
  9. horizontal_flip = True
  10. vertical_flip = True
  11. dataset_size (Sample Size) = 3350 (maximum possible to keep subset balanced)
  12. train_val_test_split = [70,20,10]
  13. batch_size = 156
  14. num_workers = 20

Although the code accepts setting num_classes=4, it is currently unable to achieve a validation accuracy higher than 60-62% regardless of hyperparameter tuning:

W B Chart 12_14_2021, 9_36_37 AM

F1 Heatmap & Confusion Matrix below show that it is especially not doing well at classifying xrays with Atypical appearance (label 3):

media_images_confusion_matrix_winter-totem-6_423_3c808d26766df4b83257 media_images_f1_p_r_heatmap_winter-totem-6_423_c6e8de7712ab2e94fcac

I am currently investigating whether it is possible to achieve significantly improvement in the performance of the model. Dropout and Augmentations seem to have an adverse effect on the accuracy so I might need to switch to a different architecture altogether.


Background:

Initially, this code was based on my Dogs vs Cats code. I eventually adopted the Lightning-Hydra-Template to make it easier to use and log. No submission is made to the Kaggle competition and only the training data is used.


Contact:

For any questions or feedback, please feel free to post comments or contact me at [email protected]


Referernces:

Torchxravision's page on Github (used for Densenet121 models pretrained on xrays).

Pytorch Lightning's page on Github

Lightning Hydra Template's page on Github.

Weights & Biases's website.

Densenet paper by Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger.


Getting Key for Kaggle's API

image

image

image

Getting Key for Weights & Biases (wandb):

Screenshot from 2021-12-14 11-50-20