import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['WANDB_DISABLED'] = 'true' import torch from tqdm import tqdm from transformers import T5Tokenizer, T5ForConditionalGeneration from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers import ByT5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments from sklearn.model_selection import train_test_split def load_model(model_path): tokenizer = ByT5Tokenizer.from_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path) if torch.cuda.is_available(): model = model.cuda() return tokenizer, model def correct_sentence(tokenizer, model, sentence): inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=4096) if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} output_sequences = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=4096, ) corrected = tokenizer.decode(output_sequences[0], skip_special_tokens=True) return corrected def process_and_save_corrections(input_file_path, output_file_path, tokenizer, model): with open(input_file_path, 'r', encoding='utf-8') as input_file, \ open(output_file_path, 'w', encoding='utf-8') as output_file: sentences = input_file.readlines() for sentence in tqdm(sentences, desc="Processing sentences"): sentence = sentence.strip() if sentence: corrected = correct_sentence(tokenizer, model, sentence) output_file.write(corrected + "\n") output_file.flush() if __name__ == "__main__": model_path = "./fine_tuned_model" input_file_path = "./test_incorrect.txt" output_file_path = "./test_correct_model.txt" tokenizer, model = load_model(model_path) process_and_save_corrections(input_file_path, output_file_path, tokenizer, model) print("Correction process completed. Corrected sentences saved to", output_file_path)