2024-10-29 13:08:43 +00:00
|
|
|
import torch
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
|
|
# 1. Load the fine-tuned model and tokenizer
|
2024-11-10 11:40:38 +00:00
|
|
|
model_path = "T5Autocorrection_Book" # Path where your model and tokenizer are saved
|
|
|
|
tokenizer_path = "T5TokenizerAutocorrection_Book"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
2024-10-29 13:08:43 +00:00
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
|
|
|
|
|
|
|
# Set device (use GPU if available)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
# 2. Define the function for sentence correction (inference)
|
|
|
|
def correct_sentence(sentence):
|
|
|
|
# Tokenize the input sentence
|
|
|
|
inputs = tokenizer(sentence, return_tensors="pt", max_length=128, truncation=True, padding="max_length").to(device)
|
|
|
|
|
|
|
|
# Generate prediction (corrected sentence)
|
|
|
|
with torch.no_grad():
|
|
|
|
outputs = model.generate(inputs["input_ids"], max_length=128, num_beams=5, early_stopping=True)
|
|
|
|
|
|
|
|
# Decode the output ids to a string
|
|
|
|
corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
return corrected_sentence
|
|
|
|
|
|
|
|
# 3. Test the model on some example sentences
|
2024-11-10 11:40:38 +00:00
|
|
|
incorrect_sentence = "Moja oblubena kniha sa nazýva 'Pýcha a predsudok', ale nemam vela casu na čítanie kvoli školskej práci a rodine."
|
2024-10-29 13:08:43 +00:00
|
|
|
corrected_sentence = correct_sentence(incorrect_sentence)
|
|
|
|
|
|
|
|
print(f"Incorrect sentence: {incorrect_sentence}")
|
|
|
|
print(f"Corrected sentence: {corrected_sentence}")
|
|
|
|
|
|
|
|
# You can test more sentences like this:
|
|
|
|
# incorrect_sentence2 = "The cat are playing with the dog."
|
|
|
|
# corrected_sentence2 = correct_sentence(incorrect_sentence2)
|
|
|
|
# print(f"Incorrect sentence: {incorrect_sentence2}")
|
|
|
|
# print(f"Corrected sentence: {corrected_sentence2}")
|
|
|
|
|