Added comments

This commit is contained in:
Artur Hyrenko 2026-02-04 20:38:37 +00:00
parent 18cdd5daa6
commit 7bb03b551f

View File

@ -2,18 +2,14 @@
# -*- coding: utf-8 -*-
"""
Unified training script: SFT (QLoRA + masked loss) and DPO (TRL) in ONE file.
Unified training script for:
- SFT (QLoRA + masked loss)
- DPO (TRL DPOTrainer)
UX requirement (as you asked):
- You run: python3 train_unified.py
- It shows a simple menu 1/2/3
- You press 1, 2, or 3
- Script relaunches itself with accelerate (2 GPUs) and runs the chosen stage
using DEFAULT paths/hparams already embedded.
Notes:
- Menu is shown ONLY once (before accelerate spawns processes).
- After relaunch, each process runs normally with LOCAL_RANK set by accelerate.
Execution model:
- Run the script normally: it shows a simple menu (SFT / DPO / BOTH).
- The script relaunches itself via accelerate (multi-process).
- After relaunch, each process runs with LOCAL_RANK set by accelerate.
"""
import os
@ -47,7 +43,7 @@ from trl import DPOTrainer, DPOConfig
# --------------------------
# Defaults (same spirit as your original scripts)
# Defaults
# --------------------------
DEFAULT_MODEL_DIR = "/home/hyrenko/Diploma/models/llama3.1-8b"
@ -61,17 +57,17 @@ DEFAULT_SFT_OUT = "./out/llama3_1_8b_sft_safety_lora_masked"
DEFAULT_DPO_OUT = "./out/llama3_1_8b_dpo_safety_lora"
DEFAULT_DPO_DATA = "./data/pku_dpo"
# Prompt wrappers (keep consistent)
# Prompt wrappers
INSTR_PREFIX = "### Instruction:\n"
RESP_PREFIX = "\n\n### Response:\n"
# --------------------------
# Simple interactive menu (shown only before accelerate)
# Interactive menu (shown only before accelerate spawns processes)
# --------------------------
def choose_stage_simple() -> str:
"""
Returns: 'sft' | 'dpo' | 'both'
Returns one of: 'sft' | 'dpo' | 'both'
"""
print("\n=== TRAIN MENU ===")
print("1) SFT (QLoRA masked-loss)")
@ -90,8 +86,8 @@ def choose_stage_simple() -> str:
def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]):
"""
Relaunches the script via accelerate (2 processes by default).
Only runs if NOT already inside accelerate (LOCAL_RANK not set).
Relaunch the script via accelerate with the requested number of processes.
No-op if already running inside accelerate (LOCAL_RANK is set).
"""
if os.environ.get("LOCAL_RANK") is not None:
return
@ -109,7 +105,7 @@ def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]):
sys.argv[0],
*extra_argv,
]
print("\n[AUTO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n")
print("\n[INFO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n")
raise SystemExit(subprocess.call(cmd))
@ -117,6 +113,9 @@ def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]):
# DDP helpers
# --------------------------
def ddp_info():
"""
Returns (local_rank, rank, world_size) inferred from environment variables.
"""
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
rank = int(os.environ.get("RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
@ -124,15 +123,18 @@ def ddp_info():
def ddp_barrier():
"""
Synchronization barrier for distributed runs (no-op if torch.distributed is not initialized).
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
# --------------------------
# TRL <-> Transformers patches (to avoid signature mismatches)
# TRL <-> Transformers patches (signature compatibility)
# --------------------------
def patch_trl_dpo_trainer_if_needed():
# Patch get_batch_samples
# Patch get_batch_samples to avoid signature mismatches across versions
try:
sig = list(inspect.signature(DPOTrainer.get_batch_samples).parameters.keys())
if len(sig) >= 3 and sig[1] == "model":
@ -143,7 +145,7 @@ def patch_trl_dpo_trainer_if_needed():
except Exception:
print("[PATCH] Could not inspect/patch DPOTrainer.get_batch_samples")
# Patch compute_loss
# Patch compute_loss to accept num_items_in_batch if needed
try:
sig = inspect.signature(DPOTrainer.compute_loss)
if "num_items_in_batch" not in sig.parameters:
@ -167,6 +169,9 @@ def patch_trl_dpo_trainer_if_needed():
def assert_is_model(obj, name="model"):
"""
Minimal runtime check: object should look like a HF CausalLM model.
"""
if not hasattr(obj, "generate"):
raise TypeError(
"[FATAL] {} is not a HF CausalLM model. type={}; missing .generate().".format(name, type(obj))
@ -177,14 +182,17 @@ def assert_is_model(obj, name="model"):
# SFT helpers
# --------------------------
def wrap_prompt(p: str) -> str:
"""
Wrap a raw prompt into the instruction/response template prefix.
"""
p = (p or "").strip()
return "{}{}{}".format(INSTR_PREFIX, p, RESP_PREFIX)
def ensure_sft_jsonl_exists(rank: int):
"""
Creates ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing.
Only rank0 writes; others wait.
Create ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing.
Only rank0 writes; all other ranks wait at a barrier.
"""
os.makedirs("./data", exist_ok=True)
@ -225,6 +233,10 @@ def ensure_sft_jsonl_exists(rank: int):
@dataclass
class DataCollatorForCausalLMWithLabels:
"""
Pads input_ids/attention_mask and aligns labels to the padded length.
Labels use -100 for ignored positions (standard LM loss masking).
"""
tokenizer: Any
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
@ -252,6 +264,9 @@ class DataCollatorForCausalLMWithLabels:
def build_quant_base(model_dir: str, device_map: Dict[str, int]):
"""
Load a 4-bit quantized base model using bitsandbytes config.
"""
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
@ -282,7 +297,7 @@ def run_sft(
world_size: int,
) -> str:
"""
Runs SFT and returns the path to the saved LoRA adapter (out_dir).
Run SFT training and return the saved LoRA adapter path (out_dir).
"""
os.makedirs(out_dir, exist_ok=True)
@ -309,6 +324,9 @@ def run_sft(
model = get_peft_model(base, lora)
def tokenize_and_mask(example):
"""
Tokenize prompt and response, and build labels masked for the prompt part.
"""
prompt = example["prompt"]
response = example["response"]
@ -366,7 +384,7 @@ def run_sft(
trainer.train()
# Save only on rank0, then barrier so others wait
# Save only on rank0; other ranks wait at a barrier
if rank == 0:
model.save_pretrained(out_dir)
tokenizer.save_pretrained(out_dir)
@ -394,6 +412,9 @@ def run_dpo(
device_map: Dict[str, int],
rank: int,
):
"""
Run DPO training starting from the base model + an SFT LoRA adapter.
"""
patch_trl_dpo_trainer_if_needed()
os.makedirs(out_dir, exist_ok=True)
@ -406,7 +427,7 @@ def run_dpo(
base_model = build_quant_base(model_dir, device_map)
assert_is_model(base_model, "base_model")
# Load SFT adapter and keep training it with DPO
# Load SFT adapter as trainable and continue training with DPO
model = PeftModel.from_pretrained(base_model, sft_adapter, is_trainable=True)
assert_is_model(model, "model(peft)")
@ -482,15 +503,13 @@ def main():
args = ap.parse_args()
# --------------------------
# AUTO MENU (only once, BEFORE accelerate spawns processes)
# --------------------------
# Show menu only in the initial, non-accelerate run
if os.environ.get("LOCAL_RANK") is None and os.environ.get("AUTO_ACCELERATE_RELAUNCH") != "1":
os.environ["AUTO_ACCELERATE_RELAUNCH"] = "1" # prevent menu duplication
stage = choose_stage_simple()
# Build argv using defaults from args (you only press 1/2/3)
# Build argv using defaults from args
argv = [
"--stage",
stage,
@ -536,17 +555,17 @@ def main():
sft_adapter = args.sft_adapter or args.sft_out
argv += ["--sft_adapter", sft_adapter]
# Relaunch under accelerate (2 GPUs)
# Relaunch under accelerate (2 processes)
relaunch_with_accelerate(num_processes=2, extra_argv=argv)
# --------------------------
# We are inside accelerate now
# Running inside accelerate
# --------------------------
local_rank, rank, world_size = ddp_info()
torch.cuda.set_device(local_rank)
device_map = {"": torch.cuda.current_device()}
# Deterministic seed (as in your DPO script style)
# Fixed seed
torch.manual_seed(42)
sft_adapter_path = args.sft_adapter or args.sft_out
@ -567,7 +586,6 @@ def main():
)
if args.stage in ("dpo", "both"):
# If "both", barrier already happened after saving SFT adapter
run_dpo(
model_dir=args.model,
sft_adapter=sft_adapter_path,