-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_retrieval.py
34 lines (23 loc) · 1.19 KB
/
eval_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
from LMTEB_retrieval import *
from flag_dres_model import FlagDRESModel
from mteb import MTEB
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', default="BAAI/bge-large-zh", type=str)
parser.add_argument('--task_type', default=None, type=str)
parser.add_argument('--add_instruction', action='store_true', help="whether to add instruction for query")
parser.add_argument('--pooling_method', default='cls', type=str)
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
model = FlagDRESModel(model_name_or_path=args.model_name_or_path,
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
pooling_method=args.pooling_method)
if args.add_instruction:
instruction="为这个句子生成表示以用于检索相关文章:"
else:
instruction=None
model.query_instruction_for_retrieval = instruction
evaluation = MTEB(tasks=[LongDocRetrieval()])
evaluation.run(model, output_folder=f"zh_results/{args.model_name_or_path.split('/')[-1]}")