DP/Training/training_dpo_sft_llama.py
2026-02-04 21:07:17 +01:00

597 lines
18 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Unified training script: SFT (QLoRA + masked loss) and DPO (TRL) in ONE file.
UX requirement (as you asked):
- You run: python3 train_unified.py
- It shows a simple menu 1/2/3
- You press 1, 2, or 3
- Script relaunches itself with accelerate (2 GPUs) and runs the chosen stage
using DEFAULT paths/hparams already embedded.
Notes:
- Menu is shown ONLY once (before accelerate spawns processes).
- After relaunch, each process runs normally with LOCAL_RANK set by accelerate.
"""
import os
import sys
import json
import math
import argparse
import subprocess
import inspect
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from datasets import load_dataset, load_from_disk
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
Trainer as HFTrainer,
TrainingArguments,
)
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
PeftModel,
)
from trl import DPOTrainer, DPOConfig
# --------------------------
# Defaults (same spirit as your original scripts)
# --------------------------
DEFAULT_MODEL_DIR = "/home/hyrenko/Diploma/models/llama3.1-8b"
# SFT: source -> jsonl cache
SFT_SRC_DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K"
SFT_SRC_SPLIT = "train"
SFT_JSONL_PATH = "./data/pku_sft.jsonl"
# Outputs
DEFAULT_SFT_OUT = "./out/llama3_1_8b_sft_safety_lora_masked"
DEFAULT_DPO_OUT = "./out/llama3_1_8b_dpo_safety_lora"
DEFAULT_DPO_DATA = "./data/pku_dpo"
# Prompt wrappers (keep consistent)
INSTR_PREFIX = "### Instruction:\n"
RESP_PREFIX = "\n\n### Response:\n"
# --------------------------
# Simple interactive menu (shown only before accelerate)
# --------------------------
def choose_stage_simple() -> str:
"""
Returns: 'sft' | 'dpo' | 'both'
"""
print("\n=== TRAIN MENU ===")
print("1) SFT (QLoRA masked-loss)")
print("2) DPO (over SFT adapter)")
print("3) BOTH (SFT -> DPO)")
while True:
x = input("Choose 1/2/3: ").strip()
if x == "1":
return "sft"
if x == "2":
return "dpo"
if x == "3":
return "both"
print("Only 1, 2, or 3.")
def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]):
"""
Relaunches the script via accelerate (2 processes by default).
Only runs if NOT already inside accelerate (LOCAL_RANK not set).
"""
if os.environ.get("LOCAL_RANK") is not None:
return
cmd = [
sys.executable,
"-m",
"accelerate.commands.launch",
"--num_processes",
str(num_processes),
"--mixed_precision",
"fp16",
"--main_process_port",
"0",
sys.argv[0],
*extra_argv,
]
print("\n[AUTO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n")
raise SystemExit(subprocess.call(cmd))
# --------------------------
# DDP helpers
# --------------------------
def ddp_info():
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
return local_rank, rank, world_size
def ddp_barrier():
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
# --------------------------
# TRL <-> Transformers patches (to avoid signature mismatches)
# --------------------------
def patch_trl_dpo_trainer_if_needed():
# Patch get_batch_samples
try:
sig = list(inspect.signature(DPOTrainer.get_batch_samples).parameters.keys())
if len(sig) >= 3 and sig[1] == "model":
if not hasattr(DPOTrainer, "get_batch_samples_trl"):
DPOTrainer.get_batch_samples_trl = DPOTrainer.get_batch_samples
DPOTrainer.get_batch_samples = HFTrainer.get_batch_samples
print("[PATCH] Patched DPOTrainer.get_batch_samples -> HF Trainer.get_batch_samples")
except Exception:
print("[PATCH] Could not inspect/patch DPOTrainer.get_batch_samples")
# Patch compute_loss
try:
sig = inspect.signature(DPOTrainer.compute_loss)
if "num_items_in_batch" not in sig.parameters:
if not hasattr(DPOTrainer, "compute_loss_trl"):
DPOTrainer.compute_loss_trl = DPOTrainer.compute_loss
def compute_loss_patched(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
orig = self.compute_loss_trl
osig = inspect.signature(orig)
if "return_outputs" in osig.parameters:
return orig(model, inputs, return_outputs=return_outputs)
out = orig(model, inputs)
if return_outputs:
return out, None
return out
DPOTrainer.compute_loss = compute_loss_patched
print("[PATCH] Patched DPOTrainer.compute_loss to accept num_items_in_batch")
except Exception:
print("[PATCH] Could not inspect/patch DPOTrainer.compute_loss")
def assert_is_model(obj, name="model"):
if not hasattr(obj, "generate"):
raise TypeError(
"[FATAL] {} is not a HF CausalLM model. type={}; missing .generate().".format(name, type(obj))
)
# --------------------------
# SFT helpers
# --------------------------
def wrap_prompt(p: str) -> str:
p = (p or "").strip()
return "{}{}{}".format(INSTR_PREFIX, p, RESP_PREFIX)
def ensure_sft_jsonl_exists(rank: int):
"""
Creates ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing.
Only rank0 writes; others wait.
"""
os.makedirs("./data", exist_ok=True)
if os.path.exists(SFT_JSONL_PATH) and os.path.getsize(SFT_JSONL_PATH) > 0:
if rank == 0:
print("[INFO] Found existing SFT jsonl:", SFT_JSONL_PATH)
ddp_barrier()
return
if rank == 0:
print("[INFO] SFT jsonl not found. Building from {} ...".format(SFT_SRC_DATASET_ID))
ds = load_dataset(SFT_SRC_DATASET_ID, split=SFT_SRC_SPLIT)
rows = 0
dropped = 0
with open(SFT_JSONL_PATH, "w", encoding="utf-8") as f:
for ex in ds:
prompt = (ex.get("prompt") or "").strip()
r0 = (ex.get("response_0") or "").strip()
r1 = (ex.get("response_1") or "").strip()
sid = ex.get("safer_response_id", None)
if not prompt or not r0 or not r1 or sid not in (0, 1):
dropped += 1
continue
safe = r0 if sid == 0 else r1
obj = {"prompt": wrap_prompt(prompt), "response": safe}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
rows += 1
print("[OK] Built SFT jsonl:", SFT_JSONL_PATH)
print("[OK] rows={}, dropped={}".format(rows, dropped))
ddp_barrier()
@dataclass
class DataCollatorForCausalLMWithLabels:
tokenizer: Any
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
pad_batch = self.tokenizer.pad(
{
"input_ids": [f["input_ids"] for f in features],
"attention_mask": [f["attention_mask"] for f in features],
},
padding=True,
return_tensors="pt",
)
max_len = pad_batch["input_ids"].shape[1]
labels = []
for f in features:
lab = f["labels"]
if len(lab) < max_len:
lab = lab + [-100] * (max_len - len(lab))
else:
lab = lab[:max_len]
labels.append(lab)
pad_batch["labels"] = torch.tensor(labels, dtype=torch.long)
return pad_batch
def build_quant_base(model_dir: str, device_map: Dict[str, int]):
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
base = AutoModelForCausalLM.from_pretrained(
model_dir,
quantization_config=bnb,
device_map=device_map,
torch_dtype=torch.float16,
)
base.config.use_cache = False
return base
def run_sft(
model_dir: str,
out_dir: str,
epochs: float,
lr: float,
bs: int,
gas: int,
seq: int,
lora_r: int,
device_map: Dict[str, int],
rank: int,
world_size: int,
) -> str:
"""
Runs SFT and returns the path to the saved LoRA adapter (out_dir).
"""
os.makedirs(out_dir, exist_ok=True)
ensure_sft_jsonl_exists(rank)
ds = load_dataset("json", data_files=SFT_JSONL_PATH, split="train")
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base = build_quant_base(model_dir, device_map)
base = prepare_model_for_kbit_training(base)
lora = LoraConfig(
r=lora_r,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "v_proj"],
)
model = get_peft_model(base, lora)
def tokenize_and_mask(example):
prompt = example["prompt"]
response = example["response"]
p = tokenizer(prompt, add_special_tokens=False)
r = tokenizer(response + tokenizer.eos_token, add_special_tokens=False)
input_ids = p["input_ids"] + r["input_ids"]
attention_mask = [1] * len(input_ids)
labels = [-100] * len(p["input_ids"]) + r["input_ids"]
if len(input_ids) > seq:
input_ids = input_ids[:seq]
attention_mask = attention_mask[:seq]
labels = labels[:seq]
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
ds_tok = ds.map(tokenize_and_mask, remove_columns=ds.column_names)
if rank == 0:
total_examples = len(ds_tok)
effective_bs = bs * world_size * gas
steps_per_epoch = int(math.ceil(float(total_examples) / float(effective_bs)))
print("[SFT] base={}".format(model_dir))
print("[SFT] out={}".format(out_dir))
print("[SFT] examples={}, world_size={}".format(total_examples, world_size))
print("[SFT] per_device_bs={}, gas={}, effective_batch={}".format(bs, gas, effective_bs))
print("[SFT] approx steps/epoch ≈ {}".format(steps_per_epoch))
collator = DataCollatorForCausalLMWithLabels(tokenizer)
train_args = TrainingArguments(
output_dir=out_dir,
num_train_epochs=epochs,
learning_rate=lr,
per_device_train_batch_size=bs,
gradient_accumulation_steps=gas,
logging_steps=10,
save_steps=200,
save_total_limit=2,
fp16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
ddp_find_unused_parameters=False,
report_to="none",
)
trainer = HFTrainer(
model=model,
args=train_args,
train_dataset=ds_tok,
data_collator=collator,
tokenizer=tokenizer,
)
trainer.train()
# Save only on rank0, then barrier so others wait
if rank == 0:
model.save_pretrained(out_dir)
tokenizer.save_pretrained(out_dir)
print("[OK] SFT LoRA adapter saved to:", out_dir)
ddp_barrier()
return out_dir
# --------------------------
# DPO stage
# --------------------------
def run_dpo(
model_dir: str,
sft_adapter: str,
dpo_data_dir: str,
out_dir: str,
epochs: float,
lr: float,
beta: float,
bs: int,
gas: int,
max_prompt: int,
max_len: int,
device_map: Dict[str, int],
rank: int,
):
patch_trl_dpo_trainer_if_needed()
os.makedirs(out_dir, exist_ok=True)
ds = load_from_disk(dpo_data_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = build_quant_base(model_dir, device_map)
assert_is_model(base_model, "base_model")
# Load SFT adapter and keep training it with DPO
model = PeftModel.from_pretrained(base_model, sft_adapter, is_trainable=True)
assert_is_model(model, "model(peft)")
if rank == 0:
print("[DPO] base:", model_dir)
print("[DPO] sft_adapter:", sft_adapter)
print("[DPO] dpo_data:", dpo_data_dir)
print("[DPO] out:", out_dir)
dpo_args = DPOConfig(
output_dir=out_dir,
num_train_epochs=epochs,
learning_rate=lr,
per_device_train_batch_size=bs,
gradient_accumulation_steps=gas,
logging_steps=10,
save_steps=200,
save_total_limit=2,
fp16=True,
report_to="none",
beta=beta,
max_prompt_length=max_prompt,
max_length=max_len,
gradient_checkpointing=True,
ddp_find_unused_parameters=False,
)
trainer = DPOTrainer(
model=model,
args=dpo_args,
train_dataset=ds,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(out_dir)
if rank == 0:
print("[OK] DPO LoRA adapter saved to:", out_dir)
# --------------------------
# Main
# --------------------------
def main():
ap = argparse.ArgumentParser(description="Unified SFT(QLoRA masked) + DPO training script")
# Stage (menu will override this in the initial run)
ap.add_argument("--stage", choices=["sft", "dpo", "both"], default="both")
ap.add_argument("--model", default=DEFAULT_MODEL_DIR)
# SFT args
ap.add_argument("--sft_out", default=DEFAULT_SFT_OUT)
ap.add_argument("--sft_epochs", type=float, default=1.0)
ap.add_argument("--sft_lr", type=float, default=2e-4)
ap.add_argument("--sft_bs", type=int, default=1)
ap.add_argument("--sft_gas", type=int, default=16)
ap.add_argument("--sft_seq", type=int, default=1024)
ap.add_argument("--lora_r", type=int, default=16)
# DPO args
ap.add_argument("--sft_adapter", default=None, help="If omitted, uses --sft_out")
ap.add_argument("--dpo_data", default=DEFAULT_DPO_DATA)
ap.add_argument("--dpo_out", default=DEFAULT_DPO_OUT)
ap.add_argument("--dpo_epochs", type=float, default=1.0)
ap.add_argument("--dpo_lr", type=float, default=5e-6)
ap.add_argument("--beta", type=float, default=0.1)
ap.add_argument("--dpo_bs", type=int, default=1)
ap.add_argument("--dpo_gas", type=int, default=16)
ap.add_argument("--max_prompt", type=int, default=512)
ap.add_argument("--max_len", type=int, default=1024)
args = ap.parse_args()
# --------------------------
# AUTO MENU (only once, BEFORE accelerate spawns processes)
# --------------------------
if os.environ.get("LOCAL_RANK") is None and os.environ.get("AUTO_ACCELERATE_RELAUNCH") != "1":
os.environ["AUTO_ACCELERATE_RELAUNCH"] = "1" # prevent menu duplication
stage = choose_stage_simple()
# Build argv using defaults from args (you only press 1/2/3)
argv = [
"--stage",
stage,
"--model",
args.model,
"--sft_out",
args.sft_out,
"--sft_epochs",
str(args.sft_epochs),
"--sft_lr",
str(args.sft_lr),
"--sft_bs",
str(args.sft_bs),
"--sft_gas",
str(args.sft_gas),
"--sft_seq",
str(args.sft_seq),
"--lora_r",
str(args.lora_r),
"--dpo_data",
args.dpo_data,
"--dpo_out",
args.dpo_out,
"--dpo_epochs",
str(args.dpo_epochs),
"--dpo_lr",
str(args.dpo_lr),
"--beta",
str(args.beta),
"--dpo_bs",
str(args.dpo_bs),
"--dpo_gas",
str(args.dpo_gas),
"--max_prompt",
str(args.max_prompt),
"--max_len",
str(args.max_len),
]
if stage in ("dpo", "both"):
sft_adapter = args.sft_adapter or args.sft_out
argv += ["--sft_adapter", sft_adapter]
# Relaunch under accelerate (2 GPUs)
relaunch_with_accelerate(num_processes=2, extra_argv=argv)
# --------------------------
# We are inside accelerate now
# --------------------------
local_rank, rank, world_size = ddp_info()
torch.cuda.set_device(local_rank)
device_map = {"": torch.cuda.current_device()}
# Deterministic seed (as in your DPO script style)
torch.manual_seed(42)
sft_adapter_path = args.sft_adapter or args.sft_out
if args.stage in ("sft", "both"):
sft_adapter_path = run_sft(
model_dir=args.model,
out_dir=args.sft_out,
epochs=args.sft_epochs,
lr=args.sft_lr,
bs=args.sft_bs,
gas=args.sft_gas,
seq=args.sft_seq,
lora_r=args.lora_r,
device_map=device_map,
rank=rank,
world_size=world_size,
)
if args.stage in ("dpo", "both"):
# If "both", barrier already happened after saving SFT adapter
run_dpo(
model_dir=args.model,
sft_adapter=sft_adapter_path,
dpo_data_dir=args.dpo_data,
out_dir=args.dpo_out,
epochs=args.dpo_epochs,
lr=args.dpo_lr,
beta=args.beta,
bs=args.dpo_bs,
gas=args.dpo_gas,
max_prompt=args.max_prompt,
max_len=args.max_len,
device_map=device_map,
rank=rank,
)
if __name__ == "__main__":
try:
main()
except Exception:
r = os.environ.get("RANK", "?")
lr = os.environ.get("LOCAL_RANK", "?")
print("\n[FATAL] Exception on RANK={} LOCAL_RANK={}\n".format(r, lr))
print(traceback.format_exc())
raise