Bakalarska_praca/tryout.py
2024-11-10 12:40:38 +01:00

41 lines
1.7 KiB
Python

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# 1. Load the fine-tuned model and tokenizer
model_path = "T5Autocorrection_Book" # Path where your model and tokenizer are saved
tokenizer_path = "T5TokenizerAutocorrection_Book"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
# Set device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 2. Define the function for sentence correction (inference)
def correct_sentence(sentence):
# Tokenize the input sentence
inputs = tokenizer(sentence, return_tensors="pt", max_length=128, truncation=True, padding="max_length").to(device)
# Generate prediction (corrected sentence)
with torch.no_grad():
outputs = model.generate(inputs["input_ids"], max_length=128, num_beams=5, early_stopping=True)
# Decode the output ids to a string
corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_sentence
# 3. Test the model on some example sentences
incorrect_sentence = "Moja oblubena kniha sa nazýva 'Pýcha a predsudok', ale nemam vela casu na čítanie kvoli školskej práci a rodine."
corrected_sentence = correct_sentence(incorrect_sentence)
print(f"Incorrect sentence: {incorrect_sentence}")
print(f"Corrected sentence: {corrected_sentence}")
# You can test more sentences like this:
# incorrect_sentence2 = "The cat are playing with the dog."
# corrected_sentence2 = correct_sentence(incorrect_sentence2)
# print(f"Incorrect sentence: {incorrect_sentence2}")
# print(f"Corrected sentence: {corrected_sentence2}")