#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Unified training script: SFT (QLoRA + masked loss) and DPO (TRL) in ONE file. UX requirement (as you asked): - You run: python3 train_unified.py - It shows a simple menu 1/2/3 - You press 1, 2, or 3 - Script relaunches itself with accelerate (2 GPUs) and runs the chosen stage using DEFAULT paths/hparams already embedded. Notes: - Menu is shown ONLY once (before accelerate spawns processes). - After relaunch, each process runs normally with LOCAL_RANK set by accelerate. """ import os import sys import json import math import argparse import subprocess import inspect import traceback from dataclasses import dataclass from typing import Any, Dict, List, Optional import torch from datasets import load_dataset, load_from_disk from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Trainer as HFTrainer, TrainingArguments, ) from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, ) from trl import DPOTrainer, DPOConfig # -------------------------- # Defaults (same spirit as your original scripts) # -------------------------- DEFAULT_MODEL_DIR = "/home/hyrenko/Diploma/models/llama3.1-8b" # SFT: source -> jsonl cache SFT_SRC_DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K" SFT_SRC_SPLIT = "train" SFT_JSONL_PATH = "./data/pku_sft.jsonl" # Outputs DEFAULT_SFT_OUT = "./out/llama3_1_8b_sft_safety_lora_masked" DEFAULT_DPO_OUT = "./out/llama3_1_8b_dpo_safety_lora" DEFAULT_DPO_DATA = "./data/pku_dpo" # Prompt wrappers (keep consistent) INSTR_PREFIX = "### Instruction:\n" RESP_PREFIX = "\n\n### Response:\n" # -------------------------- # Simple interactive menu (shown only before accelerate) # -------------------------- def choose_stage_simple() -> str: """ Returns: 'sft' | 'dpo' | 'both' """ print("\n=== TRAIN MENU ===") print("1) SFT (QLoRA masked-loss)") print("2) DPO (over SFT adapter)") print("3) BOTH (SFT -> DPO)") while True: x = input("Choose 1/2/3: ").strip() if x == "1": return "sft" if x == "2": return "dpo" if x == "3": return "both" print("Only 1, 2, or 3.") def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]): """ Relaunches the script via accelerate (2 processes by default). Only runs if NOT already inside accelerate (LOCAL_RANK not set). """ if os.environ.get("LOCAL_RANK") is not None: return cmd = [ sys.executable, "-m", "accelerate.commands.launch", "--num_processes", str(num_processes), "--mixed_precision", "fp16", "--main_process_port", "0", sys.argv[0], *extra_argv, ] print("\n[AUTO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n") raise SystemExit(subprocess.call(cmd)) # -------------------------- # DDP helpers # -------------------------- def ddp_info(): local_rank = int(os.environ.get("LOCAL_RANK", "0")) rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) return local_rank, rank, world_size def ddp_barrier(): if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.barrier() # -------------------------- # TRL <-> Transformers patches (to avoid signature mismatches) # -------------------------- def patch_trl_dpo_trainer_if_needed(): # Patch get_batch_samples try: sig = list(inspect.signature(DPOTrainer.get_batch_samples).parameters.keys()) if len(sig) >= 3 and sig[1] == "model": if not hasattr(DPOTrainer, "get_batch_samples_trl"): DPOTrainer.get_batch_samples_trl = DPOTrainer.get_batch_samples DPOTrainer.get_batch_samples = HFTrainer.get_batch_samples print("[PATCH] Patched DPOTrainer.get_batch_samples -> HF Trainer.get_batch_samples") except Exception: print("[PATCH] Could not inspect/patch DPOTrainer.get_batch_samples") # Patch compute_loss try: sig = inspect.signature(DPOTrainer.compute_loss) if "num_items_in_batch" not in sig.parameters: if not hasattr(DPOTrainer, "compute_loss_trl"): DPOTrainer.compute_loss_trl = DPOTrainer.compute_loss def compute_loss_patched(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs): orig = self.compute_loss_trl osig = inspect.signature(orig) if "return_outputs" in osig.parameters: return orig(model, inputs, return_outputs=return_outputs) out = orig(model, inputs) if return_outputs: return out, None return out DPOTrainer.compute_loss = compute_loss_patched print("[PATCH] Patched DPOTrainer.compute_loss to accept num_items_in_batch") except Exception: print("[PATCH] Could not inspect/patch DPOTrainer.compute_loss") def assert_is_model(obj, name="model"): if not hasattr(obj, "generate"): raise TypeError( "[FATAL] {} is not a HF CausalLM model. type={}; missing .generate().".format(name, type(obj)) ) # -------------------------- # SFT helpers # -------------------------- def wrap_prompt(p: str) -> str: p = (p or "").strip() return "{}{}{}".format(INSTR_PREFIX, p, RESP_PREFIX) def ensure_sft_jsonl_exists(rank: int): """ Creates ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing. Only rank0 writes; others wait. """ os.makedirs("./data", exist_ok=True) if os.path.exists(SFT_JSONL_PATH) and os.path.getsize(SFT_JSONL_PATH) > 0: if rank == 0: print("[INFO] Found existing SFT jsonl:", SFT_JSONL_PATH) ddp_barrier() return if rank == 0: print("[INFO] SFT jsonl not found. Building from {} ...".format(SFT_SRC_DATASET_ID)) ds = load_dataset(SFT_SRC_DATASET_ID, split=SFT_SRC_SPLIT) rows = 0 dropped = 0 with open(SFT_JSONL_PATH, "w", encoding="utf-8") as f: for ex in ds: prompt = (ex.get("prompt") or "").strip() r0 = (ex.get("response_0") or "").strip() r1 = (ex.get("response_1") or "").strip() sid = ex.get("safer_response_id", None) if not prompt or not r0 or not r1 or sid not in (0, 1): dropped += 1 continue safe = r0 if sid == 0 else r1 obj = {"prompt": wrap_prompt(prompt), "response": safe} f.write(json.dumps(obj, ensure_ascii=False) + "\n") rows += 1 print("[OK] Built SFT jsonl:", SFT_JSONL_PATH) print("[OK] rows={}, dropped={}".format(rows, dropped)) ddp_barrier() @dataclass class DataCollatorForCausalLMWithLabels: tokenizer: Any def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: pad_batch = self.tokenizer.pad( { "input_ids": [f["input_ids"] for f in features], "attention_mask": [f["attention_mask"] for f in features], }, padding=True, return_tensors="pt", ) max_len = pad_batch["input_ids"].shape[1] labels = [] for f in features: lab = f["labels"] if len(lab) < max_len: lab = lab + [-100] * (max_len - len(lab)) else: lab = lab[:max_len] labels.append(lab) pad_batch["labels"] = torch.tensor(labels, dtype=torch.long) return pad_batch def build_quant_base(model_dir: str, device_map: Dict[str, int]): bnb = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) base = AutoModelForCausalLM.from_pretrained( model_dir, quantization_config=bnb, device_map=device_map, torch_dtype=torch.float16, ) base.config.use_cache = False return base def run_sft( model_dir: str, out_dir: str, epochs: float, lr: float, bs: int, gas: int, seq: int, lora_r: int, device_map: Dict[str, int], rank: int, world_size: int, ) -> str: """ Runs SFT and returns the path to the saved LoRA adapter (out_dir). """ os.makedirs(out_dir, exist_ok=True) ensure_sft_jsonl_exists(rank) ds = load_dataset("json", data_files=SFT_JSONL_PATH, split="train") tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" base = build_quant_base(model_dir, device_map) base = prepare_model_for_kbit_training(base) lora = LoraConfig( r=lora_r, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], ) model = get_peft_model(base, lora) def tokenize_and_mask(example): prompt = example["prompt"] response = example["response"] p = tokenizer(prompt, add_special_tokens=False) r = tokenizer(response + tokenizer.eos_token, add_special_tokens=False) input_ids = p["input_ids"] + r["input_ids"] attention_mask = [1] * len(input_ids) labels = [-100] * len(p["input_ids"]) + r["input_ids"] if len(input_ids) > seq: input_ids = input_ids[:seq] attention_mask = attention_mask[:seq] labels = labels[:seq] return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} ds_tok = ds.map(tokenize_and_mask, remove_columns=ds.column_names) if rank == 0: total_examples = len(ds_tok) effective_bs = bs * world_size * gas steps_per_epoch = int(math.ceil(float(total_examples) / float(effective_bs))) print("[SFT] base={}".format(model_dir)) print("[SFT] out={}".format(out_dir)) print("[SFT] examples={}, world_size={}".format(total_examples, world_size)) print("[SFT] per_device_bs={}, gas={}, effective_batch={}".format(bs, gas, effective_bs)) print("[SFT] approx steps/epoch ≈ {}".format(steps_per_epoch)) collator = DataCollatorForCausalLMWithLabels(tokenizer) train_args = TrainingArguments( output_dir=out_dir, num_train_epochs=epochs, learning_rate=lr, per_device_train_batch_size=bs, gradient_accumulation_steps=gas, logging_steps=10, save_steps=200, save_total_limit=2, fp16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, ddp_find_unused_parameters=False, report_to="none", ) trainer = HFTrainer( model=model, args=train_args, train_dataset=ds_tok, data_collator=collator, tokenizer=tokenizer, ) trainer.train() # Save only on rank0, then barrier so others wait if rank == 0: model.save_pretrained(out_dir) tokenizer.save_pretrained(out_dir) print("[OK] SFT LoRA adapter saved to:", out_dir) ddp_barrier() return out_dir # -------------------------- # DPO stage # -------------------------- def run_dpo( model_dir: str, sft_adapter: str, dpo_data_dir: str, out_dir: str, epochs: float, lr: float, beta: float, bs: int, gas: int, max_prompt: int, max_len: int, device_map: Dict[str, int], rank: int, ): patch_trl_dpo_trainer_if_needed() os.makedirs(out_dir, exist_ok=True) ds = load_from_disk(dpo_data_dir) tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token base_model = build_quant_base(model_dir, device_map) assert_is_model(base_model, "base_model") # Load SFT adapter and keep training it with DPO model = PeftModel.from_pretrained(base_model, sft_adapter, is_trainable=True) assert_is_model(model, "model(peft)") if rank == 0: print("[DPO] base:", model_dir) print("[DPO] sft_adapter:", sft_adapter) print("[DPO] dpo_data:", dpo_data_dir) print("[DPO] out:", out_dir) dpo_args = DPOConfig( output_dir=out_dir, num_train_epochs=epochs, learning_rate=lr, per_device_train_batch_size=bs, gradient_accumulation_steps=gas, logging_steps=10, save_steps=200, save_total_limit=2, fp16=True, report_to="none", beta=beta, max_prompt_length=max_prompt, max_length=max_len, gradient_checkpointing=True, ddp_find_unused_parameters=False, ) trainer = DPOTrainer( model=model, args=dpo_args, train_dataset=ds, tokenizer=tokenizer, ) trainer.train() trainer.save_model(out_dir) if rank == 0: print("[OK] DPO LoRA adapter saved to:", out_dir) # -------------------------- # Main # -------------------------- def main(): ap = argparse.ArgumentParser(description="Unified SFT(QLoRA masked) + DPO training script") # Stage (menu will override this in the initial run) ap.add_argument("--stage", choices=["sft", "dpo", "both"], default="both") ap.add_argument("--model", default=DEFAULT_MODEL_DIR) # SFT args ap.add_argument("--sft_out", default=DEFAULT_SFT_OUT) ap.add_argument("--sft_epochs", type=float, default=1.0) ap.add_argument("--sft_lr", type=float, default=2e-4) ap.add_argument("--sft_bs", type=int, default=1) ap.add_argument("--sft_gas", type=int, default=16) ap.add_argument("--sft_seq", type=int, default=1024) ap.add_argument("--lora_r", type=int, default=16) # DPO args ap.add_argument("--sft_adapter", default=None, help="If omitted, uses --sft_out") ap.add_argument("--dpo_data", default=DEFAULT_DPO_DATA) ap.add_argument("--dpo_out", default=DEFAULT_DPO_OUT) ap.add_argument("--dpo_epochs", type=float, default=1.0) ap.add_argument("--dpo_lr", type=float, default=5e-6) ap.add_argument("--beta", type=float, default=0.1) ap.add_argument("--dpo_bs", type=int, default=1) ap.add_argument("--dpo_gas", type=int, default=16) ap.add_argument("--max_prompt", type=int, default=512) ap.add_argument("--max_len", type=int, default=1024) args = ap.parse_args() # -------------------------- # AUTO MENU (only once, BEFORE accelerate spawns processes) # -------------------------- if os.environ.get("LOCAL_RANK") is None and os.environ.get("AUTO_ACCELERATE_RELAUNCH") != "1": os.environ["AUTO_ACCELERATE_RELAUNCH"] = "1" # prevent menu duplication stage = choose_stage_simple() # Build argv using defaults from args (you only press 1/2/3) argv = [ "--stage", stage, "--model", args.model, "--sft_out", args.sft_out, "--sft_epochs", str(args.sft_epochs), "--sft_lr", str(args.sft_lr), "--sft_bs", str(args.sft_bs), "--sft_gas", str(args.sft_gas), "--sft_seq", str(args.sft_seq), "--lora_r", str(args.lora_r), "--dpo_data", args.dpo_data, "--dpo_out", args.dpo_out, "--dpo_epochs", str(args.dpo_epochs), "--dpo_lr", str(args.dpo_lr), "--beta", str(args.beta), "--dpo_bs", str(args.dpo_bs), "--dpo_gas", str(args.dpo_gas), "--max_prompt", str(args.max_prompt), "--max_len", str(args.max_len), ] if stage in ("dpo", "both"): sft_adapter = args.sft_adapter or args.sft_out argv += ["--sft_adapter", sft_adapter] # Relaunch under accelerate (2 GPUs) relaunch_with_accelerate(num_processes=2, extra_argv=argv) # -------------------------- # We are inside accelerate now # -------------------------- local_rank, rank, world_size = ddp_info() torch.cuda.set_device(local_rank) device_map = {"": torch.cuda.current_device()} # Deterministic seed (as in your DPO script style) torch.manual_seed(42) sft_adapter_path = args.sft_adapter or args.sft_out if args.stage in ("sft", "both"): sft_adapter_path = run_sft( model_dir=args.model, out_dir=args.sft_out, epochs=args.sft_epochs, lr=args.sft_lr, bs=args.sft_bs, gas=args.sft_gas, seq=args.sft_seq, lora_r=args.lora_r, device_map=device_map, rank=rank, world_size=world_size, ) if args.stage in ("dpo", "both"): # If "both", barrier already happened after saving SFT adapter run_dpo( model_dir=args.model, sft_adapter=sft_adapter_path, dpo_data_dir=args.dpo_data, out_dir=args.dpo_out, epochs=args.dpo_epochs, lr=args.dpo_lr, beta=args.beta, bs=args.dpo_bs, gas=args.dpo_gas, max_prompt=args.max_prompt, max_len=args.max_len, device_map=device_map, rank=rank, ) if __name__ == "__main__": try: main() except Exception: r = os.environ.get("RANK", "?") lr = os.environ.get("LOCAL_RANK", "?") print("\n[FATAL] Exception on RANK={} LOCAL_RANK={}\n".format(r, lr)) print(traceback.format_exc()) raise