-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_zero_shot_chatgpt.py
38 lines (29 loc) · 1.91 KB
/
run_zero_shot_chatgpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import sys
lang_pairs = [("de", "fr"), ("en", "ru"), ("ru", "fr")]
XLING = set(["en","de","fr","it","ru","tr","hr","fi"])
PanLex = set(["bg","ca","hu"])
Model = "gpt-4-turbo-2024-04-09"
size_train = 0 # Seed dictionary size
n_shot = 0 # Number of in-context examples. Zero-shot prompting (also known as unsupervised BLI in previous BLI work): n_shot=0.
DATA_ROOT = "/media/data/T2TData/"
SAVE_ROOT = "/media/data/T2TModel/" # save dir
TMP_DIR = "./TMP/"
os.system("rm -rf {}".format(TMP_DIR))
os.system("mkdir {}".format(TMP_DIR))
for (lang1, lang2) in lang_pairs:
print(lang1, lang2)
sys.stdout.flush()
# --best_template
if lang1 in XLING:
ROOT_EMB_SRC = "/media/data/WES/fasttext.wiki.{}.300.vocab_200K.vec".format(lang1)
ROOT_EMB_TRG = "/media/data/WES/fasttext.wiki.{}.300.vocab_200K.vec".format(lang2)
ROOT_TEST_DICT = "/media/data/xling-eval/bli_datasets/{}-{}/yacle.test.freq.2k.{}-{}.tsv".format(lang1, lang2, lang1, lang2)
else:
ROOT_EMB_SRC = "/media/data/WESPLX/fasttext.cc.{}.300.vocab_200K.vec".format(lang1)
ROOT_EMB_TRG = "/media/data/WESPLX/fasttext.cc.{}.300.vocab_200K.vec".format(lang2)
ROOT_TEST_DICT = "/media/data/panlex-bli/lexicons/all/{}-{}/{}-{}.test.2000.cc.trans".format(lang1, lang2, lang1, lang2)
Dtrain_dir = None
test_prompt_dict_dir = TMP_DIR+"{}2{}_test_prompt_{}.pkl".format(lang1, lang2, size_train)
os.system('python ./src/extract_bli_test_data.py --l1 {} --l2 {} --emb_src_dir {} --emb_tgt_dir {} --train_dict_dir {} --test_dict_dir {} --save_dir {} --source_data {}'.format(lang1, lang2, ROOT_EMB_SRC, ROOT_EMB_TRG, Dtrain_dir, ROOT_TEST_DICT, test_prompt_dict_dir, DATA_ROOT))
os.system('python ./src/openaigen.py --l1 {} --l2 {} --model_name {} --train_size {} --n_shot {} --data_dir {} --test_dict_dir {} --best_template'.format(lang1, lang2, Model, size_train, n_shot, DATA_ROOT, test_prompt_dict_dir))