Skip to content

sezan92/pytorch-siamese-triplet

 
 

Repository files navigation

Deep metric learning using Triplet network in PyTorch

The following repository contains code for training Triplet Network in Pytorch
Siamese and Triplet networks make use of a similarity metric with the aim of bringing similar images closer in the embedding space while separating non similar ones.
Popular uses of such networks being -

  • Face Verification / Classification
  • Learning deep embeddings for other tasks like classification / detection / segmentation

Paper - Deep metric learning using Triplet network

Installation

Install PyTorch

pip install -r requirements.txt  

Demo

Colab notebook with pretrained weights

Training

python train.py --cuda  

This by default will train on the MNIST dataset

MNIST / FashionMNIST

python train.py --result_dir results --exp_name MNIST_exp1 --cuda --dataset <manist>/<fmnist>  

To create a tSNE visualisation

python tsne.py --ckp <path to model>  

The embeddings and the labels are stored in the experiment folder as a pickle file, and you do not have to run the model everytime you create a visualisation. Just pass the saved embeddings as the --pkl parameter

python tsne.py --pkl <path to stored embeddings>  

Sample tSNE visualisation on MNIST tSNE

Specify the location of the dataset in test.yaml
The directory should have the following structure

+-- root
|   +-- train
|       +-- class1
|           +-- img1.jpg
|           +-- img2.jpg
|           +-- img3.jpg
|       +-- class2
|       +-- class3
|   +-- test
|       +-- class4
|       +-- class5
python train.py --result_dir results --exp_name VGGFace2_exp1 --cuda --epochs 50 --ckp_freq 5 --dataset vggface2 --num_train_samples 32000 --num_test_samples 5000 --train_log_step 50 

Custom Dataset

Specify the location of the dataset in test.yaml
The directory should have the following structure

+-- root
|   +-- train
|       +-- class1
|           +-- img1.jpg
|           +-- img2.jpg
|           +-- img3.jpg
|       +-- class2
|       +-- class3
|   +-- test
|       +-- class4
|       +-- class5
python train.py --result_dir results --exp_name Custom_exp1 --cuda --epochs 50 --ckp_freq 5 --dataset custom --num_train_samples 32000 --num_test_samples 5000 --train_log_step 50 

TODO

  • Train on MNIST / FashionMNIST
  • Train on a public dataset
  • Multi GPU Training
  • Custom Dataset
  • Include popular models - ResneXT / Resnet / VGG / Inception

About

One-Shot Learning with Triplet CNNs in Pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%