commit 5eac948712f94d41dda427593a8200bbcd114358 Author: Andrii Pervashov Date: Fri Aug 16 14:35:54 2024 +0000 Add trainingscript.py diff --git a/trainingscript.py b/trainingscript.py new file mode 100644 index 0000000..24d0467 --- /dev/null +++ b/trainingscript.py @@ -0,0 +1,54 @@ +from datasets import load_dataset +from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments + +# Initialize the tokenizer +model_name = "t5-small" +tokenizer = T5Tokenizer.from_pretrained(model_name) + +# Load the dataset with the specific configuration +dataset = load_dataset("wiki_atomic_edits", "english_insertions", trust_remote_code=True) + +# Inspect the dataset splits +print(dataset.keys()) # Print available dataset splits + +# Preprocessing Function +def preprocess_function(examples): + inputs = examples["base_sentence"] + targets = examples["edited_sentence"] + model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length") + labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length") + labels["input_ids"] = [ + [(label if label != tokenizer.pad_token_id else -100) for label in labels_example] + for labels_example in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + return model_inputs + +# Apply the preprocessing function to the dataset +tokenized_datasets = dataset.map(preprocess_function, batched=True) + +# Initialize the model +model = T5ForConditionalGeneration.from_pretrained(model_name) + +# Set up training arguments +training_args = TrainingArguments( + output_dir="./results", + evaluation_strategy="epoch", # Updated from eval_strategy to evaluation_strategy + learning_rate=2e-5, + per_device_train_batch_size=4, + per_device_eval_batch_size=4, + num_train_epochs=3, + weight_decay=0.01, + logging_dir="./logs", +) + +# Initialize Trainer +trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_datasets["train"], + eval_dataset=tokenized_datasets.get("validation") # Use .get() to avoid KeyError +) + +# Start training +trainer.train() \ No newline at end of file