Skip to content

Commit

Permalink
Created multiple choice question/evaluation classes, with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza committed Dec 17, 2024
1 parent d688ccc commit 3b4fc47
Show file tree
Hide file tree
Showing 13 changed files with 1,395 additions and 54 deletions.
190 changes: 188 additions & 2 deletions src/aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
import contextlib
import inspect
import io
import random
import re
import string
from ast import literal_eval
from collections.abc import Awaitable, Callable, Sequence
from enum import StrEnum
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast

from pydantic import BaseModel, Field, model_validator

try:
from litellm import acompletion
Expand All @@ -14,7 +21,7 @@
import numpy as np


DEFAULT_EVAL_MODEL_NAME = "gpt-4o-mini"
DEFAULT_EVAL_MODEL_NAME = "gpt-4o"
LLM_BOOL_EVAL_CONFIG = {
"prompt": (
"Here is a question, the correct answer to the question, and a proposed answer"
Expand Down Expand Up @@ -166,3 +173,182 @@ async def eval_answer(
return float(gt in pred)

raise RuntimeError(f"Invalid evaluation mode: {eval_mode}")


_CAPITAL_A_INDEX = ord("A")


class MultipleChoiceQuestion(BaseModel):
QUESTION_PROMPT_TEMPLATE: ClassVar[str] = "Q: {question}\n\nOptions:\n{options}"
# TODO: combine with above eval_answer and its prompts
EVALUATION_PROMPT_TEMPLATE: ClassVar[str] = (
"Given the following question and a proposed answer to the question, return the"
" single-letter choice in the question that matches the proposed answer."
" If the proposed answer is blank or an empty string,"
" or multiple options are matched, respond with '0'."
"\n\nQuestion: {qa_prompt}"
"\n\nProposed Answer: {qa_answer}"
"\n\nSingle Letter Answer:"
)
DEFAULT_UNSURE_OPTION: ClassVar[str] = (
"Insufficient information to answer this question"
)
SEED_USING_QUESTION: ClassVar[Literal["SEED_USING_QUESTION"]] = (
"SEED_USING_QUESTION"
)

question: str = Field(
description="Question to answer (without multiple choice options)."
)
options: Sequence[str] = Field(description="All multiple choice options.")
ideal_answer: str = Field(
description=(
"Desired ideal answer. If not one of the provided options, it will be"
" automatically added."
)
)
unsure_answer: str | None = Field(
default=DEFAULT_UNSURE_OPTION,
description=(
"Unsure answer text. If not one of the provided options, it will be"
" automatically added."
),
)
shuffle_seed: int | Literal["SEED_USING_QUESTION"] | None = Field(
default=None,
description=(
"Optional seed to use in randomization of options, where seeding is not"
" global (e.g. no `random.seed`). Optionally pass in the string literal"
" 'SEED_USING_QUESTION' to hash the question for the seed"
),
)

@model_validator(mode="after")
def add_answers_and_shuffle(self) -> Self:
if self.ideal_answer not in self.options:
self.options = [*self.options, self.ideal_answer]
if self.unsure_answer and self.unsure_answer not in self.options:
self.options = [*self.options, self.unsure_answer]
if len(self.options) > len(string.ascii_lowercase):
raise NotImplementedError(
"Didn't handle more multiple choice options than letters, options were"
f" {self.options}."
)
if self.shuffle_seed == self.SEED_USING_QUESTION:
self.shuffle_seed = hash(self.question)
if self.shuffle_seed is not None:
self.options = random.Random(self.shuffle_seed).sample(
self.options, k=len(self.options)
)
# Ensure deserialization doesn't re-shuffle
self.shuffle_seed = None
return self

@property
def ideal_answer_index(self) -> int:
return self.options.index(self.ideal_answer)

@property
def unsure_answer_index(self) -> int | None:
if self.unsure_answer is None:
return None
return self.options.index(self.unsure_answer)

@property
def question_prompt(self) -> str:
return self.QUESTION_PROMPT_TEMPLATE.format(
question=self.question,
options="\n".join([
f"{_CAPITAL_A_INDEX + i:c}) {o}" for i, o in enumerate(self.options)
]),
)

@staticmethod
def split_options(options: str) -> list[str]:
"""Split options string into a list of options.
Examples:
>>> MultipleChoiceQuestion.split_options("apples, mangos")
['apples', 'mangos']
"""
try:
split_options = literal_eval(options)
if not isinstance(split_options, list):
raise TypeError("Need split_options to be a list.") # noqa: TRY301
except (ValueError, SyntaxError, TypeError):
split_options = [d.strip("'[ ]\"") for d in options.split(",")]
return split_options

async def grade(
self, answer: str, prompt_runner: Callable[[str], Awaitable[str]] | None = None
) -> "tuple[MultipleChoiceEvaluation, str, str]":
if prompt_runner is None:
prompt_runner = run_prompt
eval_prompt = self.EVALUATION_PROMPT_TEMPLATE.format(
qa_prompt=self.question_prompt, qa_answer=answer
)
raw_evaluation = await prompt_runner(eval_prompt)
evaluation, parsed_answer = MultipleChoiceEvaluation.from_answer(
raw_evaluation, self
)
return evaluation, raw_evaluation, parsed_answer


class MultipleChoiceEvaluation(StrEnum):
CORRECT = "correct"
INCORRECT = "incorrect"
UNSURE = "unsure" # May be irrelevant if no unsure option provided

@classmethod
def calculate_accuracy_precision(
cls, evaluations: Sequence[Self | str]
) -> tuple[float, float]:
"""
Calculate QA-specific accuracy and precision metrics upon evaluations.
Raises:
ZeroDivisionError: if an empty input.
Returns:
Two-tuple of accuracy = (num correct) / (num questions) and
precision = (num correct) / ((num questions) - (num unsure)).
"""
evaluations = [e if isinstance(e, cls) else cls(e) for e in evaluations]
num_correct = sum(e == cls.CORRECT for e in evaluations)
accuracy = num_correct / len(evaluations)
precision = num_correct / sum(
e in {cls.CORRECT, cls.INCORRECT} for e in evaluations
)
return accuracy, precision

@classmethod
def from_answer(
cls, answer: str, question: MultipleChoiceQuestion
) -> "tuple[MultipleChoiceEvaluation, str]":
"""Make an evaluation from the input answer and multiple choice question.
Returns:
Two-tuple of answer enum and the raw answer extracted from the input answer.
"""
# SEE: https://regex101.com/r/vcE9Hb/1
letter_search = re.search(r"([A-Z])\)?", answer, re.DOTALL)
# Get the letter answer, or fail over to the first non-whitespace char
answer_char = (
letter_search.group(1)
if letter_search is not None
else answer.split()[0][0].upper()
)
answer_letter_index = ord(answer_char[0]) - _CAPITAL_A_INDEX
if answer_letter_index < 0 or answer_letter_index > len(question.options):
# The result extracted was not in the options (e.g. '0')
return cls.INCORRECT, answer_char
# From here, if we don't match either the ideal or the unsure multiple choice
# options then we declare the answer as incorrect.
if (
question.unsure_answer_index is not None
and answer_letter_index == question.unsure_answer_index
):
return cls.UNSURE, cast(str, question.unsure_answer)
if answer_letter_index == question.ideal_answer_index:
return cls.CORRECT, question.ideal_answer
return cls.INCORRECT, question.options[answer_letter_index]
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
interactions:
- request:
body:
'{"messages": [{"content": "Given the following question and a proposed
answer to the question, return the single-letter choice in the question that
matches the proposed answer. If the proposed answer is blank or an empty string,
or multiple options are matched, respond with ''0''.\n\nQuestion: Q: What is
the meaning of life?\n\nOptions:\nA) -84\nB) Insufficient information to answer
this question\nC) cheesecake\nD) 11\nE) 42\n\nProposed Answer: 14\n\nSingle
Letter Answer:", "role": "user"}], "model": "gpt-4o"}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- "513"
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.57.4
x-stainless-arch:
- arm64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.57.4
x-stainless-raw-response:
- "true"
x-stainless-retry-count:
- "1"
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.7
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAAwAAAP//jJJLa8MwEITv/hVC57goqamd3HpooPQBORVSilGktaNW1gpJoY+Q/15k
u7FDU+jFh/l2xrNr7xNCqJJ0QajY8iAaq9Prav24NO7rNmMKH55ulvd8VaxXHu/ejaOT6MDNK4jw
47oQ2FgNQaHpsHDAA8TUaX6ZZTnLi3kLGpSgo622Ic0wnbFZlrIiZVe9cYtKgKcL8pwQQsi+fcaK
RsIHXRA2+VEa8J7XQBfHIUKoQx0Vyr1XPnAT6GSAAk0A07ZmY91BtfM81jI7rXv9cHyRxto63Pie
H/VKGeW3pQPu0cRQH9DSlh4SQl7ahXYnHal12NhQBnwDEwOnrOjy6HDCEe1ZwMD12DSfnIkrJQSu
tB9dhAoutiAH63A+vpMKRyAZLf27zLnsbnFl6v/ED0AIsAFkaR1IJU4XHsYcxB/sr7HjkdvC1H/6
AE1ZKVODs05137iyJc/nspBcTCuaHJJvAAAA//8DAGY5XevsAgAA
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8f39fde1cf88cf1b-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 17 Dec 2024 21:26:29 GMT
Server:
- cloudflare
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- future-house-xr4tdh
openai-processing-ms:
- "363"
openai-version:
- "2020-10-01"
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- "10000"
x-ratelimit-limit-tokens:
- "30000000"
x-ratelimit-remaining-requests:
- "9999"
x-ratelimit-remaining-tokens:
- "29999874"
x-ratelimit-reset-requests:
- 6ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_aff8daa48aa43d3df077f97da6136e5a
status:
code: 200
message: OK
version: 1
Loading

0 comments on commit 3b4fc47

Please sign in to comment.