bkpc/load.py

50 lines
2.1 KiB
Python

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['WANDB_DISABLED'] = 'true'
import torch
from tqdm import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import ByT5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments
from sklearn.model_selection import train_test_split
def load_model(model_path):
tokenizer = ByT5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
if torch.cuda.is_available():
model = model.cuda()
return tokenizer, model
def correct_sentence(tokenizer, model, sentence):
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=4096)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
output_sequences = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=4096,
)
corrected = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
return corrected
def process_and_save_corrections(input_file_path, output_file_path, tokenizer, model):
with open(input_file_path, 'r', encoding='utf-8') as input_file, \
open(output_file_path, 'w', encoding='utf-8') as output_file:
sentences = input_file.readlines()
for sentence in tqdm(sentences, desc="Processing sentences"):
sentence = sentence.strip()
if sentence:
corrected = correct_sentence(tokenizer, model, sentence)
output_file.write(corrected + "\n")
output_file.flush()
if __name__ == "__main__":
model_path = "./fine_tuned_model"
input_file_path = "./test_incorrect.txt"
output_file_path = "./test_correct_model.txt"
tokenizer, model = load_model(model_path)
process_and_save_corrections(input_file_path, output_file_path, tokenizer, model)
print("Correction process completed. Corrected sentences saved to", output_file_path)