Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Experiments with BART-based Generative QA #36

Open
NISH1001 opened this issue Jul 19, 2023 · 1 comment
Open

[WIP] Experiments with BART-based Generative QA #36

NISH1001 opened this issue Jul 19, 2023 · 1 comment

Comments

@NISH1001
Copy link
Collaborator

Currently, I am experimenting with BART which is encoder-decoder model, which is mostly used as seq2seq form (mainly for summarization and translation). But, I am able to train the model for Question Answering (generative QA).

On a very rudimentary run for askathon data (128 samples), I fine-tuned and overfitted the vanilla bart model to see if it can work nicely.

Then, used evalem to evaluate and generate comparison table.

metric askathon-tuned bart-vanilla
MeteorMetric 0.456912 0.130023
BleuMetric 0.211337 0.0652976
F1Metric 0.530609 0.142459
AccuracyMetric 0.494754 0.0874828
RougeMetric 0.48671 0.101733
BertScore 0.690443 0.41951
BartScore -3.49565 -5.38511
ExactMatchMetric 0.421875 0

evalem code

The evalem code is a moneky-patch where I have created a new temporary component for Generative QA.

I) Generative QA component (temporary for now)

from transformers import BartForConditionalGeneration, BartTokenizer, TrainingArguments, Trainer
from typing import Iterable

from tqdm import tqdm
import pandas as pd

from evalem import NamedSimpleEvaluationPipeline

from evalem.nlp.models import QuestionAnsweringHFPipelineWrapper
from evalem.nlp.evaluators import QAEvaluator
from evalem.nlp.metrics import BertScore, RougeMetric, MeteorMetric, ExactMatchMetric, BartScore, BleuMetric
from evalem.nlp.structures import QuestionAnsweringDTO

from evalem.nlp.models._base import HFLMWrapper

from evalem.misc.utils import build_comparison_table

class GenerativeBartQAWrapper(HFLMWrapper):
    
#     def _predict(self, inputs, **kwargs):
#         gen_ids = self.model.generate(
#             inputs["input_ids"],
#             attention_mask=inputs["attention_mask"]
#         )
#         return self.token_ids_to_token(gen_ids)
    
    def _predict(self, inputs, **kwargs):
        res = []
        
        batch_size = kwargs.get("batch_size", 8)
        logger.debug(f"batch_size={batch_size}")
        n_items = len(inputs["input_ids"])
        with tqdm(total=int(n_items/batch_size)) as pbar:
            for input_ids, attention_mask in tqdm(zip(
                self.batch_iterator(inputs["input_ids"], batch_size),
                self.batch_iterator(inputs["attention_mask"], batch_size)
            )):
                gen_ids = self.model.generate(
                    input_ids,
                    attention_mask=attention_mask,
                )
                tokens = self.token_ids_to_token(gen_ids)
                res.extend(tokens)
                pbar.update()
        return res

    @staticmethod
    def batch_iterator(iterable, batch_size):
        for start in range(0, len(iterable), batch_size):
            yield iterable[start:start + batch_size]

    
    def token_ids_to_token(self, token_ids):
        return [
            self.tokenizer.decode(token_id, skip_special_tokens=True) for token_id in token_ids
        ]
    
    def _preprocess_inputs(self, inputs: Iterable, **kwargs) -> Iterable:
        """
        A helper method to transform inputs suitable for model to ingest.
        By default, it's an identity function.
        """
        input_texts = []
        labels = []
        
        for dct in inputs:
            ctx, q, a = dct["context"], dct["question"], dct.get("answer")
            input_texts.append(f"context: {ctx}\nquestion: {q}")
            labels.append(a)
        
        tokenized_inputs = self.tokenizer(input_texts, truncation=True, padding="longest", return_tensors="pt")

        return dict(
            input_ids=tokenized_inputs["input_ids"].to(self.model.device),
            attention_mask=tokenized_inputs["attention_mask"].to(self.model.device)
        )

II) Connecting the components

DEVICE = "mps"
BATCH_SIZE = 8

tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
predictions_postprocessor = lambda x: list(map(lambda p: QuestionAnsweringDTO(value=p), x))

wrapped_model_1 = GenerativeBartQAWrapper(
    model=BartForConditionalGeneration.from_pretrained("tmp/bart-askathon-v1/").to(DEVICE),
    tokenizer=tokenizer,
    predictions_postprocessor=predictions_postprocessor,
)

wrapped_model_2 = GenerativeBartQAWrapper(
    model=BartForConditionalGeneration.from_pretrained("facebook/bart-large").to(DEVICE),
    tokenizer=tokenizer,
    predictions_postprocessor=predictions_postprocessor,
)

data = pd.DataFrame(get_askathon_data("data/askathon.csv"))\
    .rename(columns={"contexts": "context", "questions": "question", "answers": "answer"})#.to_dict("records")

evaluators_common = [
    QAEvaluator(),
    BertScore(device="mps"),
    BartScore(device="mps"),
    RougeMetric(),
    MeteorMetric(),
    BleuMetric(),
]

