97 lines
5.6 KiB
Python
97 lines
5.6 KiB
Python
import torch
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
import jiwer
|
|
import sacrebleu
|
|
|
|
# 1. Загрузка модели и токенизатора
|
|
model_path = "T5Autocorrection_Book2" # Укажите путь к вашей модели
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
|
|
|
# Устройство (GPU/CPU)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model.to(device)
|
|
|
|
# 2. Функция для генерации предсказаний модели
|
|
def generate_correction(sentence):
|
|
inputs = tokenizer(sentence, return_tensors="pt", max_length=128, truncation=True, padding="max_length").to(device)
|
|
with torch.no_grad():
|
|
outputs = model.generate(inputs["input_ids"], max_length=128, num_beams=5, early_stopping=True)
|
|
corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return corrected_sentence
|
|
|
|
test_data = [
|
|
{"incorrect": "Toto je jednoduchy priklad textu.", "correct": "Toto je jednoduchý príklad textu."},
|
|
{"incorrect": "Mam rad slovensku kulturu a historiiu.", "correct": "Mám rád slovenskú kultúru a históriu."},
|
|
{"incorrect": "Dnes bol velmi krasny den na prechadzku.", "correct": "Dnes bol veľmi krásny deň na prechádzku."},
|
|
{"incorrect": "V škole sme mali prednasku o dejepise.", "correct": "V škole sme mali prednášku o dejepise."},
|
|
{"incorrect": "Chcel by som sa naucit lepsie pisat po slovensky.", "correct": "Chcel by som sa naučiť lepšie písať po slovensky."},
|
|
{"incorrect": "Zahrada je plna kvetou a voni.", "correct": "Záhrada je plná kvetov a vôní."},
|
|
{"incorrect": "Potrebujem pomoc s mojou domácou ulohou.", "correct": "Potrebujem pomoc s mojou domácou úlohou."},
|
|
{"incorrect": "Kazdy den sa ucime nove veci.", "correct": "Každý deň sa učíme nové veci."},
|
|
{"incorrect": "Moja oblubena kniha sa nazyva 'Pýcha a predsudok'.", "correct": "Moja obľúbená kniha sa nazýva 'Pýcha a predsudok'."},
|
|
{"incorrect": "Som velmi stastny, ze zijem na Slovensku.", "correct": "Som veľmi šťastný, že žijem na Slovensku."},
|
|
{"incorrect": "Prajem ti krasny a uspesny den.", "correct": "Prajem ti krásny a úspešný deň."},
|
|
{"incorrect": "Moje oblubene jedlo je bryndzove halusky.", "correct": "Moje obľúbené jedlo je bryndzové halušky."},
|
|
{"incorrect": "Tento vikend pojdeme na vylet do Tater.", "correct": "Tento víkend pôjdeme na výlet do Tatier."},
|
|
{"incorrect": "Kazdy rok chodime na dovolenku k moru.", "correct": "Každý rok chodíme na dovolenku k moru."},
|
|
{"incorrect": "Na stole je kniha o slovenskej literatur.", "correct": "Na stole je kniha o slovenskej literatúre."},
|
|
{"incorrect": "Ucim sa po slovensky kazdy den.", "correct": "Učím sa po slovensky každý deň."},
|
|
{"incorrect": "Vcera som videl nadherny zapad slnka.", "correct": "Včera som videl nádherný západ slnka."},
|
|
{"incorrect": "Moje oblubene mesto je Bratislava.", "correct": "Moje obľúbené mesto je Bratislava."},
|
|
{"incorrect": "Dnes sme mali skvelu prednasku o umeni.", "correct": "Dnes sme mali skvelú prednášku o umení."},
|
|
{"incorrect": "Chcem sa naucit viac o slovenskej historii.", "correct": "Chcem sa naučiť viac o slovenskej histórii."},
|
|
{"incorrect": "Slovensko je krasna krajina s bohatou kulturou.", "correct": "Slovensko je krásna krajina s bohatou kultúrou."},
|
|
{"incorrect": "Zajtra planujeme ist na turu do hor.", "correct": "Zajtra plánujeme ísť na túru do hôr."},
|
|
{"incorrect": "Moje oblubene miesto na relax je pri jazere.", "correct": "Moje obľúbené miesto na relax je pri jazere."},
|
|
{"incorrect": "Rano som si dal kavu s mliekom.", "correct": "Ráno som si dal kávu s mliekom."},
|
|
{"incorrect": "Na dovolenke sme navstivili historicke pamiatky.", "correct": "Na dovolenke sme navštívili historické pamiatky."},
|
|
{"incorrect": "Vcera sme isli na vylet do lesa.", "correct": "Včera sme išli na výlet do lesa."},
|
|
{"incorrect": "V praci mame velmi dobry kolektiv.", "correct": "V práci máme veľmi dobrý kolektív."},
|
|
{"incorrect": "Na obed sme mali typicke slovenske jedlo.", "correct": "Na obed sme mali typické slovenské jedlo."},
|
|
{"incorrect": "Dnes vecer pojdeme na koncert do mesta.", "correct": "Dnes večer pôjdeme na koncert do mesta."},
|
|
{"incorrect": "Mam rad slovenske ludove piesne.", "correct": "Mám rád slovenské ľudové piesne."}
|
|
]
|
|
|
|
|
|
|
|
# Списки для хранения истинных и предсказанных значений
|
|
references = []
|
|
predictions = []
|
|
|
|
# 4. Генерация предсказаний для тестового набора
|
|
for item in test_data:
|
|
incorrect_sentence = item['incorrect']
|
|
correct_sentence_reference = item['correct']
|
|
|
|
# Генерация предсказания
|
|
prediction = generate_correction(incorrect_sentence) # Вызов функции generate_correction
|
|
references.append(correct_sentence_reference)
|
|
predictions.append(prediction)
|
|
|
|
# 5. Оценка метрик
|
|
|
|
# WER (Word Error Rate)
|
|
wer = jiwer.wer(references, predictions)
|
|
print(f"WER: {wer}")
|
|
|
|
# CER (Character Error Rate)
|
|
cer = jiwer.cer(references, predictions)
|
|
print(f"CER: {cer}")
|
|
|
|
# SER (Sentence Error Rate)
|
|
def calculate_ser(references, predictions):
|
|
errors = 0
|
|
for ref, pred in zip(references, predictions):
|
|
if ref != pred:
|
|
errors += 1
|
|
ser = errors / len(references)
|
|
return ser
|
|
|
|
ser = calculate_ser(references, predictions)
|
|
print(f"SER: {ser}")
|
|
# BLEU (Bilingual Evaluation Understudy Score)
|
|
bleu = sacrebleu.corpus_bleu(predictions, [references])
|
|
print(f"BLEU: {bleu.score}")
|
|
|