Skip to content

Commit

Permalink
Add max tokens/temperature/top p to query_multimodal
Browse files Browse the repository at this point in the history
Summary: as titled, make it consistent with other methods

Reviewed By: SimonWan

Differential Revision: D64257049

fbshipit-source-id: 6a9a12df4efdf5abc0371a4e6b8a0942d22b1755
  • Loading branch information
onionymous authored and facebook-github-bot committed Oct 11, 2024
1 parent b85bf58 commit babeb62
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions CybersecurityBenchmarks/benchmark/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@

LOG: logging.Logger = logging.getLogger(__name__)

DEFAULT_MAX_TOKENS = 2048
DEFAULT_TEMPERATURE = 0.6
DEFAULT_TOP_P = 0.9


class RetriesExceededException(Exception):
def __init__(
Expand Down Expand Up @@ -300,6 +304,9 @@ def query_multimodal(
text_prompt: Optional[str] = None,
image_path: Optional[str] = None,
audio_path: Optional[str] = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
) -> str:
"""
Abstract method to query an LLM and return the response. This method
Expand All @@ -325,6 +332,9 @@ def query_multimodal_with_retries(
image_path: Optional[str] = None,
audio_path: Optional[str] = None,
num_retries: int = NUM_LLM_RETRIES,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
backoff_factor: float = 0.5,
) -> Optional[str]:
"""
Expand Down Expand Up @@ -362,6 +372,9 @@ def query_multimodal_with_retries(
text_prompt=text_prompt,
image_path=image_path,
audio_path=audio_path,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
if self.cache_handler:
self.cache_handler.save_response(
Expand Down

0 comments on commit babeb62

Please sign in to comment.