50 lines
2.1 KiB
Python
50 lines
2.1 KiB
Python
import os
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
os.environ['WANDB_DISABLED'] = 'true'
|
|
import torch
|
|
from tqdm import tqdm
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
from transformers import ByT5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
def load_model(model_path):
|
|
tokenizer = ByT5Tokenizer.from_pretrained(model_path)
|
|
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
|
if torch.cuda.is_available():
|
|
model = model.cuda()
|
|
return tokenizer, model
|
|
|
|
def correct_sentence(tokenizer, model, sentence):
|
|
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=4096)
|
|
if torch.cuda.is_available():
|
|
inputs = {k: v.cuda() for k, v in inputs.items()}
|
|
output_sequences = model.generate(
|
|
input_ids=inputs["input_ids"],
|
|
attention_mask=inputs["attention_mask"],
|
|
max_length=4096,
|
|
)
|
|
corrected = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
|
|
return corrected
|
|
|
|
def process_and_save_corrections(input_file_path, output_file_path, tokenizer, model):
|
|
with open(input_file_path, 'r', encoding='utf-8') as input_file, \
|
|
open(output_file_path, 'w', encoding='utf-8') as output_file:
|
|
sentences = input_file.readlines()
|
|
for sentence in tqdm(sentences, desc="Processing sentences"):
|
|
sentence = sentence.strip()
|
|
if sentence:
|
|
corrected = correct_sentence(tokenizer, model, sentence)
|
|
output_file.write(corrected + "\n")
|
|
output_file.flush()
|
|
|
|
if __name__ == "__main__":
|
|
model_path = "./fine_tuned_model"
|
|
input_file_path = "./test_incorrect.txt"
|
|
output_file_path = "./test_correct_model.txt"
|
|
|
|
tokenizer, model = load_model(model_path)
|
|
process_and_save_corrections(input_file_path, output_file_path, tokenizer, model)
|
|
print("Correction process completed. Corrected sentences saved to", output_file_path)
|