Skip to content

ShawonAshraf/postagger-lstm-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

postagger-lstm-jax

A single layer LSTM part-of-speech tagger implemented in JAX (+Flax) on the batterydata/pos_tagging dataset from Huggingface Datasets.

Usage

Make sure that you have a wandb account and have logged in using your API key.

wandb login

Then run main.py with the following arguments:

python main.py --lr 0.01 --epochs 5 --batch-size 128 --seed 2023 --dropout 0.2 \
    --embedding-dim 300 --hidden-dim 300 --max_seq_len 300  \
    --pad_token_idx 1 --log_every_n_step 100

The Trainer module is defined to train, evaluate and log to wandb simultaneously.

Logs and Results

Check the wandb metrics here.

Environment Setup

Version Requirements:

  • Python 3.11
  • CUDA 12.2
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --extra-index-url https://download.pytorch.org/whl/cpu

About

LSTM POS Tagger implementation in Jax and Flax

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages