248 lines
5.7 KiB
Python
248 lines
5.7 KiB
Python
import os
|
|
from pathlib import Path
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
os.environ["WANDB_DISABLED"] = "true"
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForCausalLM,
|
|
BitsAndBytesConfig,
|
|
Trainer,
|
|
TrainingArguments,
|
|
default_data_collator,
|
|
)
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
|
|
|
|
MODEL_NAME = "slovak-nlp/mistral-sk-7b"
|
|
DATASET_NAME = "saillab/alpaca-slovak-cleaned"
|
|
|
|
PROJECT_DIR = Path.home() / "diplomovka" / "mistral_sk_alpaca"
|
|
OUTPUT_DIR = PROJECT_DIR / "outputs-full"
|
|
ADAPTER_DIR = PROJECT_DIR / "mistral-sk-7b-alpaca-slovak-lora-full"
|
|
|
|
PROJECT_DIR.mkdir(parents=True, exist_ok=True)
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
MAX_TRAIN_SAMPLES = None
|
|
MAX_EVAL_SAMPLES = 1000
|
|
|
|
MAX_LENGTH = 1024
|
|
BATCH_SIZE = 1
|
|
GRAD_ACCUM = 8
|
|
LEARNING_RATE = 2e-4
|
|
NUM_EPOCHS = 1
|
|
|
|
SAVE_STEPS = 1000
|
|
EVAL_STEPS = 1000
|
|
WARMUP_STEPS = 150
|
|
MAX_STEPS = -1
|
|
|
|
|
|
print("CUDA available:", torch.cuda.is_available())
|
|
if torch.cuda.is_available():
|
|
print("GPU:", torch.cuda.get_device_name(0))
|
|
print("VRAM GB:", round(torch.cuda.get_device_properties(0).total_memory / 1024**3, 2))
|
|
|
|
print("Project dir:", PROJECT_DIR)
|
|
print("Output dir:", OUTPUT_DIR)
|
|
print("Adapter dir:", ADAPTER_DIR)
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
|
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
raw_dataset = load_dataset(DATASET_NAME)
|
|
|
|
print(raw_dataset)
|
|
|
|
|
|
def is_empty(value):
|
|
if value is None:
|
|
return True
|
|
value = str(value).strip()
|
|
return value == "" or value.lower() == "nan"
|
|
|
|
|
|
def build_prompt(example):
|
|
instruction = str(example["instruction"]).strip()
|
|
input_text = example.get("input")
|
|
|
|
if is_empty(input_text):
|
|
prompt = f"### Inštrukcia:\n{instruction}\n\n### Odpoveď:\n"
|
|
else:
|
|
prompt = f"### Inštrukcia:\n{instruction}\n\n### Vstup:\n{str(input_text).strip()}\n\n### Odpoveď:\n"
|
|
|
|
completion = str(example["output"]).strip() + tokenizer.eos_token
|
|
|
|
return {
|
|
"prompt": prompt,
|
|
"completion": completion,
|
|
"text": prompt + completion,
|
|
}
|
|
|
|
|
|
dataset = raw_dataset.map(
|
|
build_prompt,
|
|
remove_columns=raw_dataset["train"].column_names,
|
|
)
|
|
|
|
if MAX_TRAIN_SAMPLES is not None:
|
|
dataset["train"] = dataset["train"].select(range(min(MAX_TRAIN_SAMPLES, len(dataset["train"]))))
|
|
|
|
if MAX_EVAL_SAMPLES is not None:
|
|
dataset["test"] = dataset["test"].select(range(min(MAX_EVAL_SAMPLES, len(dataset["test"]))))
|
|
|
|
print(dataset)
|
|
|
|
|
|
def tokenize_example(example):
|
|
prompt_ids = tokenizer(
|
|
example["prompt"],
|
|
add_special_tokens=False,
|
|
truncation=True,
|
|
max_length=MAX_LENGTH,
|
|
)["input_ids"]
|
|
|
|
full = tokenizer(
|
|
example["text"],
|
|
add_special_tokens=False,
|
|
truncation=True,
|
|
max_length=MAX_LENGTH,
|
|
padding="max_length",
|
|
)
|
|
|
|
input_ids = full["input_ids"]
|
|
attention_mask = full["attention_mask"]
|
|
labels = input_ids.copy()
|
|
|
|
prompt_len = min(len(prompt_ids), MAX_LENGTH)
|
|
|
|
labels[:prompt_len] = [-100] * prompt_len
|
|
labels = [
|
|
label if mask == 1 else -100
|
|
for label, mask in zip(labels, attention_mask)
|
|
]
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"labels": labels,
|
|
}
|
|
|
|
|
|
tokenized_dataset = dataset.map(
|
|
tokenize_example,
|
|
remove_columns=dataset["train"].column_names,
|
|
num_proc=1,
|
|
)
|
|
|
|
print(tokenized_dataset)
|
|
print("Tokenizácia hotová.")
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_use_double_quant=True,
|
|
bnb_4bit_compute_dtype=torch.float16,
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_NAME,
|
|
quantization_config=bnb_config,
|
|
device_map={"": 0},
|
|
dtype=torch.float16,
|
|
)
|
|
|
|
model.config.use_cache = False
|
|
model.gradient_checkpointing_enable()
|
|
model = prepare_model_for_kbit_training(model)
|
|
|
|
lora_config = LoraConfig(
|
|
r=16,
|
|
lora_alpha=32,
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=[
|
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
|
"gate_proj", "up_proj", "down_proj",
|
|
],
|
|
)
|
|
|
|
model = get_peft_model(model, lora_config)
|
|
model.print_trainable_parameters()
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=str(OUTPUT_DIR),
|
|
|
|
per_device_train_batch_size=BATCH_SIZE,
|
|
per_device_eval_batch_size=BATCH_SIZE,
|
|
gradient_accumulation_steps=GRAD_ACCUM,
|
|
|
|
learning_rate=LEARNING_RATE,
|
|
num_train_epochs=NUM_EPOCHS,
|
|
max_steps=MAX_STEPS,
|
|
|
|
fp16=True,
|
|
bf16=False,
|
|
|
|
logging_steps=10,
|
|
save_steps=SAVE_STEPS,
|
|
save_total_limit=2,
|
|
|
|
eval_strategy="steps",
|
|
eval_steps=EVAL_STEPS,
|
|
|
|
optim="paged_adamw_8bit",
|
|
warmup_steps=WARMUP_STEPS,
|
|
lr_scheduler_type="cosine",
|
|
max_grad_norm=0.3,
|
|
|
|
gradient_checkpointing=True,
|
|
report_to="none",
|
|
remove_unused_columns=False,
|
|
)
|
|
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=tokenized_dataset["train"],
|
|
eval_dataset=tokenized_dataset["test"],
|
|
data_collator=default_data_collator,
|
|
)
|
|
|
|
|
|
last_checkpoint = None
|
|
if OUTPUT_DIR.exists():
|
|
last_checkpoint = get_last_checkpoint(str(OUTPUT_DIR))
|
|
|
|
if last_checkpoint is not None:
|
|
print("Pokračujem z checkpointu:", last_checkpoint)
|
|
else:
|
|
print("Začínam nový tréning.")
|
|
|
|
|
|
trainer.train(resume_from_checkpoint=last_checkpoint)
|
|
|
|
metrics = trainer.evaluate()
|
|
print(metrics)
|
|
|
|
trainer.save_model(str(ADAPTER_DIR))
|
|
tokenizer.save_pretrained(str(ADAPTER_DIR))
|
|
|
|
print("Hotovo.")
|
|
print("LoRA adaptér uložený do:")
|
|
print(ADAPTER_DIR)
|