Official PyTorch implementation for our ICML 2023 paper, Regularization-free Diffeomorphic Temporal Alignment Nets.
-
Clone the repository:
git clone https://github.com/BGU-CS-VIL/RF-DTAN.git
-
Create a new conda environment:
conda create --name rfdtan python=3.9
-
Activate the conda environment:
conda activate rfdtan
-
Install the required dependencies:
pip install -r requirements.txt
To run the training process, execute the following command:
python train_model.py --dataset ECGFiveDays --ICAE_loss
Replace ECGFiveDays
with the desired dataset name and add any additional arguments as needed.
We support the following losses:
- ICAE_loss
- ICAE_triplet_loss
- WCSS_loss
- WCSS_triplet_loss
- smoothness_prior
ICAE - Inverse Consistecny Averaging Error
WCSS - Within-Class Sum of Squares
difw==0.0.29
matplotlib==3.5.1
numpy==1.20.3
scikit_learn==1.0.2
torch==1.10.1
tqdm==4.62.3
tsai==0.2.24
tslearn==0.5.2