615 lines
18 KiB
Python
615 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Unified training script for:
|
|
- SFT (QLoRA + masked loss)
|
|
- DPO (TRL DPOTrainer)
|
|
|
|
Execution model:
|
|
- Run the script normally: it shows a simple menu (SFT / DPO / BOTH).
|
|
- The script relaunches itself via accelerate (multi-process).
|
|
- After relaunch, each process runs with LOCAL_RANK set by accelerate.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import math
|
|
import argparse
|
|
import subprocess
|
|
import inspect
|
|
import traceback
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import torch
|
|
|
|
from datasets import load_dataset, load_from_disk
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForCausalLM,
|
|
BitsAndBytesConfig,
|
|
Trainer as HFTrainer,
|
|
TrainingArguments,
|
|
)
|
|
from peft import (
|
|
LoraConfig,
|
|
get_peft_model,
|
|
prepare_model_for_kbit_training,
|
|
PeftModel,
|
|
)
|
|
from trl import DPOTrainer, DPOConfig
|
|
|
|
|
|
# --------------------------
|
|
# Defaults
|
|
# --------------------------
|
|
DEFAULT_MODEL_DIR = "/home/hyrenko/Diploma/models/llama3.1-8b"
|
|
|
|
# SFT: source -> jsonl cache
|
|
SFT_SRC_DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K"
|
|
SFT_SRC_SPLIT = "train"
|
|
SFT_JSONL_PATH = "./data/pku_sft.jsonl"
|
|
|
|
# Outputs
|
|
DEFAULT_SFT_OUT = "./out/llama3_1_8b_sft_safety_lora_masked"
|
|
DEFAULT_DPO_OUT = "./out/llama3_1_8b_dpo_safety_lora"
|
|
DEFAULT_DPO_DATA = "./data/pku_dpo"
|
|
|
|
# Prompt wrappers
|
|
INSTR_PREFIX = "### Instruction:\n"
|
|
RESP_PREFIX = "\n\n### Response:\n"
|
|
|
|
|
|
# --------------------------
|
|
# Interactive menu (shown only before accelerate spawns processes)
|
|
# --------------------------
|
|
def choose_stage_simple() -> str:
|
|
"""
|
|
Returns one of: 'sft' | 'dpo' | 'both'
|
|
"""
|
|
print("\n=== TRAIN MENU ===")
|
|
print("1) SFT (QLoRA masked-loss)")
|
|
print("2) DPO (over SFT adapter)")
|
|
print("3) BOTH (SFT -> DPO)")
|
|
while True:
|
|
x = input("Choose 1/2/3: ").strip()
|
|
if x == "1":
|
|
return "sft"
|
|
if x == "2":
|
|
return "dpo"
|
|
if x == "3":
|
|
return "both"
|
|
print("Only 1, 2, or 3.")
|
|
|
|
|
|
def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]):
|
|
"""
|
|
Relaunch the script via accelerate with the requested number of processes.
|
|
No-op if already running inside accelerate (LOCAL_RANK is set).
|
|
"""
|
|
if os.environ.get("LOCAL_RANK") is not None:
|
|
return
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
"-m",
|
|
"accelerate.commands.launch",
|
|
"--num_processes",
|
|
str(num_processes),
|
|
"--mixed_precision",
|
|
"fp16",
|
|
"--main_process_port",
|
|
"0",
|
|
sys.argv[0],
|
|
*extra_argv,
|
|
]
|
|
print("\n[INFO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n")
|
|
raise SystemExit(subprocess.call(cmd))
|
|
|
|
|
|
# --------------------------
|
|
# DDP helpers
|
|
# --------------------------
|
|
def ddp_info():
|
|
"""
|
|
Returns (local_rank, rank, world_size) inferred from environment variables.
|
|
"""
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
rank = int(os.environ.get("RANK", "0"))
|
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
return local_rank, rank, world_size
|
|
|
|
|
|
def ddp_barrier():
|
|
"""
|
|
Synchronization barrier for distributed runs (no-op if torch.distributed is not initialized).
|
|
"""
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
torch.distributed.barrier()
|
|
|
|
|
|
# --------------------------
|
|
# TRL <-> Transformers patches (signature compatibility)
|
|
# --------------------------
|
|
def patch_trl_dpo_trainer_if_needed():
|
|
# Patch get_batch_samples to avoid signature mismatches across versions
|
|
try:
|
|
sig = list(inspect.signature(DPOTrainer.get_batch_samples).parameters.keys())
|
|
if len(sig) >= 3 and sig[1] == "model":
|
|
if not hasattr(DPOTrainer, "get_batch_samples_trl"):
|
|
DPOTrainer.get_batch_samples_trl = DPOTrainer.get_batch_samples
|
|
DPOTrainer.get_batch_samples = HFTrainer.get_batch_samples
|
|
print("[PATCH] Patched DPOTrainer.get_batch_samples -> HF Trainer.get_batch_samples")
|
|
except Exception:
|
|
print("[PATCH] Could not inspect/patch DPOTrainer.get_batch_samples")
|
|
|
|
# Patch compute_loss to accept num_items_in_batch if needed
|
|
try:
|
|
sig = inspect.signature(DPOTrainer.compute_loss)
|
|
if "num_items_in_batch" not in sig.parameters:
|
|
if not hasattr(DPOTrainer, "compute_loss_trl"):
|
|
DPOTrainer.compute_loss_trl = DPOTrainer.compute_loss
|
|
|
|
def compute_loss_patched(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
|
|
orig = self.compute_loss_trl
|
|
osig = inspect.signature(orig)
|
|
if "return_outputs" in osig.parameters:
|
|
return orig(model, inputs, return_outputs=return_outputs)
|
|
out = orig(model, inputs)
|
|
if return_outputs:
|
|
return out, None
|
|
return out
|
|
|
|
DPOTrainer.compute_loss = compute_loss_patched
|
|
print("[PATCH] Patched DPOTrainer.compute_loss to accept num_items_in_batch")
|
|
except Exception:
|
|
print("[PATCH] Could not inspect/patch DPOTrainer.compute_loss")
|
|
|
|
|
|
def assert_is_model(obj, name="model"):
|
|
"""
|
|
Minimal runtime check: object should look like a HF CausalLM model.
|
|
"""
|
|
if not hasattr(obj, "generate"):
|
|
raise TypeError(
|
|
"[FATAL] {} is not a HF CausalLM model. type={}; missing .generate().".format(name, type(obj))
|
|
)
|
|
|
|
|
|
# --------------------------
|
|
# SFT helpers
|
|
# --------------------------
|
|
def wrap_prompt(p: str) -> str:
|
|
"""
|
|
Wrap a raw prompt into the instruction/response template prefix.
|
|
"""
|
|
p = (p or "").strip()
|
|
return "{}{}{}".format(INSTR_PREFIX, p, RESP_PREFIX)
|
|
|
|
|
|
def ensure_sft_jsonl_exists(rank: int):
|
|
"""
|
|
Create ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing.
|
|
Only rank0 writes; all other ranks wait at a barrier.
|
|
"""
|
|
os.makedirs("./data", exist_ok=True)
|
|
|
|
if os.path.exists(SFT_JSONL_PATH) and os.path.getsize(SFT_JSONL_PATH) > 0:
|
|
if rank == 0:
|
|
print("[INFO] Found existing SFT jsonl:", SFT_JSONL_PATH)
|
|
ddp_barrier()
|
|
return
|
|
|
|
if rank == 0:
|
|
print("[INFO] SFT jsonl not found. Building from {} ...".format(SFT_SRC_DATASET_ID))
|
|
ds = load_dataset(SFT_SRC_DATASET_ID, split=SFT_SRC_SPLIT)
|
|
|
|
rows = 0
|
|
dropped = 0
|
|
|
|
with open(SFT_JSONL_PATH, "w", encoding="utf-8") as f:
|
|
for ex in ds:
|
|
prompt = (ex.get("prompt") or "").strip()
|
|
r0 = (ex.get("response_0") or "").strip()
|
|
r1 = (ex.get("response_1") or "").strip()
|
|
sid = ex.get("safer_response_id", None)
|
|
|
|
if not prompt or not r0 or not r1 or sid not in (0, 1):
|
|
dropped += 1
|
|
continue
|
|
|
|
safe = r0 if sid == 0 else r1
|
|
obj = {"prompt": wrap_prompt(prompt), "response": safe}
|
|
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
|
rows += 1
|
|
|
|
print("[OK] Built SFT jsonl:", SFT_JSONL_PATH)
|
|
print("[OK] rows={}, dropped={}".format(rows, dropped))
|
|
|
|
ddp_barrier()
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForCausalLMWithLabels:
|
|
"""
|
|
Pads input_ids/attention_mask and aligns labels to the padded length.
|
|
Labels use -100 for ignored positions (standard LM loss masking).
|
|
"""
|
|
tokenizer: Any
|
|
|
|
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
|
pad_batch = self.tokenizer.pad(
|
|
{
|
|
"input_ids": [f["input_ids"] for f in features],
|
|
"attention_mask": [f["attention_mask"] for f in features],
|
|
},
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
max_len = pad_batch["input_ids"].shape[1]
|
|
labels = []
|
|
for f in features:
|
|
lab = f["labels"]
|
|
if len(lab) < max_len:
|
|
lab = lab + [-100] * (max_len - len(lab))
|
|
else:
|
|
lab = lab[:max_len]
|
|
labels.append(lab)
|
|
|
|
pad_batch["labels"] = torch.tensor(labels, dtype=torch.long)
|
|
return pad_batch
|
|
|
|
|
|
def build_quant_base(model_dir: str, device_map: Dict[str, int]):
|
|
"""
|
|
Load a 4-bit quantized base model using bitsandbytes config.
|
|
"""
|
|
bnb = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=torch.float16,
|
|
bnb_4bit_use_double_quant=True,
|
|
)
|
|
base = AutoModelForCausalLM.from_pretrained(
|
|
model_dir,
|
|
quantization_config=bnb,
|
|
device_map=device_map,
|
|
torch_dtype=torch.float16,
|
|
)
|
|
base.config.use_cache = False
|
|
return base
|
|
|
|
|
|
def run_sft(
|
|
model_dir: str,
|
|
out_dir: str,
|
|
epochs: float,
|
|
lr: float,
|
|
bs: int,
|
|
gas: int,
|
|
seq: int,
|
|
lora_r: int,
|
|
device_map: Dict[str, int],
|
|
rank: int,
|
|
world_size: int,
|
|
) -> str:
|
|
"""
|
|
Run SFT training and return the saved LoRA adapter path (out_dir).
|
|
"""
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
ensure_sft_jsonl_exists(rank)
|
|
|
|
ds = load_dataset("json", data_files=SFT_JSONL_PATH, split="train")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
tokenizer.padding_side = "right"
|
|
|
|
base = build_quant_base(model_dir, device_map)
|
|
base = prepare_model_for_kbit_training(base)
|
|
|
|
lora = LoraConfig(
|
|
r=lora_r,
|
|
lora_alpha=32,
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=["q_proj", "v_proj"],
|
|
)
|
|
model = get_peft_model(base, lora)
|
|
|
|
def tokenize_and_mask(example):
|
|
"""
|
|
Tokenize prompt and response, and build labels masked for the prompt part.
|
|
"""
|
|
prompt = example["prompt"]
|
|
response = example["response"]
|
|
|
|
p = tokenizer(prompt, add_special_tokens=False)
|
|
r = tokenizer(response + tokenizer.eos_token, add_special_tokens=False)
|
|
|
|
input_ids = p["input_ids"] + r["input_ids"]
|
|
attention_mask = [1] * len(input_ids)
|
|
labels = [-100] * len(p["input_ids"]) + r["input_ids"]
|
|
|
|
if len(input_ids) > seq:
|
|
input_ids = input_ids[:seq]
|
|
attention_mask = attention_mask[:seq]
|
|
labels = labels[:seq]
|
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
|
|
|
ds_tok = ds.map(tokenize_and_mask, remove_columns=ds.column_names)
|
|
|
|
if rank == 0:
|
|
total_examples = len(ds_tok)
|
|
effective_bs = bs * world_size * gas
|
|
steps_per_epoch = int(math.ceil(float(total_examples) / float(effective_bs)))
|
|
print("[SFT] base={}".format(model_dir))
|
|
print("[SFT] out={}".format(out_dir))
|
|
print("[SFT] examples={}, world_size={}".format(total_examples, world_size))
|
|
print("[SFT] per_device_bs={}, gas={}, effective_batch={}".format(bs, gas, effective_bs))
|
|
print("[SFT] approx steps/epoch ≈ {}".format(steps_per_epoch))
|
|
|
|
collator = DataCollatorForCausalLMWithLabels(tokenizer)
|
|
|
|
train_args = TrainingArguments(
|
|
output_dir=out_dir,
|
|
num_train_epochs=epochs,
|
|
learning_rate=lr,
|
|
per_device_train_batch_size=bs,
|
|
gradient_accumulation_steps=gas,
|
|
logging_steps=10,
|
|
save_steps=200,
|
|
save_total_limit=2,
|
|
fp16=True,
|
|
gradient_checkpointing=True,
|
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
ddp_find_unused_parameters=False,
|
|
report_to="none",
|
|
)
|
|
|
|
trainer = HFTrainer(
|
|
model=model,
|
|
args=train_args,
|
|
train_dataset=ds_tok,
|
|
data_collator=collator,
|
|
tokenizer=tokenizer,
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
# Save only on rank0; other ranks wait at a barrier
|
|
if rank == 0:
|
|
model.save_pretrained(out_dir)
|
|
tokenizer.save_pretrained(out_dir)
|
|
print("[OK] SFT LoRA adapter saved to:", out_dir)
|
|
ddp_barrier()
|
|
|
|
return out_dir
|
|
|
|
|
|
# --------------------------
|
|
# DPO stage
|
|
# --------------------------
|
|
def run_dpo(
|
|
model_dir: str,
|
|
sft_adapter: str,
|
|
dpo_data_dir: str,
|
|
out_dir: str,
|
|
epochs: float,
|
|
lr: float,
|
|
beta: float,
|
|
bs: int,
|
|
gas: int,
|
|
max_prompt: int,
|
|
max_len: int,
|
|
device_map: Dict[str, int],
|
|
rank: int,
|
|
):
|
|
"""
|
|
Run DPO training starting from the base model + an SFT LoRA adapter.
|
|
"""
|
|
patch_trl_dpo_trainer_if_needed()
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
ds = load_from_disk(dpo_data_dir)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
base_model = build_quant_base(model_dir, device_map)
|
|
assert_is_model(base_model, "base_model")
|
|
|
|
# Load SFT adapter as trainable and continue training with DPO
|
|
model = PeftModel.from_pretrained(base_model, sft_adapter, is_trainable=True)
|
|
assert_is_model(model, "model(peft)")
|
|
|
|
if rank == 0:
|
|
print("[DPO] base:", model_dir)
|
|
print("[DPO] sft_adapter:", sft_adapter)
|
|
print("[DPO] dpo_data:", dpo_data_dir)
|
|
print("[DPO] out:", out_dir)
|
|
|
|
dpo_args = DPOConfig(
|
|
output_dir=out_dir,
|
|
num_train_epochs=epochs,
|
|
learning_rate=lr,
|
|
per_device_train_batch_size=bs,
|
|
gradient_accumulation_steps=gas,
|
|
logging_steps=10,
|
|
save_steps=200,
|
|
save_total_limit=2,
|
|
fp16=True,
|
|
report_to="none",
|
|
beta=beta,
|
|
max_prompt_length=max_prompt,
|
|
max_length=max_len,
|
|
gradient_checkpointing=True,
|
|
ddp_find_unused_parameters=False,
|
|
)
|
|
|
|
trainer = DPOTrainer(
|
|
model=model,
|
|
args=dpo_args,
|
|
train_dataset=ds,
|
|
tokenizer=tokenizer,
|
|
)
|
|
|
|
trainer.train()
|
|
trainer.save_model(out_dir)
|
|
|
|
if rank == 0:
|
|
print("[OK] DPO LoRA adapter saved to:", out_dir)
|
|
|
|
|
|
# --------------------------
|
|
# Main
|
|
# --------------------------
|
|
def main():
|
|
ap = argparse.ArgumentParser(description="Unified SFT(QLoRA masked) + DPO training script")
|
|
|
|
# Stage (menu will override this in the initial run)
|
|
ap.add_argument("--stage", choices=["sft", "dpo", "both"], default="both")
|
|
|
|
ap.add_argument("--model", default=DEFAULT_MODEL_DIR)
|
|
|
|
# SFT args
|
|
ap.add_argument("--sft_out", default=DEFAULT_SFT_OUT)
|
|
ap.add_argument("--sft_epochs", type=float, default=1.0)
|
|
ap.add_argument("--sft_lr", type=float, default=2e-4)
|
|
ap.add_argument("--sft_bs", type=int, default=1)
|
|
ap.add_argument("--sft_gas", type=int, default=16)
|
|
ap.add_argument("--sft_seq", type=int, default=1024)
|
|
ap.add_argument("--lora_r", type=int, default=16)
|
|
|
|
# DPO args
|
|
ap.add_argument("--sft_adapter", default=None, help="If omitted, uses --sft_out")
|
|
ap.add_argument("--dpo_data", default=DEFAULT_DPO_DATA)
|
|
ap.add_argument("--dpo_out", default=DEFAULT_DPO_OUT)
|
|
ap.add_argument("--dpo_epochs", type=float, default=1.0)
|
|
ap.add_argument("--dpo_lr", type=float, default=5e-6)
|
|
ap.add_argument("--beta", type=float, default=0.1)
|
|
ap.add_argument("--dpo_bs", type=int, default=1)
|
|
ap.add_argument("--dpo_gas", type=int, default=16)
|
|
ap.add_argument("--max_prompt", type=int, default=512)
|
|
ap.add_argument("--max_len", type=int, default=1024)
|
|
|
|
args = ap.parse_args()
|
|
|
|
# Show menu only in the initial, non-accelerate run
|
|
if os.environ.get("LOCAL_RANK") is None and os.environ.get("AUTO_ACCELERATE_RELAUNCH") != "1":
|
|
os.environ["AUTO_ACCELERATE_RELAUNCH"] = "1" # prevent menu duplication
|
|
|
|
stage = choose_stage_simple()
|
|
|
|
# Build argv using defaults from args
|
|
argv = [
|
|
"--stage",
|
|
stage,
|
|
"--model",
|
|
args.model,
|
|
|
|
"--sft_out",
|
|
args.sft_out,
|
|
"--sft_epochs",
|
|
str(args.sft_epochs),
|
|
"--sft_lr",
|
|
str(args.sft_lr),
|
|
"--sft_bs",
|
|
str(args.sft_bs),
|
|
"--sft_gas",
|
|
str(args.sft_gas),
|
|
"--sft_seq",
|
|
str(args.sft_seq),
|
|
"--lora_r",
|
|
str(args.lora_r),
|
|
|
|
"--dpo_data",
|
|
args.dpo_data,
|
|
"--dpo_out",
|
|
args.dpo_out,
|
|
"--dpo_epochs",
|
|
str(args.dpo_epochs),
|
|
"--dpo_lr",
|
|
str(args.dpo_lr),
|
|
"--beta",
|
|
str(args.beta),
|
|
"--dpo_bs",
|
|
str(args.dpo_bs),
|
|
"--dpo_gas",
|
|
str(args.dpo_gas),
|
|
"--max_prompt",
|
|
str(args.max_prompt),
|
|
"--max_len",
|
|
str(args.max_len),
|
|
]
|
|
|
|
if stage in ("dpo", "both"):
|
|
sft_adapter = args.sft_adapter or args.sft_out
|
|
argv += ["--sft_adapter", sft_adapter]
|
|
|
|
# Relaunch under accelerate (2 processes)
|
|
relaunch_with_accelerate(num_processes=2, extra_argv=argv)
|
|
|
|
# --------------------------
|
|
# Running inside accelerate
|
|
# --------------------------
|
|
local_rank, rank, world_size = ddp_info()
|
|
torch.cuda.set_device(local_rank)
|
|
device_map = {"": torch.cuda.current_device()}
|
|
|
|
# Fixed seed
|
|
torch.manual_seed(42)
|
|
|
|
sft_adapter_path = args.sft_adapter or args.sft_out
|
|
|
|
if args.stage in ("sft", "both"):
|
|
sft_adapter_path = run_sft(
|
|
model_dir=args.model,
|
|
out_dir=args.sft_out,
|
|
epochs=args.sft_epochs,
|
|
lr=args.sft_lr,
|
|
bs=args.sft_bs,
|
|
gas=args.sft_gas,
|
|
seq=args.sft_seq,
|
|
lora_r=args.lora_r,
|
|
device_map=device_map,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
|
|
if args.stage in ("dpo", "both"):
|
|
run_dpo(
|
|
model_dir=args.model,
|
|
sft_adapter=sft_adapter_path,
|
|
dpo_data_dir=args.dpo_data,
|
|
out_dir=args.dpo_out,
|
|
epochs=args.dpo_epochs,
|
|
lr=args.dpo_lr,
|
|
beta=args.beta,
|
|
bs=args.dpo_bs,
|
|
gas=args.dpo_gas,
|
|
max_prompt=args.max_prompt,
|
|
max_len=args.max_len,
|
|
device_map=device_map,
|
|
rank=rank,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except Exception:
|
|
r = os.environ.get("RANK", "?")
|
|
lr = os.environ.get("LOCAL_RANK", "?")
|
|
print("\n[FATAL] Exception on RANK={} LOCAL_RANK={}\n".format(r, lr))
|
|
print(traceback.format_exc())
|
|
raise
|