40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
|
import torch
|
||
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||
|
|
||
|
# 1. Load the fine-tuned model and tokenizer
|
||
|
model_path = "T5Autocorrection228" # Path where your model and tokenizer are saved
|
||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
|
||
|
|
||
|
# Set device (use GPU if available)
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
model.to(device)
|
||
|
|
||
|
# 2. Define the function for sentence correction (inference)
|
||
|
def correct_sentence(sentence):
|
||
|
# Tokenize the input sentence
|
||
|
inputs = tokenizer(sentence, return_tensors="pt", max_length=128, truncation=True, padding="max_length").to(device)
|
||
|
|
||
|
# Generate prediction (corrected sentence)
|
||
|
with torch.no_grad():
|
||
|
outputs = model.generate(inputs["input_ids"], max_length=128, num_beams=5, early_stopping=True)
|
||
|
|
||
|
# Decode the output ids to a string
|
||
|
corrected_sentence = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
|
|
||
|
return corrected_sentence
|
||
|
|
||
|
# 3. Test the model on some example sentences
|
||
|
incorrect_sentence = "Kral"
|
||
|
corrected_sentence = correct_sentence(incorrect_sentence)
|
||
|
|
||
|
print(f"Incorrect sentence: {incorrect_sentence}")
|
||
|
print(f"Corrected sentence: {corrected_sentence}")
|
||
|
|
||
|
# You can test more sentences like this:
|
||
|
# incorrect_sentence2 = "The cat are playing with the dog."
|
||
|
# corrected_sentence2 = correct_sentence(incorrect_sentence2)
|
||
|
# print(f"Incorrect sentence: {incorrect_sentence2}")
|
||
|
# print(f"Corrected sentence: {corrected_sentence2}")
|
||
|
|