142 lines
4.2 KiB
Python
142 lines
4.2 KiB
Python
|
|
import sys
|
|
import sacrebleu
|
|
import collections
|
|
import numpy as np
|
|
import time
|
|
|
|
|
|
def ngram_counts(text, max_n=4):
|
|
|
|
counts = collections.defaultdict(int)
|
|
for n in range(1, max_n + 1):
|
|
for i in range(len(text) - n + 1):
|
|
ngram = tuple(text[i:i+n])
|
|
counts[ngram] += 1
|
|
return counts
|
|
|
|
def gleu_score(reference, hypothesis, max_n=4):
|
|
|
|
ref_counts = ngram_counts(reference.split(), max_n)
|
|
hyp_counts = ngram_counts(hypothesis.split(), max_n)
|
|
|
|
overlap = sum(min(count, hyp_counts[gram]) for gram, count in ref_counts.items())
|
|
|
|
hyp_count_sum = sum(hyp_counts.values())
|
|
ref_count_sum = sum(ref_counts.values())
|
|
|
|
precision = overlap / hyp_count_sum if hyp_count_sum > 0 else 0
|
|
recall = overlap / ref_count_sum if ref_count_sum > 0 else 0
|
|
|
|
return min(precision, recall)
|
|
|
|
def fbeta_score(reference, hypothesis, beta=0.5, max_n=4):
|
|
|
|
ref_counts = ngram_counts(reference.split(), max_n)
|
|
hyp_counts = ngram_counts(hypothesis.split(), max_n)
|
|
|
|
overlap = sum(min(count, hyp_counts[gram]) for gram, count in ref_counts.items())
|
|
|
|
hyp_count_sum = sum(hyp_counts.values())
|
|
ref_count_sum = sum(ref_counts.values())
|
|
|
|
precision = overlap / hyp_count_sum if hyp_count_sum > 0 else 0
|
|
recall = overlap / ref_count_sum if ref_count_sum > 0 else 0
|
|
|
|
if precision + recall == 0:
|
|
return 0.0
|
|
else:
|
|
return (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall)
|
|
|
|
def edit_distance(ref, hyp):
|
|
|
|
d = np.zeros((len(ref) + 1) * (len(hyp) + 1), dtype=np.int32)
|
|
d = d.reshape((len(ref) + 1, len(hyp) + 1))
|
|
for i in range(len(ref) + 1):
|
|
for j in range(len(hyp) + 1):
|
|
if i == 0:
|
|
d[i][j] = j
|
|
elif j == 0:
|
|
d[i][j] = i
|
|
elif ref[i - 1] == hyp[j - 1]:
|
|
d[i][j] = d[i - 1][j - 1]
|
|
else:
|
|
d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1])
|
|
return d[len(ref)][len(hyp)]
|
|
|
|
def wer(reference, hypothesis):
|
|
|
|
ref_words = reference.split()
|
|
if len(ref_words) == 0:
|
|
return 1.0
|
|
hyp_words = hypothesis.split()
|
|
distance = edit_distance(ref_words, hyp_words)
|
|
return distance / len(ref_words)
|
|
|
|
def cer(reference, hypothesis):
|
|
|
|
ref_chars = list(reference)
|
|
if len(ref_chars) == 0:
|
|
return 1.0
|
|
hyp_chars = list(hypothesis)
|
|
distance = edit_distance(ref_chars, hyp_chars)
|
|
return distance / len(ref_chars)
|
|
|
|
def accuracy(refs, preds):
|
|
|
|
exact_matches = sum(1 for ref, pred in zip(refs, preds) if ref == pred)
|
|
return exact_matches / len(refs) if len(refs) > 0 else 0
|
|
|
|
def ser(refs, preds):
|
|
|
|
sentence_errors = sum(1 for ref, pred in zip(refs, preds) if ref != pred)
|
|
return sentence_errors / len(refs) if len(refs) > 0 else 0
|
|
|
|
def main(target_test, target_pred):
|
|
start_time = time.time()
|
|
|
|
refs = []
|
|
preds = []
|
|
|
|
with open(target_test) as test:
|
|
for line in test:
|
|
line = line.strip()
|
|
refs.append(line)
|
|
|
|
with open(target_pred) as pred:
|
|
for line in pred:
|
|
line = line.strip()
|
|
preds.append(line)
|
|
|
|
|
|
gleu_scores = [gleu_score(refs[i], preds[i]) for i in range(len(refs))]
|
|
average_gleu = np.mean(gleu_scores)
|
|
print("Average GLEU: {:.2f}%".format(average_gleu * 100))
|
|
|
|
fbeta_scores = [fbeta_score(refs[i], preds[i]) for i in range(len(refs))]
|
|
average_fbeta = np.mean(fbeta_scores)
|
|
print("Average F0.5 Score: {:.2f}%".format(average_fbeta * 100))
|
|
|
|
wer_scores = [wer(refs[i], preds[i]) for i in range(len(refs))]
|
|
average_wer = np.mean(wer_scores)
|
|
print("Average WER: {:.2f}%".format(average_wer * 100))
|
|
|
|
cer_scores = [cer(refs[i], preds[i]) for i in range(len(refs))]
|
|
average_cer = np.mean(cer_scores)
|
|
print("Average CER: {:.2f}%".format(average_cer * 100))
|
|
|
|
accuracy_score = accuracy(refs, preds)
|
|
print("Accuracy: {:.2f}%".format(accuracy_score * 100))
|
|
|
|
ser_score = ser(refs, preds)
|
|
print("SER: {:.2f}%".format(ser_score * 100))
|
|
|
|
end_time = time.time()
|
|
print(f"Execution Time: {end_time - start_time:.2f} seconds")
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) != 3:
|
|
print("Usage: python script.py target_test target_pred")
|
|
else:
|
|
main(sys.argv[1], sys.argv[2])
|