Bakalarska_praca/trainingscript.py

63 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset
model_name = "google/mt5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
def preprocess_function(examples):
before_list = []
after_list = []
for ex in examples["before after"]:
if ex is not None:
splits = ex.split(" before after ")
before_list.append(splits[0] if len(splits) == 2 else ex)
after_list.append(splits[1] if len(splits) == 2 else '')
else:
before_list.append('')
after_list.append('')
# Токенизация с ограничением по длине
model_inputs = tokenizer(before_list, padding="max_length", truncation=True, max_length=512)
labels = tokenizer(after_list, padding="max_length", truncation=True, max_length=512)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
dataset = load_dataset("csv", data_files={"train": "converted.csv"}, delimiter=" ", column_names=["before after"])
tokenized_datasets = dataset.map(preprocess_function, batched=True)
training_args = TrainingArguments(
output_dir="./results2",
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
num_train_epochs=1,
weight_decay=0.01,
gradient_accumulation_steps=64,
fp16=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
tokenizer=tokenizer,
)
trainer.train()
for param in model.parameters():
param.data = param.data.contiguous()
model.save_pretrained("T5Autocorrection", safe_serialization=False) # Отключаем safetensors для простого сохранения
tokenizer.save_pretrained("T5TokenizerAutocorrection")