A single layer LSTM part-of-speech tagger implemented in JAX (+Flax) on the batterydata/pos_tagging
dataset
from Huggingface Datasets.
Make sure that you have a wandb account and have logged in using your API key.
wandb login
Then run main.py
with the following arguments:
python main.py --lr 0.01 --epochs 5 --batch-size 128 --seed 2023 --dropout 0.2 \
--embedding-dim 300 --hidden-dim 300 --max_seq_len 300 \
--pad_token_idx 1 --log_every_n_step 100
The Trainer module is defined to train, evaluate and log to wandb simultaneously.
Check the wandb metrics here.
Version Requirements:
- Python 3.11
- CUDA 12.2
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --extra-index-url https://download.pytorch.org/whl/cpu