DP/Training/training_dpo_sft_llama.py
2026-02-04 20:38:37 +00:00

615 lines
18 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Unified training script for:
- SFT (QLoRA + masked loss)
- DPO (TRL DPOTrainer)
Execution model:
- Run the script normally: it shows a simple menu (SFT / DPO / BOTH).
- The script relaunches itself via accelerate (multi-process).
- After relaunch, each process runs with LOCAL_RANK set by accelerate.
"""
import os
import sys
import json
import math
import argparse
import subprocess
import inspect
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from datasets import load_dataset, load_from_disk
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
Trainer as HFTrainer,
TrainingArguments,
)
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
PeftModel,
)
from trl import DPOTrainer, DPOConfig
# --------------------------
# Defaults
# --------------------------
DEFAULT_MODEL_DIR = "/home/hyrenko/Diploma/models/llama3.1-8b"
# SFT: source -> jsonl cache
SFT_SRC_DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K"
SFT_SRC_SPLIT = "train"
SFT_JSONL_PATH = "./data/pku_sft.jsonl"
# Outputs
DEFAULT_SFT_OUT = "./out/llama3_1_8b_sft_safety_lora_masked"
DEFAULT_DPO_OUT = "./out/llama3_1_8b_dpo_safety_lora"
DEFAULT_DPO_DATA = "./data/pku_dpo"
# Prompt wrappers
INSTR_PREFIX = "### Instruction:\n"
RESP_PREFIX = "\n\n### Response:\n"
# --------------------------
# Interactive menu (shown only before accelerate spawns processes)
# --------------------------
def choose_stage_simple() -> str:
"""
Returns one of: 'sft' | 'dpo' | 'both'
"""
print("\n=== TRAIN MENU ===")
print("1) SFT (QLoRA masked-loss)")
print("2) DPO (over SFT adapter)")
print("3) BOTH (SFT -> DPO)")
while True:
x = input("Choose 1/2/3: ").strip()
if x == "1":
return "sft"
if x == "2":
return "dpo"
if x == "3":
return "both"
print("Only 1, 2, or 3.")
def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]):
"""
Relaunch the script via accelerate with the requested number of processes.
No-op if already running inside accelerate (LOCAL_RANK is set).
"""
if os.environ.get("LOCAL_RANK") is not None:
return
cmd = [
sys.executable,
"-m",
"accelerate.commands.launch",
"--num_processes",
str(num_processes),
"--mixed_precision",
"fp16",
"--main_process_port",
"0",
sys.argv[0],
*extra_argv,
]
print("\n[INFO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n")
raise SystemExit(subprocess.call(cmd))
# --------------------------
# DDP helpers
# --------------------------
def ddp_info():
"""
Returns (local_rank, rank, world_size) inferred from environment variables.
"""
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
return local_rank, rank, world_size
def ddp_barrier():
"""
Synchronization barrier for distributed runs (no-op if torch.distributed is not initialized).
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
# --------------------------
# TRL <-> Transformers patches (signature compatibility)
# --------------------------
def patch_trl_dpo_trainer_if_needed():
# Patch get_batch_samples to avoid signature mismatches across versions
try:
sig = list(inspect.signature(DPOTrainer.get_batch_samples).parameters.keys())
if len(sig) >= 3 and sig[1] == "model":
if not hasattr(DPOTrainer, "get_batch_samples_trl"):
DPOTrainer.get_batch_samples_trl = DPOTrainer.get_batch_samples
DPOTrainer.get_batch_samples = HFTrainer.get_batch_samples
print("[PATCH] Patched DPOTrainer.get_batch_samples -> HF Trainer.get_batch_samples")
except Exception:
print("[PATCH] Could not inspect/patch DPOTrainer.get_batch_samples")
# Patch compute_loss to accept num_items_in_batch if needed
try:
sig = inspect.signature(DPOTrainer.compute_loss)
if "num_items_in_batch" not in sig.parameters:
if not hasattr(DPOTrainer, "compute_loss_trl"):
DPOTrainer.compute_loss_trl = DPOTrainer.compute_loss
def compute_loss_patched(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
orig = self.compute_loss_trl
osig = inspect.signature(orig)
if "return_outputs" in osig.parameters:
return orig(model, inputs, return_outputs=return_outputs)
out = orig(model, inputs)
if return_outputs:
return out, None
return out
DPOTrainer.compute_loss = compute_loss_patched
print("[PATCH] Patched DPOTrainer.compute_loss to accept num_items_in_batch")
except Exception:
print("[PATCH] Could not inspect/patch DPOTrainer.compute_loss")
def assert_is_model(obj, name="model"):
"""
Minimal runtime check: object should look like a HF CausalLM model.
"""
if not hasattr(obj, "generate"):
raise TypeError(
"[FATAL] {} is not a HF CausalLM model. type={}; missing .generate().".format(name, type(obj))
)
# --------------------------
# SFT helpers
# --------------------------
def wrap_prompt(p: str) -> str:
"""
Wrap a raw prompt into the instruction/response template prefix.
"""
p = (p or "").strip()
return "{}{}{}".format(INSTR_PREFIX, p, RESP_PREFIX)
def ensure_sft_jsonl_exists(rank: int):
"""
Create ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing.
Only rank0 writes; all other ranks wait at a barrier.
"""
os.makedirs("./data", exist_ok=True)
if os.path.exists(SFT_JSONL_PATH) and os.path.getsize(SFT_JSONL_PATH) > 0:
if rank == 0:
print("[INFO] Found existing SFT jsonl:", SFT_JSONL_PATH)
ddp_barrier()
return
if rank == 0:
print("[INFO] SFT jsonl not found. Building from {} ...".format(SFT_SRC_DATASET_ID))
ds = load_dataset(SFT_SRC_DATASET_ID, split=SFT_SRC_SPLIT)
rows = 0
dropped = 0
with open(SFT_JSONL_PATH, "w", encoding="utf-8") as f:
for ex in ds:
prompt = (ex.get("prompt") or "").strip()
r0 = (ex.get("response_0") or "").strip()
r1 = (ex.get("response_1") or "").strip()
sid = ex.get("safer_response_id", None)
if not prompt or not r0 or not r1 or sid not in (0, 1):
dropped += 1
continue
safe = r0 if sid == 0 else r1
obj = {"prompt": wrap_prompt(prompt), "response": safe}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
rows += 1
print("[OK] Built SFT jsonl:", SFT_JSONL_PATH)
print("[OK] rows={}, dropped={}".format(rows, dropped))
ddp_barrier()
@dataclass
class DataCollatorForCausalLMWithLabels:
"""
Pads input_ids/attention_mask and aligns labels to the padded length.
Labels use -100 for ignored positions (standard LM loss masking).
"""
tokenizer: Any
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
pad_batch = self.tokenizer.pad(
{
"input_ids": [f["input_ids"] for f in features],
"attention_mask": [f["attention_mask"] for f in features],
},
padding=True,
return_tensors="pt",
)
max_len = pad_batch["input_ids"].shape[1]
labels = []
for f in features:
lab = f["labels"]
if len(lab) < max_len:
lab = lab + [-100] * (max_len - len(lab))
else:
lab = lab[:max_len]
labels.append(lab)
pad_batch["labels"] = torch.tensor(labels, dtype=torch.long)
return pad_batch
def build_quant_base(model_dir: str, device_map: Dict[str, int]):
"""
Load a 4-bit quantized base model using bitsandbytes config.
"""
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
base = AutoModelForCausalLM.from_pretrained(
model_dir,
quantization_config=bnb,
device_map=device_map,
torch_dtype=torch.float16,
)
base.config.use_cache = False
return base
def run_sft(
model_dir: str,
out_dir: str,
epochs: float,
lr: float,
bs: int,
gas: int,
seq: int,
lora_r: int,
device_map: Dict[str, int],
rank: int,
world_size: int,
) -> str:
"""
Run SFT training and return the saved LoRA adapter path (out_dir).
"""
os.makedirs(out_dir, exist_ok=True)
ensure_sft_jsonl_exists(rank)
ds = load_dataset("json", data_files=SFT_JSONL_PATH, split="train")
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base = build_quant_base(model_dir, device_map)
base = prepare_model_for_kbit_training(base)
lora = LoraConfig(
r=lora_r,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "v_proj"],
)
model = get_peft_model(base, lora)
def tokenize_and_mask(example):
"""
Tokenize prompt and response, and build labels masked for the prompt part.
"""
prompt = example["prompt"]
response = example["response"]
p = tokenizer(prompt, add_special_tokens=False)
r = tokenizer(response + tokenizer.eos_token, add_special_tokens=False)
input_ids = p["input_ids"] + r["input_ids"]
attention_mask = [1] * len(input_ids)
labels = [-100] * len(p["input_ids"]) + r["input_ids"]
if len(input_ids) > seq:
input_ids = input_ids[:seq]
attention_mask = attention_mask[:seq]
labels = labels[:seq]
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
ds_tok = ds.map(tokenize_and_mask, remove_columns=ds.column_names)
if rank == 0:
total_examples = len(ds_tok)
effective_bs = bs * world_size * gas
steps_per_epoch = int(math.ceil(float(total_examples) / float(effective_bs)))
print("[SFT] base={}".format(model_dir))
print("[SFT] out={}".format(out_dir))
print("[SFT] examples={}, world_size={}".format(total_examples, world_size))
print("[SFT] per_device_bs={}, gas={}, effective_batch={}".format(bs, gas, effective_bs))
print("[SFT] approx steps/epoch ≈ {}".format(steps_per_epoch))
collator = DataCollatorForCausalLMWithLabels(tokenizer)
train_args = TrainingArguments(
output_dir=out_dir,
num_train_epochs=epochs,
learning_rate=lr,
per_device_train_batch_size=bs,
gradient_accumulation_steps=gas,
logging_steps=10,
save_steps=200,
save_total_limit=2,
fp16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
ddp_find_unused_parameters=False,
report_to="none",
)
trainer = HFTrainer(
model=model,
args=train_args,
train_dataset=ds_tok,
data_collator=collator,
tokenizer=tokenizer,
)
trainer.train()
# Save only on rank0; other ranks wait at a barrier
if rank == 0:
model.save_pretrained(out_dir)
tokenizer.save_pretrained(out_dir)
print("[OK] SFT LoRA adapter saved to:", out_dir)
ddp_barrier()
return out_dir
# --------------------------
# DPO stage
# --------------------------
def run_dpo(
model_dir: str,
sft_adapter: str,
dpo_data_dir: str,
out_dir: str,
epochs: float,
lr: float,
beta: float,
bs: int,
gas: int,
max_prompt: int,
max_len: int,
device_map: Dict[str, int],
rank: int,
):
"""
Run DPO training starting from the base model + an SFT LoRA adapter.
"""
patch_trl_dpo_trainer_if_needed()
os.makedirs(out_dir, exist_ok=True)
ds = load_from_disk(dpo_data_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = build_quant_base(model_dir, device_map)
assert_is_model(base_model, "base_model")
# Load SFT adapter as trainable and continue training with DPO
model = PeftModel.from_pretrained(base_model, sft_adapter, is_trainable=True)
assert_is_model(model, "model(peft)")
if rank == 0:
print("[DPO] base:", model_dir)
print("[DPO] sft_adapter:", sft_adapter)
print("[DPO] dpo_data:", dpo_data_dir)
print("[DPO] out:", out_dir)
dpo_args = DPOConfig(
output_dir=out_dir,
num_train_epochs=epochs,
learning_rate=lr,
per_device_train_batch_size=bs,
gradient_accumulation_steps=gas,
logging_steps=10,
save_steps=200,
save_total_limit=2,
fp16=True,
report_to="none",
beta=beta,
max_prompt_length=max_prompt,
max_length=max_len,
gradient_checkpointing=True,
ddp_find_unused_parameters=False,
)
trainer = DPOTrainer(
model=model,
args=dpo_args,
train_dataset=ds,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(out_dir)
if rank == 0:
print("[OK] DPO LoRA adapter saved to:", out_dir)
# --------------------------
# Main
# --------------------------
def main():
ap = argparse.ArgumentParser(description="Unified SFT(QLoRA masked) + DPO training script")
# Stage (menu will override this in the initial run)
ap.add_argument("--stage", choices=["sft", "dpo", "both"], default="both")
ap.add_argument("--model", default=DEFAULT_MODEL_DIR)
# SFT args
ap.add_argument("--sft_out", default=DEFAULT_SFT_OUT)
ap.add_argument("--sft_epochs", type=float, default=1.0)
ap.add_argument("--sft_lr", type=float, default=2e-4)
ap.add_argument("--sft_bs", type=int, default=1)
ap.add_argument("--sft_gas", type=int, default=16)
ap.add_argument("--sft_seq", type=int, default=1024)
ap.add_argument("--lora_r", type=int, default=16)
# DPO args
ap.add_argument("--sft_adapter", default=None, help="If omitted, uses --sft_out")
ap.add_argument("--dpo_data", default=DEFAULT_DPO_DATA)
ap.add_argument("--dpo_out", default=DEFAULT_DPO_OUT)
ap.add_argument("--dpo_epochs", type=float, default=1.0)
ap.add_argument("--dpo_lr", type=float, default=5e-6)
ap.add_argument("--beta", type=float, default=0.1)
ap.add_argument("--dpo_bs", type=int, default=1)
ap.add_argument("--dpo_gas", type=int, default=16)
ap.add_argument("--max_prompt", type=int, default=512)
ap.add_argument("--max_len", type=int, default=1024)
args = ap.parse_args()
# Show menu only in the initial, non-accelerate run
if os.environ.get("LOCAL_RANK") is None and os.environ.get("AUTO_ACCELERATE_RELAUNCH") != "1":
os.environ["AUTO_ACCELERATE_RELAUNCH"] = "1" # prevent menu duplication
stage = choose_stage_simple()
# Build argv using defaults from args
argv = [
"--stage",
stage,
"--model",
args.model,
"--sft_out",
args.sft_out,
"--sft_epochs",
str(args.sft_epochs),
"--sft_lr",
str(args.sft_lr),
"--sft_bs",
str(args.sft_bs),
"--sft_gas",
str(args.sft_gas),
"--sft_seq",
str(args.sft_seq),
"--lora_r",
str(args.lora_r),
"--dpo_data",
args.dpo_data,
"--dpo_out",
args.dpo_out,
"--dpo_epochs",
str(args.dpo_epochs),
"--dpo_lr",
str(args.dpo_lr),
"--beta",
str(args.beta),
"--dpo_bs",
str(args.dpo_bs),
"--dpo_gas",
str(args.dpo_gas),
"--max_prompt",
str(args.max_prompt),
"--max_len",
str(args.max_len),
]
if stage in ("dpo", "both"):
sft_adapter = args.sft_adapter or args.sft_out
argv += ["--sft_adapter", sft_adapter]
# Relaunch under accelerate (2 processes)
relaunch_with_accelerate(num_processes=2, extra_argv=argv)
# --------------------------
# Running inside accelerate
# --------------------------
local_rank, rank, world_size = ddp_info()
torch.cuda.set_device(local_rank)
device_map = {"": torch.cuda.current_device()}
# Fixed seed
torch.manual_seed(42)
sft_adapter_path = args.sft_adapter or args.sft_out
if args.stage in ("sft", "both"):
sft_adapter_path = run_sft(
model_dir=args.model,
out_dir=args.sft_out,
epochs=args.sft_epochs,
lr=args.sft_lr,
bs=args.sft_bs,
gas=args.sft_gas,
seq=args.sft_seq,
lora_r=args.lora_r,
device_map=device_map,
rank=rank,
world_size=world_size,
)
if args.stage in ("dpo", "both"):
run_dpo(
model_dir=args.model,
sft_adapter=sft_adapter_path,
dpo_data_dir=args.dpo_data,
out_dir=args.dpo_out,
epochs=args.dpo_epochs,
lr=args.dpo_lr,
beta=args.beta,
bs=args.dpo_bs,
gas=args.dpo_gas,
max_prompt=args.max_prompt,
max_len=args.max_len,
device_map=device_map,
rank=rank,
)
if __name__ == "__main__":
try:
main()
except Exception:
r = os.environ.get("RANK", "?")
lr = os.environ.get("LOCAL_RANK", "?")
print("\n[FATAL] Exception on RANK={} LOCAL_RANK={}\n".format(r, lr))
print(traceback.format_exc())
raise