This repo contains the implementation of this paper which adapts MAML (Finn et al., 2017) and MMAML (Vuorio et al., 2019) to Time Series Regression.
The multimodal-meta-learning is based on this official implementation.
- Python 3.7.0
- Pytorch 1.4.0
- Learn2learn 0.1.1
The code can be used on two open datasets that need to be pre-processed before running MAML or MMAML. The data is available on:
-
Air Pollution Dataset: PM2.5 Data of Five Chinese Cities Data Set.
-
Heart Rate Dataset: PPG-DaLiA Data Set.
- Download data and create the following folder structure:
MMAML-TSR/
├── logs
├── MAML_output/
├── MMAML_output/
...
├── data
├── code
├── tools
├── pre_processing
├── models
...
- Preprocess and generate .pickle file
Change the paths to the raw data in the file pre_processing/ts_dataset.py
accordingly. Then run pre_processing/dataset_creation.ipynb
to pickle the object with the transformed data. For a new dataset, a loading functionality should be created by taking our datasets as reference. Optionally, you can download the preprocessed data HERE.
- Run MAML
Assuming that the pickled files are in data/
. Training with the default parameters on the Air Pollution Dataset works as:
cd code/
python run_MAML.py
To train on Heart-rate data:
cd code/
python run_MAML.py --dataset HR
- Run MMAML
Assuming that the pickled files are in data/
. Training with the default parameters on the Air Pollution Dataset works as:
cd code
python run_MMAML.py
To train on Heart-rate data:
cd code
python run_MMAML.py --dataset HR
If this repository is useful, please cite us as:
@inproceedings{arango2021multimodal,
title={Multimodal meta-learning for time series regression},
author={Arango, Sebastian Pineda and Heinrich, Felix and Madhusudhanan, Kiran and Schmidt-Thieme, Lars},
booktitle={Advanced Analytics and Learning on Temporal Data: 6th ECML PKDD Workshop, AALTD 2021, Bilbao, Spain, September 13, 2021, Revised Selected Papers 6},
pages={123--138},
year={2021},
organization={Springer}
}
To ask questions or report issues, please open an issue on the issues tracker.