#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Unified training script for: - SFT (QLoRA + masked loss) - DPO (TRL DPOTrainer) Execution model: - Run the script normally: it shows a simple menu (SFT / DPO / BOTH). - The script relaunches itself via accelerate (multi-process). - After relaunch, each process runs with LOCAL_RANK set by accelerate. """ import os import sys import json import math import argparse import subprocess import inspect import traceback from dataclasses import dataclass from typing import Any, Dict, List, Optional import torch from datasets import load_dataset, load_from_disk from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, Trainer as HFTrainer, TrainingArguments, ) from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, ) from trl import DPOTrainer, DPOConfig # -------------------------- # Defaults # -------------------------- DEFAULT_MODEL_DIR = "/home/hyrenko/Diploma/models/llama3.1-8b" # SFT: source -> jsonl cache SFT_SRC_DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K" SFT_SRC_SPLIT = "train" SFT_JSONL_PATH = "./data/pku_sft.jsonl" # Outputs DEFAULT_SFT_OUT = "./out/llama3_1_8b_sft_safety_lora_masked" DEFAULT_DPO_OUT = "./out/llama3_1_8b_dpo_safety_lora" DEFAULT_DPO_DATA = "./data/pku_dpo" # Prompt wrappers INSTR_PREFIX = "### Instruction:\n" RESP_PREFIX = "\n\n### Response:\n" # -------------------------- # Interactive menu (shown only before accelerate spawns processes) # -------------------------- def choose_stage_simple() -> str: """ Returns one of: 'sft' | 'dpo' | 'both' """ print("\n=== TRAIN MENU ===") print("1) SFT (QLoRA masked-loss)") print("2) DPO (over SFT adapter)") print("3) BOTH (SFT -> DPO)") while True: x = input("Choose 1/2/3: ").strip() if x == "1": return "sft" if x == "2": return "dpo" if x == "3": return "both" print("Only 1, 2, or 3.") def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]): """ Relaunch the script via accelerate with the requested number of processes. No-op if already running inside accelerate (LOCAL_RANK is set). """ if os.environ.get("LOCAL_RANK") is not None: return cmd = [ sys.executable, "-m", "accelerate.commands.launch", "--num_processes", str(num_processes), "--mixed_precision", "fp16", "--main_process_port", "0", sys.argv[0], *extra_argv, ] print("\n[INFO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n") raise SystemExit(subprocess.call(cmd)) # -------------------------- # DDP helpers # -------------------------- def ddp_info(): """ Returns (local_rank, rank, world_size) inferred from environment variables. """ local_rank = int(os.environ.get("LOCAL_RANK", "0")) rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) return local_rank, rank, world_size def ddp_barrier(): """ Synchronization barrier for distributed runs (no-op if torch.distributed is not initialized). """ if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.barrier() # -------------------------- # TRL <-> Transformers patches (signature compatibility) # -------------------------- def patch_trl_dpo_trainer_if_needed(): # Patch get_batch_samples to avoid signature mismatches across versions try: sig = list(inspect.signature(DPOTrainer.get_batch_samples).parameters.keys()) if len(sig) >= 3 and sig[1] == "model": if not hasattr(DPOTrainer, "get_batch_samples_trl"): DPOTrainer.get_batch_samples_trl = DPOTrainer.get_batch_samples DPOTrainer.get_batch_samples = HFTrainer.get_batch_samples print("[PATCH] Patched DPOTrainer.get_batch_samples -> HF Trainer.get_batch_samples") except Exception: print("[PATCH] Could not inspect/patch DPOTrainer.get_batch_samples") # Patch compute_loss to accept num_items_in_batch if needed try: sig = inspect.signature(DPOTrainer.compute_loss) if "num_items_in_batch" not in sig.parameters: if not hasattr(DPOTrainer, "compute_loss_trl"): DPOTrainer.compute_loss_trl = DPOTrainer.compute_loss def compute_loss_patched(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs): orig = self.compute_loss_trl osig = inspect.signature(orig) if "return_outputs" in osig.parameters: return orig(model, inputs, return_outputs=return_outputs) out = orig(model, inputs) if return_outputs: return out, None return out DPOTrainer.compute_loss = compute_loss_patched print("[PATCH] Patched DPOTrainer.compute_loss to accept num_items_in_batch") except Exception: print("[PATCH] Could not inspect/patch DPOTrainer.compute_loss") def assert_is_model(obj, name="model"): """ Minimal runtime check: object should look like a HF CausalLM model. """ if not hasattr(obj, "generate"): raise TypeError( "[FATAL] {} is not a HF CausalLM model. type={}; missing .generate().".format(name, type(obj)) ) # -------------------------- # SFT helpers # -------------------------- def wrap_prompt(p: str) -> str: """ Wrap a raw prompt into the instruction/response template prefix. """ p = (p or "").strip() return "{}{}{}".format(INSTR_PREFIX, p, RESP_PREFIX) def ensure_sft_jsonl_exists(rank: int): """ Create ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing. Only rank0 writes; all other ranks wait at a barrier. """ os.makedirs("./data", exist_ok=True) if os.path.exists(SFT_JSONL_PATH) and os.path.getsize(SFT_JSONL_PATH) > 0: if rank == 0: print("[INFO] Found existing SFT jsonl:", SFT_JSONL_PATH) ddp_barrier() return if rank == 0: print("[INFO] SFT jsonl not found. Building from {} ...".format(SFT_SRC_DATASET_ID)) ds = load_dataset(SFT_SRC_DATASET_ID, split=SFT_SRC_SPLIT) rows = 0 dropped = 0 with open(SFT_JSONL_PATH, "w", encoding="utf-8") as f: for ex in ds: prompt = (ex.get("prompt") or "").strip() r0 = (ex.get("response_0") or "").strip() r1 = (ex.get("response_1") or "").strip() sid = ex.get("safer_response_id", None) if not prompt or not r0 or not r1 or sid not in (0, 1): dropped += 1 continue safe = r0 if sid == 0 else r1 obj = {"prompt": wrap_prompt(prompt), "response": safe} f.write(json.dumps(obj, ensure_ascii=False) + "\n") rows += 1 print("[OK] Built SFT jsonl:", SFT_JSONL_PATH) print("[OK] rows={}, dropped={}".format(rows, dropped)) ddp_barrier() @dataclass class DataCollatorForCausalLMWithLabels: """ Pads input_ids/attention_mask and aligns labels to the padded length. Labels use -100 for ignored positions (standard LM loss masking). """ tokenizer: Any def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: pad_batch = self.tokenizer.pad( { "input_ids": [f["input_ids"] for f in features], "attention_mask": [f["attention_mask"] for f in features], }, padding=True, return_tensors="pt", ) max_len = pad_batch["input_ids"].shape[1] labels = [] for f in features: lab = f["labels"] if len(lab) < max_len: lab = lab + [-100] * (max_len - len(lab)) else: lab = lab[:max_len] labels.append(lab) pad_batch["labels"] = torch.tensor(labels, dtype=torch.long) return pad_batch def build_quant_base(model_dir: str, device_map: Dict[str, int]): """ Load a 4-bit quantized base model using bitsandbytes config. """ bnb = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) base = AutoModelForCausalLM.from_pretrained( model_dir, quantization_config=bnb, device_map=device_map, torch_dtype=torch.float16, ) base.config.use_cache = False return base def run_sft( model_dir: str, out_dir: str, epochs: float, lr: float, bs: int, gas: int, seq: int, lora_r: int, device_map: Dict[str, int], rank: int, world_size: int, ) -> str: """ Run SFT training and return the saved LoRA adapter path (out_dir). """ os.makedirs(out_dir, exist_ok=True) ensure_sft_jsonl_exists(rank) ds = load_dataset("json", data_files=SFT_JSONL_PATH, split="train") tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" base = build_quant_base(model_dir, device_map) base = prepare_model_for_kbit_training(base) lora = LoraConfig( r=lora_r, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], ) model = get_peft_model(base, lora) def tokenize_and_mask(example): """ Tokenize prompt and response, and build labels masked for the prompt part. """ prompt = example["prompt"] response = example["response"] p = tokenizer(prompt, add_special_tokens=False) r = tokenizer(response + tokenizer.eos_token, add_special_tokens=False) input_ids = p["input_ids"] + r["input_ids"] attention_mask = [1] * len(input_ids) labels = [-100] * len(p["input_ids"]) + r["input_ids"] if len(input_ids) > seq: input_ids = input_ids[:seq] attention_mask = attention_mask[:seq] labels = labels[:seq] return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} ds_tok = ds.map(tokenize_and_mask, remove_columns=ds.column_names) if rank == 0: total_examples = len(ds_tok) effective_bs = bs * world_size * gas steps_per_epoch = int(math.ceil(float(total_examples) / float(effective_bs))) print("[SFT] base={}".format(model_dir)) print("[SFT] out={}".format(out_dir)) print("[SFT] examples={}, world_size={}".format(total_examples, world_size)) print("[SFT] per_device_bs={}, gas={}, effective_batch={}".format(bs, gas, effective_bs)) print("[SFT] approx steps/epoch ≈ {}".format(steps_per_epoch)) collator = DataCollatorForCausalLMWithLabels(tokenizer) train_args = TrainingArguments( output_dir=out_dir, num_train_epochs=epochs, learning_rate=lr, per_device_train_batch_size=bs, gradient_accumulation_steps=gas, logging_steps=10, save_steps=200, save_total_limit=2, fp16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, ddp_find_unused_parameters=False, report_to="none", ) trainer = HFTrainer( model=model, args=train_args, train_dataset=ds_tok, data_collator=collator, tokenizer=tokenizer, ) trainer.train() # Save only on rank0; other ranks wait at a barrier if rank == 0: model.save_pretrained(out_dir) tokenizer.save_pretrained(out_dir) print("[OK] SFT LoRA adapter saved to:", out_dir) ddp_barrier() return out_dir # -------------------------- # DPO stage # -------------------------- def run_dpo( model_dir: str, sft_adapter: str, dpo_data_dir: str, out_dir: str, epochs: float, lr: float, beta: float, bs: int, gas: int, max_prompt: int, max_len: int, device_map: Dict[str, int], rank: int, ): """ Run DPO training starting from the base model + an SFT LoRA adapter. """ patch_trl_dpo_trainer_if_needed() os.makedirs(out_dir, exist_ok=True) ds = load_from_disk(dpo_data_dir) tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token base_model = build_quant_base(model_dir, device_map) assert_is_model(base_model, "base_model") # Load SFT adapter as trainable and continue training with DPO model = PeftModel.from_pretrained(base_model, sft_adapter, is_trainable=True) assert_is_model(model, "model(peft)") if rank == 0: print("[DPO] base:", model_dir) print("[DPO] sft_adapter:", sft_adapter) print("[DPO] dpo_data:", dpo_data_dir) print("[DPO] out:", out_dir) dpo_args = DPOConfig( output_dir=out_dir, num_train_epochs=epochs, learning_rate=lr, per_device_train_batch_size=bs, gradient_accumulation_steps=gas, logging_steps=10, save_steps=200, save_total_limit=2, fp16=True, report_to="none", beta=beta, max_prompt_length=max_prompt, max_length=max_len, gradient_checkpointing=True, ddp_find_unused_parameters=False, ) trainer = DPOTrainer( model=model, args=dpo_args, train_dataset=ds, tokenizer=tokenizer, ) trainer.train() trainer.save_model(out_dir) if rank == 0: print("[OK] DPO LoRA adapter saved to:", out_dir) # -------------------------- # Main # -------------------------- def main(): ap = argparse.ArgumentParser(description="Unified SFT(QLoRA masked) + DPO training script") # Stage (menu will override this in the initial run) ap.add_argument("--stage", choices=["sft", "dpo", "both"], default="both") ap.add_argument("--model", default=DEFAULT_MODEL_DIR) # SFT args ap.add_argument("--sft_out", default=DEFAULT_SFT_OUT) ap.add_argument("--sft_epochs", type=float, default=1.0) ap.add_argument("--sft_lr", type=float, default=2e-4) ap.add_argument("--sft_bs", type=int, default=1) ap.add_argument("--sft_gas", type=int, default=16) ap.add_argument("--sft_seq", type=int, default=1024) ap.add_argument("--lora_r", type=int, default=16) # DPO args ap.add_argument("--sft_adapter", default=None, help="If omitted, uses --sft_out") ap.add_argument("--dpo_data", default=DEFAULT_DPO_DATA) ap.add_argument("--dpo_out", default=DEFAULT_DPO_OUT) ap.add_argument("--dpo_epochs", type=float, default=1.0) ap.add_argument("--dpo_lr", type=float, default=5e-6) ap.add_argument("--beta", type=float, default=0.1) ap.add_argument("--dpo_bs", type=int, default=1) ap.add_argument("--dpo_gas", type=int, default=16) ap.add_argument("--max_prompt", type=int, default=512) ap.add_argument("--max_len", type=int, default=1024) args = ap.parse_args() # Show menu only in the initial, non-accelerate run if os.environ.get("LOCAL_RANK") is None and os.environ.get("AUTO_ACCELERATE_RELAUNCH") != "1": os.environ["AUTO_ACCELERATE_RELAUNCH"] = "1" # prevent menu duplication stage = choose_stage_simple() # Build argv using defaults from args argv = [ "--stage", stage, "--model", args.model, "--sft_out", args.sft_out, "--sft_epochs", str(args.sft_epochs), "--sft_lr", str(args.sft_lr), "--sft_bs", str(args.sft_bs), "--sft_gas", str(args.sft_gas), "--sft_seq", str(args.sft_seq), "--lora_r", str(args.lora_r), "--dpo_data", args.dpo_data, "--dpo_out", args.dpo_out, "--dpo_epochs", str(args.dpo_epochs), "--dpo_lr", str(args.dpo_lr), "--beta", str(args.beta), "--dpo_bs", str(args.dpo_bs), "--dpo_gas", str(args.dpo_gas), "--max_prompt", str(args.max_prompt), "--max_len", str(args.max_len), ] if stage in ("dpo", "both"): sft_adapter = args.sft_adapter or args.sft_out argv += ["--sft_adapter", sft_adapter] # Relaunch under accelerate (2 processes) relaunch_with_accelerate(num_processes=2, extra_argv=argv) # -------------------------- # Running inside accelerate # -------------------------- local_rank, rank, world_size = ddp_info() torch.cuda.set_device(local_rank) device_map = {"": torch.cuda.current_device()} # Fixed seed torch.manual_seed(42) sft_adapter_path = args.sft_adapter or args.sft_out if args.stage in ("sft", "both"): sft_adapter_path = run_sft( model_dir=args.model, out_dir=args.sft_out, epochs=args.sft_epochs, lr=args.sft_lr, bs=args.sft_bs, gas=args.sft_gas, seq=args.sft_seq, lora_r=args.lora_r, device_map=device_map, rank=rank, world_size=world_size, ) if args.stage in ("dpo", "both"): run_dpo( model_dir=args.model, sft_adapter=sft_adapter_path, dpo_data_dir=args.dpo_data, out_dir=args.dpo_out, epochs=args.dpo_epochs, lr=args.dpo_lr, beta=args.beta, bs=args.dpo_bs, gas=args.dpo_gas, max_prompt=args.max_prompt, max_len=args.max_len, device_map=device_map, rank=rank, ) if __name__ == "__main__": try: main() except Exception: r = os.environ.get("RANK", "?") lr = os.environ.get("LOCAL_RANK", "?") print("\n[FATAL] Exception on RANK={} LOCAL_RANK={}\n".format(r, lr)) print(traceback.format_exc()) raise