Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[result] Askathon QA comparison table for distilbert and nasa-v6 #35

Open
NISH1001 opened this issue Jul 12, 2023 · 0 comments
Open

[result] Askathon QA comparison table for distilbert and nasa-v6 #35

NISH1001 opened this issue Jul 12, 2023 · 0 comments

Comments

@NISH1001
Copy link
Collaborator

NISH1001 commented Jul 12, 2023

We use askathon cleaned response dataset to evaluate the nasa-v6 models against vanilla distilbert model.
We're using onnx version of nasa-v6 model.

Dataset loader is tentatively:

def load_askathon_clean(path: str) -> pd.DataFrame:
    data = pd.read_csv(path)
    data = data.drop(columns=["Email Address"]).reset_index(drop=True)
    data.rename(columns={
        data.columns[0] : "context",
        data.columns[1]: "id",
        data.columns[2]: "source",
        data.columns[3]: "topics",
        data.columns[4]: "q1",
        data.columns[5]: "a1",
        data.columns[6]: "q2",
        data.columns[7]: "a2",
        data.columns[8]: "q3",
        data.columns[9]: "a3",
        data.columns[10]: "q4",
        data.columns[11]: "a4",
        data.columns[12]: "q5",
        data.columns[13]: "a5"
    }, inplace=True)
    data.drop(columns=["source", "topics"], inplace=True)
    return data

def create_qa_dataset(data: pd.DataFrame) -> pd.DataFrame:
    res = []
    q_keys = [f"q{i}" for i in range(1, 6)]
    a_keys = [f"a{i}" for i in range(1, 6)]
    
    def _index_fn(context: str, answer: str) -> int:
        try:
            return context.lower().index(answer.rstrip(" ,.!?").lower())
        except ValueError:
            return -1
    
    for _df in data.itertuples():
        tmp = []
        for qk, ak in zip(q_keys, a_keys):
            q, a = getattr(_df, qk), getattr(_df, ak)
            
            if not isinstance(a, str):
                continue
            idx = _index_fn(_df.context, a)
            if idx > -1:
                tmp.append(dict(
                    id=str(_df.id),
                    context=_df.context,
                    question=q,
                    answer_text=a,
                    answer_start=idx,
                ))
        res.extend(tmp)
    return pd.DataFrame(res)

data = create_qa_dataset(load_askathon_clean("data/askathon.csv"))

metric distilbert nasa-v6-onnx
RougeMetric 0.569272 0.60426
BartScore -3.4921 -3.26777
F1Metric 0.652469 0.687898
MeteorMetric 0.525007 0.563029
BertScore 0.74016 0.760525
AccuracyMetric 0.545293 0.588434
ExactMatchMetric 0.265625 0.28125
BleuMetric 0.228825 0.280672

Evaluation is done through evalem with following tentative pipeline code:

from evalem.nlp.evaluators import QAEvaluator
from evalem.nlp.models import QuestionAnsweringHFPipelineWrapper
from evalem.nlp.metrics import BartScore, BertScore, BleuMetric, MeteorMetric, ExactMatchMetric, RougeMetric
from evalem import NamedSimpleEvaluationPipeline
from evalem.misc.utils import build_comparison_table


# define models
wrapped_model = QuestionAnsweringHFPipelineWrapper(device="mps")
wrapped_model_2 = QuestionAnsweringHFPipelineWrapper.from_onnx(
    model="tmp/onnx/nasa-v6/",
    tokenizer="tmp/onnx/nasa-v6/",
    device="mps"
)

# define evaluators/metrics
evaluators_common = [
    QAEvaluator(),
    BertScore(device="mps"),
    BartScore(device="mps"),
    RougeMetric(),
    MeteorMetric(),
    BleuMetric(),
]

# build pipelines
eval_pipe = NamedSimpleEvaluationPipeline(
    model=wrapped_model,
    evaluators=evaluators_common,
    name="distilbert"
)

eval_pipe_2 = NamedSimpleEvaluationPipeline(
    model=wrapped_model_2,
    evaluators=evaluators_common,
    name="nasa-v6-onnx"
)

# evaluate and get comparison table
results = build_comparison_table(
    eval_pipe, eval_pipe_2,
    inputs=list(data[["context", "question"]].T.to_dict().values()),
    references=data["answer_text"].to_list(),
)

cc: @muthukumaranR @xhagrg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant