Bakalarska_praca/tryout.py

40 lines
1.5 KiB
Python
Raw Normal View History

2024-10-29 13:08:43 +00:00
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# 1. Load the fine-tuned model and tokenizer
model_path = "T5Autocorrection228" # Path where your model and tokenizer are saved
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
# Set device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 2. Define the function for sentence correction (inference)
def correct_sentence(sentence):
# Tokenize the input sentence
inputs = tokenizer(sentence, return_tensors="pt", max_length=128, truncation=True, padding="max_length").to(device)
# Generate prediction (corrected sentence)
with torch.no_grad():
outputs = model.generate(inputs["input_ids"], max_length=128, num_beams=5, early_stopping=True)
# Decode the output ids to a string
corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
return corrected_sentence
# 3. Test the model on some example sentences
incorrect_sentence = "Kral"
corrected_sentence = correct_sentence(incorrect_sentence)
print(f"Incorrect sentence: {incorrect_sentence}")
print(f"Corrected sentence: {corrected_sentence}")
# You can test more sentences like this:
# incorrect_sentence2 = "The cat are playing with the dog."
# corrected_sentence2 = correct_sentence(incorrect_sentence2)
# print(f"Incorrect sentence: {incorrect_sentence2}")
# print(f"Corrected sentence: {corrected_sentence2}")