diff --git a/trainingscript.py b/trainingscript.py index 4662d38..1c288e4 100644 --- a/trainingscript.py +++ b/trainingscript.py @@ -1,7 +1,8 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments from datasets import load_dataset -model_name = "t5-base" + +model_name = "google/mt5-base" tokenizer = T5Tokenizer.from_pretrained(model_name) model = T5ForConditionalGeneration.from_pretrained(model_name) @@ -11,37 +12,38 @@ def preprocess_function(examples): for ex in examples["before after"]: if ex is not None: splits = ex.split(" before after ") - if len(splits) == 2: - before_list.append(splits[0]) - after_list.append(splits[1]) - else: - before_list.append(ex) - after_list.append('') + before_list.append(splits[0] if len(splits) == 2 else ex) + after_list.append(splits[1] if len(splits) == 2 else '') else: before_list.append('') after_list.append('') - model_inputs = tokenizer(before_list, padding="max_length", truncation=True) - labels = tokenizer(after_list, padding="max_length", truncation=True) + # Токенизация с ограничением по длине + model_inputs = tokenizer(before_list, padding="max_length", truncation=True, max_length=512) + labels = tokenizer(after_list, padding="max_length", truncation=True, max_length=512) model_inputs["labels"] = labels["input_ids"] return model_inputs - dataset = load_dataset("csv", data_files={"train": "converted.csv"}, delimiter=" ", column_names=["before after"]) + tokenized_datasets = dataset.map(preprocess_function, batched=True) + training_args = TrainingArguments( - output_dir="./results1", + output_dir="./results2", evaluation_strategy="epoch", save_strategy="epoch", learning_rate=2e-5, - per_device_train_batch_size=64, - per_device_eval_batch_size=64, + per_device_train_batch_size=1, + per_device_eval_batch_size=1, num_train_epochs=1, weight_decay=0.01, + gradient_accumulation_steps=64, + fp16=True, ) + trainer = Trainer( model=model, args=training_args, @@ -49,6 +51,13 @@ trainer = Trainer( tokenizer=tokenizer, ) + trainer.train() -model.save_pretrained("T5Autocorrection") -tokenizer.save_pretrained("T5TokenizerAutocorrection") + + +for param in model.parameters(): + param.data = param.data.contiguous() + + +model.save_pretrained("T5Autocorrection", safe_serialization=False) # Отключаем safetensors для простого сохранения +tokenizer.save_pretrained("T5TokenizerAutocorrection") \ No newline at end of file