Skip to content
/ LitMNIST Public template

Deep learning project template based on PyTorch Lightning and Hydra.

License

Notifications You must be signed in to change notification settings

XavierJiezou/LitMNIST

Repository files navigation

logo

LitMNIST

基于 PyTorch Lightning + Hydra 的深度学习项目模板。

(以 MNIST 分类任务为例)

点击 Use this template 即可使用该模板来初始化你的新仓库。

GitHub Workflow Release Status GitHub Workflow Test Status GitHub Workflow Lint Status Codacy Badge codecov PyPI PyPI - Downloads GitHub stars GitHub forks GitHub issues GitHub license

Python PyTorch Lightning Config: hydra Code style: black

观看演示报告错误功能需求

English 简体中文

喜欢这个项目吗?请考虑捐赠(微信 | 支付宝),以帮助它改善!

目录

LitMNIST

演示

demo

安装

开始之前,你必须熟练使用 PyTorch Lightning,并对 Hydra 有一定的了解。

  1. 克隆仓库到本地
git clone https://github.com/XavierJiezou/LitMNIST.git
cd LitMNIST
  1. 创建并激活 conda 虚拟环境
conda create -n myenv python=3.8
conda activate myenv
  1. 安装项目依赖包(如需安装 GPU 版 PyTorch,请参考官网安装教程
pip install -r requirements.txt

运行

CPU

python train.py

GPU

python train.py \
trainer.gpus=4 \
+trainer.strategy=ddp_find_unused_parameters_false \
datamodule.num_workers=16 \
datamodule.pin_memory=True \
datamodule.persistent_workers=True

结构

项目的主要目录结构如下:

├── configs # 存放 Hydra 配置文件
│   ├── callbacks # Callbacks 配置(例如 EarlyStopping、ModelCheckpoint 等)
│   ├── datamodule # Datamodule 配置(例如 batch_size、num_workers 等)
│   ├── debug # 调试配置
│   ├── experiment # 实验配置
│   ├── hparams_search # 超参数搜索配置
│   ├── local # 本地配置(暂时可以忽略)
│   ├── log_dir # 日志存放目录配置
│   ├── logger # 日志配置
│   ├── model # 模型配置
│   ├── trainer # Trainer 配置
│   │
│   ├── test.yaml # 测试的主要配置
│   └── train.yaml # 训练的主要配置
│
├── data # 存放项目数据
│
├── logs # 存放项目日志(Hydra 日志 和 PyTorch Lightning loggers 生成的日志)
│
├── src # 项目源代码
│   ├── datamodules # LightningDataModule
│   ├── models # 存放基于原生 PyTorch 框架编写的模型
│   ├── litmodules # LightningModule
│   ├── utils # 存放一些实用的脚本(例如数据预处理的脚本)
│   │
│   ├── testing_pipeline.py # 测试流水线(实例化对象)
│   └── training_pipeline.py # 训练流水线(实例化对象)
│
├── tests # 单元测试(可选)
│
├── test.py # 开始测试(加载配置文件)
├── train.py # 开始训练(加载配置文件)
│
├── .env # 存储私有环境变量(例如 wandb 的 API_KEY)【注意:该文件不受版本控制】
├── .gitignore # 设置版本控制需要排除的文件或目录(例如 .env 文件)
├── requirements.txt # 项目依赖环境(pip install -r requirements.txt)
└── README.md # 项目概述文档

用法

本仓库是一个基于 PyTorch Lightning + Hydra 的深度学习项目模板。因此你仅需要套用该模板,并作出如下修改:

  1. 编写你自己的 PyTorch nn.Module 模型(参见 src/models/simple_densenet.py
  2. 编写你自己的 PyTorch Lightning LightningModule(参见 src/litmodules/mnist_litmodule.py
  3. 编写你自己的 PyTorch Lightning LightningDataModule(参见 src/datamodules/mnist_datamodule.py
  4. 编写你自己的实验配置文件(参见 configs/experiment/example.yaml
  5. 使用选定的实验配置运行训练代码:python train.py experiment=<EXPERIMENT_NAME>

基础

train.py 集成了模型训练验证测试的一整套工作流,安装好环境后,运行即可:

python train.py

test.py 仅包含测试步骤,允许你单独加载预训练模型进行测试(但要指定模型路径):

python test.py ckpt_path=checkpoints/last.ckpt

进阶

  • 从命令行覆盖任何配置参数

Hydra 允许你轻松覆盖配置文件中定义的任何参数。

train.py 默认从 configs/train.yaml 中获取参数。因此,你可以先修改 yaml 配置文件中的参数,然后再运行。

或者,你也可以在命令行中直接指定参数。命令行中参数的优先级要大于 yaml 配置文件中参数的优先级。

python train.py trainer.max_epochs=3

对于某些不太重要的参数,它们没有在 yaml 配置文件中定义,因此你在命令行中指定的时候必须添加 +

python train.py +trainer.precision=16
  • 在 CPU、GPU、多 GPU 和 TPU 上训练

PyTorch Lightning 使得在不同硬件上训练模型变得容易。

在 CPU 上训练

python train.py trainer.gpus=0

在 GPU 上训练

python train.py trainer.gpus=1

在 TPU 上训练

python train.py +trainer.tpu_cores=8

基于 DDP(Distributed Data Parallel,分布式数据并行)的训练【4 个 GPU】

python train.py trainer.gpus=4 +trainer.strategy=ddp

基于 DDP(Distributed Data Parallel,分布式数据并行)的训练【8 个 GPU,两个节点】

python train.py trainer.gpus=4 +trainer.num_nodes=2 +trainer.strategy=ddp
  • 混合精度训练

PyTorch Lightning 允许你使用半精度或混合精度以减少训练期间的内存占用。(在 GPU 上能够实现 3 倍的加速效果,但可能损失精度)

python train.py trainer.gpus=1 +trainer.precision=16
  • 使用 PyTorch Lightning 中的日志记录器来记录训练日志

PyTorch Lightning 集成了多种主流日志记录框架,包括 TensorBoard 和 Weights&Biases等。

这里以 wandb 为例,展示如何使用:

  1. 安装 wandb
pip install wandb
  1. 转到 wandb.ai/authorize 获取 API key

  2. 执行 login 命令(需要用到上一步获取的 API key

wandb login
  1. configs/logger/ 目录下新建一个名为 wandb.yaml 的文件,并写入以下内容
wandb:
  _target_: pytorch_lightning.loggers.wandb.WandbLogger
  project: "mnist"
  1. 执行训练代码的时候指定 loggerwandb
python train.py logger=wandb
  • 根据自定义实验配置来训练模型

配置文件见 configs/experiment/

python train.py experiment=example
  • 带回调函数的训练

配置文件见 configs/callbacks/

python train.py callbacks=default
  • 使用 Pytorch Lightning 中的训练策略

点击这里了解 Pytorch Lightning 中的各种训练策略

梯度裁剪来避免梯度爆炸

python train.py +trainer.gradient_clip_val=0.5

随机加权平均可以使您的模型更好地泛化

python train.py +trainer.stochastic_weight_avg=true

梯度累积

python train.py +trainer.accumulate_grad_batches=10
  • 轻松调试

配置文件见 configs/debug/

默认调试模式(运行 1 个 epoch)

python train.py debug=default

仅对 test epoch 进行调试

python train.py debug=test_only

执行一次 train,val 和 test 步骤(仅使用 1 个 batch)

python train.py +trainer.fast_dev_run=true

训练完成后打印各个阶段的执行时间(用于快速发现训练瓶颈)

python train.py +trainer.profiler="simple"
  • 断点续训
python train.py trainer.resume_from_checkpoint="/path/to/name.ckpt"
  • 一次执行多个实验

例如,下方代码将按顺序运行所有参数组合(共 6 个)的实验。

python train.py -m datamodule.batch_size=32,64,128 litmodule.lr=0.001,0.0005

此外,你也可以执行 /configs/experiment/ 目录下的的所有实验

python train.py -m 'experiment=glob(*)'
  • 使用 Optuna 进行超参数搜索

Optuna Sweeper plugin | Hydra

  1. 安装 hydra-optuna-sweeper 插件
pip install hydra-optuna-sweeper
  1. 修改 configs/hparams_search/ 目录下的配置文件

  2. 执行训练代码的时候指定 hparams_search

python train.py -m hparams_search=mnist_optuna
  • 使用 Tab 键智能提示可选配置参数

Tab completion | Hydra

$ eval "$(python train.py -sc install=bash)" # 安装
$ python train.py logger= # 按下 Tab 键后会智能提示有哪些可选参数
logger=comet         logger=csv           logger=many_loggers  logger=mlflow        logger=neptune       logger=tensorboard   logger=wandb

提示

  • .env 文件中设置私有环境变量
  1. 例如,你可以将 cometAPI Key 添加到 .env 文件中
COMET_API_KEY="xxx"
  1. 并在配置文件 configs/logger/comet.yaml 中进行调用
comet:
  _target_: pytorch_lightning.loggers.comet.CometLogger
  api_key: ${oc.env:COMET_API_KEY}
  1. 在训练的时候指定 logger 参数为 comet
python train.py logger=comet

注意:.env 文件不应受版本控制,因此我们已将其添加到 .gitignore 文件中了。

推荐使用 PytorchLightning 官方提供的 torchmetrics 库来计算指标(像准确率,F1 score 和混淆矩阵等)。这对于多 GPU 训练尤为重要!并且,推荐对每个步骤使用不同的指标实例,以确保所有 GPU 进程都有正确的累积值。下面给出了一个简单示例。

from torchmetrics.classification.accuracy import Accuracy


class LitModel(LightningModule):
    def __init__(self)
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()

    def training_step(self, batch, batch_idx):
        ...
        acc = self.train_acc(predictions, targets)
        self.log("train/acc", acc)
        ...

    def validation_step(self, batch, batch_idx):
        ...
        acc = self.val_acc(predictions, targets)
        self.log("val/acc", acc)
        ...
  • 可以使用 DVC 对数据和模型这些大文件进行版本控制
dvc init
dvc add data/MNIST
git add data/MNIST.dvc data/.gitignore
git commit -m "Add raw data"

更新

CHANGELOG.md

证书

MIT License

参考

此模板引用了以下仓库并进行了一些细微的修改。

Readme Card