[Research Report] | [WandB Runs]
This project is an enhanced implementation of the T5 model using JAX. By adopting a functional approach and harnessing the capabilities of JAX, this implementation strives for superior performance and compatibility with Google Cloud TPU. Beyond the technical advancements, this project's motivation stems from the desire to establish a cleaner T5 codebase and to serve as a valuable educational resource for both AI researchers and engineers, facilitating their understanding and exploration of Transformer-based LLM architectures.
This project is supported by Cloud TPUs from Google's TPU Research Cloud (TRC).
This project is inspired by ayaka/bart-base-jax, while the code for this project is entirely written by myself.
- Setup Instructions
- Usage examples
- Discoveries
- Analysis
- 1. JAX precision
- 2. Layer normalisation
- 3. Dropout
- 4. Scaling QK matrices
- 5. Relative Attention Bias / Position embeddings
- 6. Layer norm in T5 does not subtract mean
- 7. T5 employs a final layer norm on the output of the encoder and decoder
- 8. T5 uses tied word embeddings
- 9. T5 also rescales the decoder output for tied word embedding in the language model head
- T5 Jax Implementation Results
- Fine-tuning
-
Install jax
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
-
Then install requirements.txt:
pip install -r requirements.txt
-
Tokenize inputs
from transformers import AutoTokenizer, FlaxT5ForConditionalGeneration model = FlaxT5ForConditionalGeneration.from_pretrained("allenai/unifiedqa-t5-base") tokenizer = AutoTokenizer.from_pretrained("t5-base") inputs = tokenizer( ["summarize: My friends are cool but they eat too many carbs."], return_tensors="np" ) input_ids = inputs["input_ids"]
-
Initialize model parameters
from utils.params_utils import init_params_pretrained params = init_params_pretrained()
-
Encoder
from model.transformer_encoder import fwd_transformer_encoder encoder_output = fwd_transformer_encoder( encoder_params=params["encoder"], embedding_params=params["shared"], input_ids=input_ids, )
-
Decoder
from model.transformer_decoder import fwd_transformer_decoder decoder_start_token_id = model.config.decoder_start_token_id decoder_input_ids = ( jnp.ones((encoder_input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id ) decoder_output = fwd_transformer_decoder( decoder_params=params["decoder"], embedding_params=params["shared"], decoder_input_ids=decoder_input_ids, encoder_output=encoder_output, )
-
Generate
from model.t5_generate import fwd_t5_generate from config import config sequences = fwd_t5_generate( params, encoder_input_ids=input_ids, eos_token_id=config.EOS_TOKEN_ID, decoder_start_token_id=config.DECODER_START_TOKEN_ID, ) output = tokenizer.batch_decode(sequences, skip_special_tokens=True)
I discovered an issue in the Hugging Face transformers FlaxT5. Their hidden states output were not consistent with my outputs.
I observed that my encoder and decoder block 11
hidden state
does not align with their block 11
hidden_state
even though my hidden states
from block 0
to 10
aligns with the their hidden states
from block 0
to 10
. Additionally, my final hidden state
(after applying the layer norm) also aligns with their final hidden state
after the layer norm.
I then raised an issue and made a PR to fix this issue.
- On TPU, JAX defaults to using
bfloat16
for matrix multiplication even when the data type is specified asfloat32
. While this may speed up training, some precision is lost. - When utilizing GPU, the Hugging Face transformers model exhibits distinct precision compared to JAX.
T5 performs pre-layer norm instead of post-layer norm.
Attention:
(layer norm -> self attention -> dropout -> add)
instead of(self-attention -> dropout -> add -> layer norm)
Feed foward:
(layer norm -> densereludense -> dropout -> add)
instead of(densereludense -> dropout -> add -> layernorm)
- drop out performed once at the end in ff
(linear -> linear -> dropout)
instead of twice after each linear layer(linear -> dropout -> linear -> dropout)
Hugging Face T5 does not scale the QK matrices
The T5 paper did not mention the exclusion of QK matrix scaling.
instead of
T5's position embeddings (relative attention bias) is different from Self-Attention with Relative Position Representations. (Hugging Face's implementation)
- Uses binned relative attention bias to reduce time complexity for long sequences
- Only applies the bias before
$\text{softmax}$
It is not mentioned in the T5 paper that they only apply the bias before the
$\text{softmax}$
instead of
Where:
-
$Q$ is the query matrix -
$K$ is the key matrix -
$V$ is the value matrix -
$d_k$ is the dimension of the keys
In the case of multi-head attention, the above process is performed multiple times with different learned linear transformations of the original (Q), (K), and (V). If we have (h) heads, then we have:
where each head is defined as:
$$ \text{head}i = \text{Attention}(QW{Qi}, KW_{Ki}, VW_{Vi}) $$
Layer Norm Definition
Given an input
Where:
-
$\mu$ is the mean of the input$x$ :$\mu = \frac{1}{d} \sum_{i=1}^{d} x_i$ -
$\sigma^2$ is the variance of the input$x$ :$\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2$ -
$\gamma$ and$\beta$ are learnable parameters (the weight and bias), which have the same dimension as$x$ . -
$\epsilon$ is a small number for numerical stability, typically on the order of$10^{-5}$ to$10^{-8}$ .
T5 Layer Norm
T5's layer norm does not subtract the mean (
The T5 paper did not mention that they used Root Mean Square Layer Normalization
Root mean Square Layer Normalization Formula:
Where:
-
$g_i$ is the gain (weight) parameter -
$a_i$ is the inputs -
$\bar{a_i}$ is the scaled values of the inputs -
$RMS(a)$ is the root mean square$a$ .
In the original transformer model proposed by Vaswani et al., 2017, there is no final layer normalization on the outputs of the encoder and decoder. The outputs of these components are fed directly into subsequent operations.
In the T5 model, there is a final layer normalization step after the output from both the encoder and decoder.
T5 uses tied word embeddings
, which is layered upon the output of the final decoder. This differs from the conventional Transformer architecture, which uses a linear layer for the language model head (lm_head
).
However, for T5 during training, the lm_head
is the transpose of the word embedding. This reduces the number of trainable parameters in the model by sharing the same embeddings for the input and output layers. This not only decreases the computational load, but also helps in regularizing the model, leading to an improved generalization ability and potentially better performance on unseen data.
The output of the final decoder block is fed into a dense layer with a softmax output, whose weights are shared with the input embedding matrix.
The rescaling of decoder output before passing it into the lm_head is not mentioned in the T5 paper
However, their T5 implementation scales the decoder output.
Where:
-
$x$ is the decoder output. -
$y$ is the logits. -
$d_{\text{model}}$ is the dimensionality of the model. -
$W_e$ is the input embeddings used for tie word embeddings. -
$\mathrm{lm_head}$ is the input embeddings used for tie word embeddings.
Input | Hugging Face Output | My Output |
---|---|---|
translate English to German: That is good. | 'Das ist gut so.' | 'Das ist gut so.' |
cola sentence: The course is jumping well. | acceptable | acceptable |
stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field. | 4.0 | 4.0 |
summarize: In recent times, rapid advancements in technology have revolutionized various industries, enhancing efficiency, connectivity, and convenience for individuals and businesses alike. | rapid advancements in technology have revolutionized various industries | rapid advancements in technology have revolutionized various industries |
The inputs above are fed into the Hugging Face transformers model and my own model. Generation was repeated 100 times and here is the total time taken:
Device | Hugging Face | Mine | Speed Improvement |
---|---|---|---|
GPU | 190.63s | 64.36s | 66.24% faster |
TPU | 466.59s | 42.31s | 90.93% faster |
In a direct comparison, my implementation achieves comparable results to Hugging Face's implementation, while also demonstrating superior performance in terms of speed. Both implementations produced identical translations, acceptability scores, and summarization outputs in the provided examples. However, my implementation outperforms Hugging Face's implementation, completing the tasks approximately 90.93% faster on TPU and 66.24% faster on GPU.
Upon reading the original T5 paper, I discovered that it primarily focused on translating English to German, French, and Romanian. This sparked my curiosity about whether the model could also handle translating from French to English. To test this, I utilized the pre-trained model and applied a task prefix of "translate French to English: ". Unfortunately, the model proved incapable of performing the desired translation. Determined to overcome this limitation, I embarked on the journey of fine-tuning my own model specifically tailored for the task of French to English translation.
For more in-depth information regarding my fine-tuning process, please read the research report or visit the GitHub branch or explore the WandB runs. These resources provide additional insights into the details of my fine-tuning procedure.
To finetune my model, I utilized the wmt-14 fr-en dataset, which consists of approximately 40 million data entries for the training set, and around 3,000 rows for the test and validation sets.
For a comprehensive understanding and detailed analysis of my findings, I invite you to explore my research report.
Through rigorous experimentation with different factors such as initialising the language model head with embeddings, scaling the decoder output, task prefix, trying different Adafactor learning rates, and testing various optimisers, my results reveal that the optimal configuration comprises of using the Adafactor optimiser with a learning rate of 0.001, coupled with a scaled decoder output and embedding initialised lm_head
.
Optimiser | Steps | Generation BLEU | One Pass BLEU |
---|---|---|---|
Original (No fine-tuning) | 1.01 | 16.43 | |
Adafactor, 0.001, scale | 20,480 | 29.96 | 34.89 |
40,960 | 30.61 | 35.44 | |
61,440 | 30.64 | 35.50 | |
81,920 | 31.20 | 36.12 |
Prior to fine-tuning the model specifically for French to English translation, it only managed to secure a modest BLEU score of 16.43. After fine-tuning the model with the best configuration (Adafactor optimizer, learning rate of 0.001, scaled decoder output, and embedding lm_head), the fine-tuned model achieved a BLEU score to 31.20, almost twice of the non-fine-tuned model.
Task | Input | Output |
---|---|---|
English to German | "translate english to german: That is good" | "That is good news." |
Summarize | "summarize: In recent times, rapid advancements in technology have revolutionized various industries, enhancing efficiency, connectivity, and convenience for individuals and businesses alike." | "Rapid advances in technology have revolutionized various industries recently, increasing efficiency, connectivity, and convenience for individuals and businesses alike." |
Semantic Textual Similarity Benchmark | "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field." | "sentence1: The rhinograzed on the grass. sentence2: A rhino is grazing in a field. sentence3: A rhino is grazing in a field. sentence4: A rhino is grazing in a field. sentence5: A rhino is grazing in a field. sentence6: A rhino is grazing in a field. sentence7: A rhino is grazing in a field. sentence8: A rhino is grazing in a field. sentence8: A rhino is grazing in a field. sentence8:" |
Corpus of Linguistic Acceptability | "cola sentence: The course is jumping well." | "The course is jumping well. The course is jumping well." |
French to English | "translate french to english: Une stratégie" | "A Strategy" |
"translate french to english: Cette année, je pense que c'est la bonne." | "This year I think it's nice." | |
"translate french to english: L'effet de la vitamine D sur le cancer n'est pas clairement établi non plus." | "Vitamin D's effect on cancer is not clear either." | |
"translate french to english: Une boîte noire dans votre voiture?" | "Black box in your car?" | |
"translate french to english: Le sportif Jhonathan Florez a sauté jeudi d'un hélicoptère au-dessus de Bogota, la capitale colombienne." | "Jhonathan Florez crashed helicopter over Bogotá City Thursday night. He survived injuries sustained by teammates from teammates from Bogotá City." | |
"translate french to english: j'aime manger du riz au poulet le matin." | "I like eating rice chicken morning." |
As seen from the table, the model perform quite well on French to English translation, but it fails to perform original tasks well. This demonstrates a striking example of catastrophic forgetting in machine learning, a predicament that affects not only the original tasks - including translation from English to German, summarization, STS-B, and CoLA - but also persists even in models fine-tuned for a small number of steps such as 20,480, 40,960, 61,440, and 81,920. Despite these careful, incremental adjustments, the models continue to exhibit catastrophic forgetting, underlining the challenge of maintaining the proficiency of AI models in their originally trained tasks while integrating new knowledge.