Skip to content

A plug and play framework for Temporal Fusion Transformer. Predict your future!

License

Notifications You must be signed in to change notification settings

anhphan2705/temporal_fusion_transformer_plugnplay

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

75 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Temporal Fusion Transformer PlugNPlay

This project provides a Temporal Fusion Transformer (TFT) framework for time series forecasting. The TFT model is known for handling multiple time series with temporal relationships and incorporates various data sources efficiently for future prediction. With this project, users can plug in their custom data, configure preprocessing, and train the TFT model quickly and easily.

Table of Contents

Features

  • Easy integration of custom datasets
  • Flexible configuration for preprocessing and model parameters
  • Automated training and evaluation pipeline
  • Baseline comparison and visualization of predictions

Installation

To install the necessary dependencies, run:

git clone https://github.com/anhphan2705/temporal_fusion_transformer_plugnplay.git
cd temporal_fusion_transformer_plugnplay
python -m venv tft_env
source tft_env/bin/activate  # On Windows, use `tft_env\Scripts\activate`
pip install -r requirements.txt

Usage

Datasets

  1. To use your own data, place it in the ./data directory.
  2. Ensure that your data is formatted appropriately and includes any necessary preprocessing scripts in the ./datasets directory.
  3. Make an import of your preprocess data file to ./tools/data_process.py, create an option to use it in the data_pipeline

Configuration

Make a copy of the sample_config.yaml file to set the desired parameters for training, hyperparameter tuning, and logging. This file includes paths for data, model checkpoints, and logging directories, as well as hyperparameters for the TFT model.

Training

To train the model, run:

python main.py --mode train --config configs/your_config.yaml

This will start the training pipeline, which includes data loading, model initialization, hyperparameter tuning, and final training.

Evaluation

To evaluate the model, run:

python main.py --mode eval --config configs/your_config.yaml --model path_to_your_model

The training pipeline evaluates the model against a baseline model. Validation metrics (loss, RMSE, MAE) are logged during training and evaluation.

Visualization

After training, the script generates plots comparing the trained model's predictions with actual data and the baseline model's predictions. These plots are saved in the specified logs directory.

Reading PyTorch Lightning Logs

The events.out.tfevents... file generated by PyTorch Lightning is meant to be read and visualized using TensorBoard. To read and interpret this log file, follow these steps:

  1. Install TensorBoard (if not already installed):

    pip install tensorboard
  2. Navigate to the Directory containing your log files:

    cd path_to_your_log_directory
  3. Run TensorBoard:

    tensorboard --logdir=.
  4. Open TensorBoard in Your Browser: Open a web browser and go to http://localhost:6006/. You should see the TensorBoard dashboard, which will visualize all the metrics, hyperparameters, and other logged data from your training process.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request. If you are adding new data preprocessing scripts, place them in the ./datasets directory and update tools/data_process.py to include an option for your preprocessing.

License

This project is licensed under the Apache-2.0 License. See the LICENSE file for details.