197 lines
6.1 KiB
Python
197 lines
6.1 KiB
Python
import os
|
|
import numpy as np
|
|
import torch
|
|
from datasets import load_dataset, concatenate_datasets, Sequence, Value
|
|
from sklearn.metrics import precision_recall_fscore_support
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForTokenClassification,
|
|
Trainer,
|
|
TrainingArguments,
|
|
set_seed
|
|
)
|
|
from peft import LoraConfig, get_peft_model, TaskType
|
|
from datetime import datetime
|
|
|
|
set_seed(42)
|
|
|
|
print("Loading WikiANN and conll2003-SK-NER datasets...")
|
|
try:
|
|
# WikiANN Slovak
|
|
sk_wikiann_train = load_dataset("unimelb-nlp/wikiann", "sk", split="train[:90%]")
|
|
sk_wikiann_val = load_dataset("unimelb-nlp/wikiann", "sk", split="train[90%:]")
|
|
|
|
# conll2003-SK-NER
|
|
conll_sk = load_dataset("ju-bezdek/conll2003-SK-NER", split="train")
|
|
|
|
# Check label sets
|
|
wikiann_labels = sk_wikiann_train.features["ner_tags"].feature.names
|
|
conll_labels = conll_sk.features["ner_tags"].feature.names
|
|
print("WikiANN label names:", wikiann_labels)
|
|
print("Conll2003-SK label names:", conll_labels)
|
|
|
|
# Full label set
|
|
full_label_list = list(sorted(set(wikiann_labels + conll_labels)))
|
|
label2id = {label: i for i, label in enumerate(full_label_list)}
|
|
id2label = {i: label for label, i in label2id.items()}
|
|
num_labels = len(full_label_list)
|
|
|
|
def remap_labels(example):
|
|
example["ner_tags"] = [label2id[conll_labels[i]] for i in example["ner_tags"]]
|
|
return example
|
|
|
|
|
|
conll_sk = conll_sk.map(remap_labels)
|
|
|
|
|
|
conll_sk = conll_sk.cast_column("ner_tags", Sequence(feature=Value(dtype="int64")))
|
|
sk_wikiann_train = sk_wikiann_train.cast_column("ner_tags", Sequence(feature=Value(dtype="int64")))
|
|
sk_wikiann_val = sk_wikiann_val.cast_column("ner_tags", Sequence(feature=Value(dtype="int64")))
|
|
|
|
# kombo
|
|
train_combined = concatenate_datasets([sk_wikiann_train, conll_sk])
|
|
except Exception as e:
|
|
print(f"Failed to load datasets: {e}")
|
|
exit(1)
|
|
|
|
model_name = "gerulata/slovakbert"
|
|
print("Loading model and tokenizer...")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
|
|
base_model = AutoModelForTokenClassification.from_pretrained(
|
|
model_name,
|
|
num_labels=num_labels,
|
|
id2label=id2label,
|
|
label2id=label2id
|
|
)
|
|
|
|
lora_config = LoraConfig(
|
|
task_type=TaskType.TOKEN_CLS,
|
|
r=8,
|
|
lora_alpha=32,
|
|
lora_dropout=0.05,
|
|
target_modules=["query", "value"]
|
|
)
|
|
|
|
model = get_peft_model(base_model, lora_config)
|
|
model.print_trainable_parameters()
|
|
|
|
def tokenize_and_align_labels(batch):
|
|
tokenized_inputs = tokenizer(
|
|
batch["tokens"],
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=256,
|
|
is_split_into_words=True
|
|
)
|
|
labels = []
|
|
for i, label in enumerate(batch["ner_tags"]):
|
|
word_ids = tokenized_inputs.word_ids(batch_index=i)
|
|
label_ids = []
|
|
for word_id in word_ids:
|
|
if word_id is None:
|
|
label_ids.append(-100)
|
|
else:
|
|
label_ids.append(label[word_id])
|
|
labels.append(label_ids)
|
|
tokenized_inputs["labels"] = labels
|
|
return tokenized_inputs
|
|
|
|
print("Tokenizing training and validation datasets...")
|
|
train_dataset = train_combined.map(tokenize_and_align_labels, batched=True, remove_columns=["tokens", "ner_tags"])
|
|
val_dataset = sk_wikiann_val.map(tokenize_and_align_labels, batched=True, remove_columns=["tokens", "ner_tags"])
|
|
|
|
|
|
output_base = "./ner_ner_model"
|
|
os.makedirs(output_base, exist_ok=True)
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=output_base,
|
|
save_strategy="steps",
|
|
save_steps=1000,
|
|
evaluation_strategy="steps",
|
|
eval_steps=1000,
|
|
save_total_limit=2,
|
|
logging_steps=50,
|
|
num_train_epochs=3,
|
|
learning_rate=5e-5,
|
|
per_device_train_batch_size=8,
|
|
per_device_eval_batch_size=8,
|
|
weight_decay=0.01,
|
|
logging_dir=os.path.join(output_base, "logs"),
|
|
warmup_steps=500,
|
|
metric_for_best_model="f1",
|
|
greater_is_better=True,
|
|
load_best_model_at_end=True,
|
|
seed=42
|
|
)
|
|
|
|
def compute_metrics(pred):
|
|
predictions = np.argmax(pred.predictions, axis=2)
|
|
true_labels = pred.label_ids
|
|
|
|
cleaned_predictions, cleaned_labels = [], []
|
|
|
|
for prediction, label in zip(predictions, true_labels):
|
|
temp_pred, temp_label = [], []
|
|
for p, l in zip(prediction, label):
|
|
if l != -100:
|
|
temp_pred.append(p)
|
|
temp_label.append(l)
|
|
cleaned_predictions.extend(temp_pred)
|
|
cleaned_labels.extend(temp_label)
|
|
|
|
precision, recall, f1, _ = precision_recall_fscore_support(
|
|
cleaned_labels, cleaned_predictions, average="weighted"
|
|
)
|
|
return {"precision": precision, "recall": recall, "f1": f1}
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=val_dataset,
|
|
compute_metrics=compute_metrics,
|
|
tokenizer=tokenizer
|
|
)
|
|
|
|
checkpoint_dir = output_base
|
|
last_checkpoint = None
|
|
if os.path.isdir(checkpoint_dir):
|
|
checkpoints = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir) if d.startswith("checkpoint-")]
|
|
if checkpoints:
|
|
last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[-1]
|
|
print(f"Resuming from checkpoint: {last_checkpoint}")
|
|
else:
|
|
print("No checkpoints found. Starting from scratch.")
|
|
|
|
print("Starting training...")
|
|
try:
|
|
if last_checkpoint:
|
|
trainer.train(resume_from_checkpoint=last_checkpoint)
|
|
else:
|
|
trainer.train()
|
|
except Exception as e:
|
|
print(f"Training failed: {e}")
|
|
exit(1)
|
|
|
|
print("Saving model...")
|
|
try:
|
|
model_save_path = os.path.join(output_base, "final_model")
|
|
trainer.save_model(model_save_path)
|
|
print(f"Model saved successfully to {model_save_path}")
|
|
except Exception as e:
|
|
print(f"Failed to save the model: {e}")
|
|
|
|
|
|
from peft import PeftModel
|
|
|
|
print("Merging LoRA into base model and saving full model...")
|
|
try:
|
|
merged_model = model.merge_and_unload()
|
|
full_model_path = os.path.join(output_base, "full_model")
|
|
merged_model.save_pretrained(full_model_path)
|
|
tokenizer.save_pretrained(full_model_path)
|
|
print(f"✅ Full merged model saved to {full_model_path}")
|
|
except Exception as e:
|
|
print(f"❌ Failed to merge and save full model: {e}")
|