diff --git a/gollm/entities.py b/gollm/entities.py index 852fafb..ae26870 100644 --- a/gollm/entities.py +++ b/gollm/entities.py @@ -20,7 +20,7 @@ class ModelCardModel(BaseModel): class ModelCompareModel(BaseModel): - cards: List # model cards + amrs: List[str] # expects AMRs to be a stringified JSON object class EmbeddingModel(BaseModel): diff --git a/gollm/openai/prompts/model_meta_compare.py b/gollm/openai/prompts/model_meta_compare.py index 22c4996..2b6bf1c 100644 --- a/gollm/openai/prompts/model_meta_compare.py +++ b/gollm/openai/prompts/model_meta_compare.py @@ -1,3 +1,18 @@ MODEL_METADATA_COMPARE_PROMPT = """ -You are a helpful agent designed to compare the metadata of multiple models. Use as much detail as possible and assume that your audience is domain experts. When you mention bias and limitations, provide detailed examples. Do not repeat the model card schema headers. You have access to the model cards.\n{model_cards}\nComparison: +You are a helpful agent designed to compare multiple AMR models. + +Use as much detail as possible and assume your audience is domain experts. When you mention bias and limitations, provide detailed examples. Do not repeat the model card schema headers. Do not refer to 'gollmCard' in your response, refer to 'gollmCard metadata' as 'metadata'. Format the response in Markdown and include section headers. + +If all the AMR models contain gollmCard metadata, focus solely on comparing gollmCard information. + +If some but not all of the AMR models contain gollmCard metadata, compare headers, gollmCard, and semantic information together. + +If none of the AMR models contain gollmCard metadata, only focus on comparing headers and semantic information. Avoid making assumptions about the AMR models to maintain an objective perspective. + +AMRs: + +{amrs} + + +Comparison: """ diff --git a/gollm/openai/tool_utils.py b/gollm/openai/tool_utils.py index 7f49846..325cc21 100644 --- a/gollm/openai/tool_utils.py +++ b/gollm/openai/tool_utils.py @@ -197,19 +197,23 @@ def config_from_dataset(amr: str, model_mapping: str, datasets: List[str]) -> st return postprocess_oai_json(output.choices[0].message.content) -def compare_models(model_cards: List[str]) -> str: +def compare_models(amrs: List[str]) -> str: + print("Comparing models...") + + joined_escaped_amrs = "\n------\n".join([escape_curly_braces(amr) for amr in amrs]) prompt = MODEL_METADATA_COMPARE_PROMPT.format( - model_cards="--------".join(model_cards) + amrs=joined_escaped_amrs ) + client = OpenAI() output = client.chat.completions.create( - model="gpt-4o-2024-05-13", + model="gpt-4o-mini", top_p=1, frequency_penalty=0, presence_penalty=0, seed=123, temperature=0, - max_tokens=1024, + max_tokens=2048, messages=[ {"role": "user", "content": prompt}, ],