import os import numpy as np import torch from datasets import load_dataset, concatenate_datasets, Sequence, Value from sklearn.metrics import precision_recall_fscore_support from transformers import ( AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments, set_seed ) from peft import LoraConfig, get_peft_model, TaskType from datetime import datetime set_seed(42) print("Loading WikiANN and conll2003-SK-NER datasets...") try: # WikiANN Slovak sk_wikiann_train = load_dataset("unimelb-nlp/wikiann", "sk", split="train[:90%]") sk_wikiann_val = load_dataset("unimelb-nlp/wikiann", "sk", split="train[90%:]") # conll2003-SK-NER conll_sk = load_dataset("ju-bezdek/conll2003-SK-NER", split="train") # Check label sets wikiann_labels = sk_wikiann_train.features["ner_tags"].feature.names conll_labels = conll_sk.features["ner_tags"].feature.names print("WikiANN label names:", wikiann_labels) print("Conll2003-SK label names:", conll_labels) # Full label set full_label_list = list(sorted(set(wikiann_labels + conll_labels))) label2id = {label: i for i, label in enumerate(full_label_list)} id2label = {i: label for label, i in label2id.items()} num_labels = len(full_label_list) def remap_labels(example): example["ner_tags"] = [label2id[conll_labels[i]] for i in example["ner_tags"]] return example conll_sk = conll_sk.map(remap_labels) conll_sk = conll_sk.cast_column("ner_tags", Sequence(feature=Value(dtype="int64"))) sk_wikiann_train = sk_wikiann_train.cast_column("ner_tags", Sequence(feature=Value(dtype="int64"))) sk_wikiann_val = sk_wikiann_val.cast_column("ner_tags", Sequence(feature=Value(dtype="int64"))) # kombo train_combined = concatenate_datasets([sk_wikiann_train, conll_sk]) except Exception as e: print(f"Failed to load datasets: {e}") exit(1) model_name = "gerulata/slovakbert" print("Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) base_model = AutoModelForTokenClassification.from_pretrained( model_name, num_labels=num_labels, id2label=id2label, label2id=label2id ) lora_config = LoraConfig( task_type=TaskType.TOKEN_CLS, r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["query", "value"] ) model = get_peft_model(base_model, lora_config) model.print_trainable_parameters() def tokenize_and_align_labels(batch): tokenized_inputs = tokenizer( batch["tokens"], padding="max_length", truncation=True, max_length=256, is_split_into_words=True ) labels = [] for i, label in enumerate(batch["ner_tags"]): word_ids = tokenized_inputs.word_ids(batch_index=i) label_ids = [] for word_id in word_ids: if word_id is None: label_ids.append(-100) else: label_ids.append(label[word_id]) labels.append(label_ids) tokenized_inputs["labels"] = labels return tokenized_inputs print("Tokenizing training and validation datasets...") train_dataset = train_combined.map(tokenize_and_align_labels, batched=True, remove_columns=["tokens", "ner_tags"]) val_dataset = sk_wikiann_val.map(tokenize_and_align_labels, batched=True, remove_columns=["tokens", "ner_tags"]) output_base = "./ner_ner_model" os.makedirs(output_base, exist_ok=True) training_args = TrainingArguments( output_dir=output_base, save_strategy="steps", save_steps=1000, evaluation_strategy="steps", eval_steps=1000, save_total_limit=2, logging_steps=50, num_train_epochs=3, learning_rate=5e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, weight_decay=0.01, logging_dir=os.path.join(output_base, "logs"), warmup_steps=500, metric_for_best_model="f1", greater_is_better=True, load_best_model_at_end=True, seed=42 ) def compute_metrics(pred): predictions = np.argmax(pred.predictions, axis=2) true_labels = pred.label_ids cleaned_predictions, cleaned_labels = [], [] for prediction, label in zip(predictions, true_labels): temp_pred, temp_label = [], [] for p, l in zip(prediction, label): if l != -100: temp_pred.append(p) temp_label.append(l) cleaned_predictions.extend(temp_pred) cleaned_labels.extend(temp_label) precision, recall, f1, _ = precision_recall_fscore_support( cleaned_labels, cleaned_predictions, average="weighted" ) return {"precision": precision, "recall": recall, "f1": f1} trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics, tokenizer=tokenizer ) checkpoint_dir = output_base last_checkpoint = None if os.path.isdir(checkpoint_dir): checkpoints = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir) if d.startswith("checkpoint-")] if checkpoints: last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[-1] print(f"Resuming from checkpoint: {last_checkpoint}") else: print("No checkpoints found. Starting from scratch.") print("Starting training...") try: if last_checkpoint: trainer.train(resume_from_checkpoint=last_checkpoint) else: trainer.train() except Exception as e: print(f"Training failed: {e}") exit(1) print("Saving model...") try: model_save_path = os.path.join(output_base, "final_model") trainer.save_model(model_save_path) print(f"Model saved successfully to {model_save_path}") except Exception as e: print(f"Failed to save the model: {e}") from peft import PeftModel print("Merging LoRA into base model and saving full model...") try: merged_model = model.merge_and_unload() full_model_path = os.path.join(output_base, "full_model") merged_model.save_pretrained(full_model_path) tokenizer.save_pretrained(full_model_path) print(f"✅ Full merged model saved to {full_model_path}") except Exception as e: print(f"❌ Failed to merge and save full model: {e}")