Skip to content

Commit

Permalink
Merge pull request #221 from EvolvingLMMs-Lab/dev/metric_fix
Browse files Browse the repository at this point in the history
[Fix] Fix bugs in returning result dict and bring back anls metric
  • Loading branch information
Luodian authored Sep 3, 2024
2 parents 2be039d + 0126968 commit 146002c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
48 changes: 47 additions & 1 deletion lmms_eval/api/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,52 @@ def bits_per_byte_fn(items): # This is a passthrough function
return items


def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1

distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]


@register_metric(
metric="anls",
higher_is_better=True,
output_type="generate_until",
aggregation="mean",
)
def anls(
references,
predictions,
thresh_hold=0.5,
): # This is a passthrough function
"""https://github.com/QwenLM/Qwen-VL/blob/master/eval_mm/infographicsvqa_eval.py"""
values = []
for answer in references:
# preprocess both the answers - gt and prediction
gt_answer = " ".join(answer.strip().lower().split())
det_answer = " ".join(predictions[0].strip().lower().split())

# dist = levenshtein_distance(answer.lower(), detObject['answer'].lower())
dist = levenshtein_distance(gt_answer, det_answer)
length = max(len(answer.upper()), len(predictions[0].upper()))
values.append(0.0 if length == 0 else float(dist) / float(length))

question_result = 1 - min(values)

if question_result < thresh_hold:
question_result = 0
return {"anls": question_result}


def pop_stddev(arr):
mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
Expand All @@ -296,7 +342,7 @@ def mean_stderr(arr):
aggregation="bypass",
)
def bypass(items):
return None
return items


@register_metric(
Expand Down
13 changes: 0 additions & 13 deletions lmms_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,19 +607,6 @@ def evaluate(
else:
results_dict = None

with open(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt", "w") as f:
f.write(f"rank {int(os.environ.get('RANK', 0))} eval done")
while len([file for file in os.listdir(cli_args.output_path) if file.endswith("metric_eval_done.txt")]) < lm._world_size:
time.sleep(1)

else:
return None

with open(f"{cli_args.output_path}/rank{int(os.environ.get('RANK', 0))}_metric_eval_done.txt", "w") as f:
f.write(f"rank {int(os.environ.get('RANK', 0))} eval done")
while len([file for file in os.listdir(cli_args.output_path) if file.endswith("metric_eval_done.txt")]) < lm._world_size:
time.sleep(1)

lm.accelerator.wait_for_everyone()
return results_dict

Expand Down

0 comments on commit 146002c

Please sign in to comment.