eval_pipe_1 = NamedSimpleEvaluationPipeline(
    model=wrapped_model_1,
    evaluators=evaluators_common,
    name="askathon-tuned"
)

eval_pipe_2 = NamedSimpleEvaluationPipeline(
    model=wrapped_model_2,
    evaluators=evaluators_common,
    name="bart-vanilla"
)

results = build_comparison_table(
    eval_pipe_1, eval_pipe_2,
    inputs=list(data[["context", "question"]].T.to_dict().values()),
    references=data["answer"].to_list(),
)

III) Askathon dataloader

def load_askathon_clean(path: str) -> pd.DataFrame:
    data = pd.read_csv(path)
    data = data.drop(columns=["Email Address"]).reset_index(drop=True)
    data.rename(columns={
        data.columns[0] : "context",
        data.columns[1]: "id",
        data.columns[2]: "source",
        data.columns[3]: "topics",
        data.columns[4]: "q1",
        data.columns[5]: "a1",
        data.columns[6]: "q2",
        data.columns[7]: "a2",
        data.columns[8]: "q3",
        data.columns[9]: "a3",
        data.columns[10]: "q4",
        data.columns[11]: "a4",
        data.columns[12]: "q5",
        data.columns[13]: "a5"
    }, inplace=True)
    data.drop(columns=["source", "topics"], inplace=True)
    return data

def create_qa_dataset(data: pd.DataFrame) -> pd.DataFrame:
    res = []
    q_keys = [f"q{i}" for i in range(1, 6)]
    a_keys = [f"a{i}" for i in range(1, 6)]
    
    def _index_fn(context: str, answer: str) -> int:
        try:
            return context.lower().index(answer.rstrip(" ,.!?").lower())
        except ValueError:
            return -1
    
    for _df in data.itertuples():
        tmp = []
        for qk, ak in zip(q_keys, a_keys):
            q, a = getattr(_df, qk), getattr(_df, ak)
            
            if not isinstance(a, str):
                continue
            idx = _index_fn(_df.context, a)
            if idx > -1:
                tmp.append(dict(
                    id=str(_df.id),
                    context=_df.context,
                    question=q,
                    answer_text=a,
                    answer_start=idx,
                ))
        res.extend(tmp)
    return pd.DataFrame(res)

cc: @muthukumaranR @xhagrg

@NISH1001
Copy link
Collaborator Author

NISH1001 commented Jul 21, 2023

July 21

In this, we use 2 sets of data combined:

  • askathon (128 samples)
  • ES squad v2 data (117 samples, excluded impossible questions)

Total data size = 245 samples

We trained 2 runs on same settings on same data split

  • train split of 75%
  • AdamW (lr=5e-5)
  • batch size = 8
  • epochs = 10

And finally compare the two runs w.r.t vanilla bart

metric v2 v3 vanilla
BleuMetric 0.109138 0.167192 0.0510784
AccuracyMetric 0.161637 0.185594 0.0807655
MeteorMetric 0.150026 0.201007 0.125672
ExactMatchMetric 0.0851064 0.0806452 0
RougeMetric 0.159285 0.192992 0.0944004
BertScore 0.496189 0.497045 0.414053
F1Metric 0.181625 0.22986 0.126766
BartScore -5.26783 -4.95518 -5.31429

The training was done in a way context + question were combined as: f"context: {context}\nquestion:{question}"


Now we train another new model (v4) with dropout=0.2.

metric tmp/bart-askathon-v2/ tmp/bart-askathon-v3/ tmp/bart-askathon-v4/ facebook/bart-large
AccuracyMetric 0.161637 0.185594 0.18856 0.0807655
MeteorMetric 0.150026 0.201007 0.187716 0.125672
RougeMetric 0.159389 0.192661 0.188657 0.093801
ExactMatchMetric 0.0851064 0.0806452 0.0967742 0
F1Metric 0.181625 0.22986 0.233943 0.126766
BertScore 0.496189 0.497045 0.500905 0.414053
BartScore -5.26783 -4.95518 -4.85861 -5.31429
BleuMetric 0.109138 0.167192 0.110142 0.0510784

Comparison table on train data

(183 samples)

metric tmp/bart-askathon-v2/ tmp/bart-askathon-v3/ tmp/bart-askathon-v4/ facebook/bart-large
BleuMetric 0.100755 0.2866164510826 0.2498173874370021 0.0748681
F1Metric 0.278363 0.3612513677597083 0.392318298794673 0.154488
AccuracyMetric 0.245453 0.2977759512961105 0.34619882292279 0.094062
RougeMetric 0.234801 0.31110798938483697 0.3485161151517261 0.109307
BartScore -5.28446 -4.387980080856 -4966262879296 -5.37412
ExactMatchMetric 0.184874 0.18032786885245902 0.262298196721313 0
BertScore 0.56794 0.5850964090863212 0.230477664197 0.424385
MeteorMetric 0.202797 0.3124119917021371 0.337576884785 0.140888

cc: @muthukumaranR @xhagrg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant