Bakalar_work/train_ner.py

197 lines
6.1 KiB
Python

import os
import numpy as np
import torch
from datasets import load_dataset, concatenate_datasets, Sequence, Value
from sklearn.metrics import precision_recall_fscore_support
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
Trainer,
TrainingArguments,
set_seed
)
from peft import LoraConfig, get_peft_model, TaskType
from datetime import datetime
set_seed(42)
print("Loading WikiANN and conll2003-SK-NER datasets...")
try:
# WikiANN Slovak
sk_wikiann_train = load_dataset("unimelb-nlp/wikiann", "sk", split="train[:90%]")
sk_wikiann_val = load_dataset("unimelb-nlp/wikiann", "sk", split="train[90%:]")
# conll2003-SK-NER
conll_sk = load_dataset("ju-bezdek/conll2003-SK-NER", split="train")
# Check label sets
wikiann_labels = sk_wikiann_train.features["ner_tags"].feature.names
conll_labels = conll_sk.features["ner_tags"].feature.names
print("WikiANN label names:", wikiann_labels)
print("Conll2003-SK label names:", conll_labels)
# Full label set
full_label_list = list(sorted(set(wikiann_labels + conll_labels)))
label2id = {label: i for i, label in enumerate(full_label_list)}
id2label = {i: label for label, i in label2id.items()}
num_labels = len(full_label_list)
def remap_labels(example):
example["ner_tags"] = [label2id[conll_labels[i]] for i in example["ner_tags"]]
return example
conll_sk = conll_sk.map(remap_labels)
conll_sk = conll_sk.cast_column("ner_tags", Sequence(feature=Value(dtype="int64")))
sk_wikiann_train = sk_wikiann_train.cast_column("ner_tags", Sequence(feature=Value(dtype="int64")))
sk_wikiann_val = sk_wikiann_val.cast_column("ner_tags", Sequence(feature=Value(dtype="int64")))
# kombo
train_combined = concatenate_datasets([sk_wikiann_train, conll_sk])
except Exception as e:
print(f"Failed to load datasets: {e}")
exit(1)
model_name = "gerulata/slovakbert"
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
base_model = AutoModelForTokenClassification.from_pretrained(
model_name,
num_labels=num_labels,
id2label=id2label,
label2id=label2id
)
lora_config = LoraConfig(
task_type=TaskType.TOKEN_CLS,
r=8,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["query", "value"]
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
def tokenize_and_align_labels(batch):
tokenized_inputs = tokenizer(
batch["tokens"],
padding="max_length",
truncation=True,
max_length=256,
is_split_into_words=True
)
labels = []
for i, label in enumerate(batch["ner_tags"]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
label_ids = []
for word_id in word_ids:
if word_id is None:
label_ids.append(-100)
else:
label_ids.append(label[word_id])
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
print("Tokenizing training and validation datasets...")
train_dataset = train_combined.map(tokenize_and_align_labels, batched=True, remove_columns=["tokens", "ner_tags"])
val_dataset = sk_wikiann_val.map(tokenize_and_align_labels, batched=True, remove_columns=["tokens", "ner_tags"])
output_base = "./ner_ner_model"
os.makedirs(output_base, exist_ok=True)
training_args = TrainingArguments(
output_dir=output_base,
save_strategy="steps",
save_steps=1000,
evaluation_strategy="steps",
eval_steps=1000,
save_total_limit=2,
logging_steps=50,
num_train_epochs=3,
learning_rate=5e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
weight_decay=0.01,
logging_dir=os.path.join(output_base, "logs"),
warmup_steps=500,
metric_for_best_model="f1",
greater_is_better=True,
load_best_model_at_end=True,
seed=42
)
def compute_metrics(pred):
predictions = np.argmax(pred.predictions, axis=2)
true_labels = pred.label_ids
cleaned_predictions, cleaned_labels = [], []
for prediction, label in zip(predictions, true_labels):
temp_pred, temp_label = [], []
for p, l in zip(prediction, label):
if l != -100:
temp_pred.append(p)
temp_label.append(l)
cleaned_predictions.extend(temp_pred)
cleaned_labels.extend(temp_label)
precision, recall, f1, _ = precision_recall_fscore_support(
cleaned_labels, cleaned_predictions, average="weighted"
)
return {"precision": precision, "recall": recall, "f1": f1}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
tokenizer=tokenizer
)
checkpoint_dir = output_base
last_checkpoint = None
if os.path.isdir(checkpoint_dir):
checkpoints = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir) if d.startswith("checkpoint-")]
if checkpoints:
last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[-1]
print(f"Resuming from checkpoint: {last_checkpoint}")
else:
print("No checkpoints found. Starting from scratch.")
print("Starting training...")
try:
if last_checkpoint:
trainer.train(resume_from_checkpoint=last_checkpoint)
else:
trainer.train()
except Exception as e:
print(f"Training failed: {e}")
exit(1)
print("Saving model...")
try:
model_save_path = os.path.join(output_base, "final_model")
trainer.save_model(model_save_path)
print(f"Model saved successfully to {model_save_path}")
except Exception as e:
print(f"Failed to save the model: {e}")
from peft import PeftModel
print("Merging LoRA into base model and saving full model...")
try:
merged_model = model.merge_and_unload()
full_model_path = os.path.join(output_base, "full_model")
merged_model.save_pretrained(full_model_path)
tokenizer.save_pretrained(full_model_path)
print(f"✅ Full merged model saved to {full_model_path}")
except Exception as e:
print(f"❌ Failed to merge and save full model: {e}")