From 7bb03b551fd88cad8f2c278e30776bb1e7b1e17c Mon Sep 17 00:00:00 2001 From: Artur Hyrenko Date: Wed, 4 Feb 2026 20:38:37 +0000 Subject: [PATCH] Added comments --- Training/training_dpo_sft_llama.py | 86 ++++++++++++++++++------------ 1 file changed, 52 insertions(+), 34 deletions(-) diff --git a/Training/training_dpo_sft_llama.py b/Training/training_dpo_sft_llama.py index f876041..bd7e0c9 100644 --- a/Training/training_dpo_sft_llama.py +++ b/Training/training_dpo_sft_llama.py @@ -2,18 +2,14 @@ # -*- coding: utf-8 -*- """ -Unified training script: SFT (QLoRA + masked loss) and DPO (TRL) in ONE file. +Unified training script for: +- SFT (QLoRA + masked loss) +- DPO (TRL DPOTrainer) -UX requirement (as you asked): -- You run: python3 train_unified.py -- It shows a simple menu 1/2/3 -- You press 1, 2, or 3 -- Script relaunches itself with accelerate (2 GPUs) and runs the chosen stage - using DEFAULT paths/hparams already embedded. - -Notes: -- Menu is shown ONLY once (before accelerate spawns processes). -- After relaunch, each process runs normally with LOCAL_RANK set by accelerate. +Execution model: +- Run the script normally: it shows a simple menu (SFT / DPO / BOTH). +- The script relaunches itself via accelerate (multi-process). +- After relaunch, each process runs with LOCAL_RANK set by accelerate. """ import os @@ -47,7 +43,7 @@ from trl import DPOTrainer, DPOConfig # -------------------------- -# Defaults (same spirit as your original scripts) +# Defaults # -------------------------- DEFAULT_MODEL_DIR = "/home/hyrenko/Diploma/models/llama3.1-8b" @@ -61,17 +57,17 @@ DEFAULT_SFT_OUT = "./out/llama3_1_8b_sft_safety_lora_masked" DEFAULT_DPO_OUT = "./out/llama3_1_8b_dpo_safety_lora" DEFAULT_DPO_DATA = "./data/pku_dpo" -# Prompt wrappers (keep consistent) +# Prompt wrappers INSTR_PREFIX = "### Instruction:\n" RESP_PREFIX = "\n\n### Response:\n" # -------------------------- -# Simple interactive menu (shown only before accelerate) +# Interactive menu (shown only before accelerate spawns processes) # -------------------------- def choose_stage_simple() -> str: """ - Returns: 'sft' | 'dpo' | 'both' + Returns one of: 'sft' | 'dpo' | 'both' """ print("\n=== TRAIN MENU ===") print("1) SFT (QLoRA masked-loss)") @@ -90,8 +86,8 @@ def choose_stage_simple() -> str: def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]): """ - Relaunches the script via accelerate (2 processes by default). - Only runs if NOT already inside accelerate (LOCAL_RANK not set). + Relaunch the script via accelerate with the requested number of processes. + No-op if already running inside accelerate (LOCAL_RANK is set). """ if os.environ.get("LOCAL_RANK") is not None: return @@ -109,7 +105,7 @@ def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]): sys.argv[0], *extra_argv, ] - print("\n[AUTO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n") + print("\n[INFO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n") raise SystemExit(subprocess.call(cmd)) @@ -117,6 +113,9 @@ def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]): # DDP helpers # -------------------------- def ddp_info(): + """ + Returns (local_rank, rank, world_size) inferred from environment variables. + """ local_rank = int(os.environ.get("LOCAL_RANK", "0")) rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -124,15 +123,18 @@ def ddp_info(): def ddp_barrier(): + """ + Synchronization barrier for distributed runs (no-op if torch.distributed is not initialized). + """ if torch.distributed.is_available() and torch.distributed.is_initialized(): torch.distributed.barrier() # -------------------------- -# TRL <-> Transformers patches (to avoid signature mismatches) +# TRL <-> Transformers patches (signature compatibility) # -------------------------- def patch_trl_dpo_trainer_if_needed(): - # Patch get_batch_samples + # Patch get_batch_samples to avoid signature mismatches across versions try: sig = list(inspect.signature(DPOTrainer.get_batch_samples).parameters.keys()) if len(sig) >= 3 and sig[1] == "model": @@ -143,7 +145,7 @@ def patch_trl_dpo_trainer_if_needed(): except Exception: print("[PATCH] Could not inspect/patch DPOTrainer.get_batch_samples") - # Patch compute_loss + # Patch compute_loss to accept num_items_in_batch if needed try: sig = inspect.signature(DPOTrainer.compute_loss) if "num_items_in_batch" not in sig.parameters: @@ -167,6 +169,9 @@ def patch_trl_dpo_trainer_if_needed(): def assert_is_model(obj, name="model"): + """ + Minimal runtime check: object should look like a HF CausalLM model. + """ if not hasattr(obj, "generate"): raise TypeError( "[FATAL] {} is not a HF CausalLM model. type={}; missing .generate().".format(name, type(obj)) @@ -177,14 +182,17 @@ def assert_is_model(obj, name="model"): # SFT helpers # -------------------------- def wrap_prompt(p: str) -> str: + """ + Wrap a raw prompt into the instruction/response template prefix. + """ p = (p or "").strip() return "{}{}{}".format(INSTR_PREFIX, p, RESP_PREFIX) def ensure_sft_jsonl_exists(rank: int): """ - Creates ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing. - Only rank0 writes; others wait. + Create ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing. + Only rank0 writes; all other ranks wait at a barrier. """ os.makedirs("./data", exist_ok=True) @@ -225,6 +233,10 @@ def ensure_sft_jsonl_exists(rank: int): @dataclass class DataCollatorForCausalLMWithLabels: + """ + Pads input_ids/attention_mask and aligns labels to the padded length. + Labels use -100 for ignored positions (standard LM loss masking). + """ tokenizer: Any def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: @@ -252,6 +264,9 @@ class DataCollatorForCausalLMWithLabels: def build_quant_base(model_dir: str, device_map: Dict[str, int]): + """ + Load a 4-bit quantized base model using bitsandbytes config. + """ bnb = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", @@ -282,7 +297,7 @@ def run_sft( world_size: int, ) -> str: """ - Runs SFT and returns the path to the saved LoRA adapter (out_dir). + Run SFT training and return the saved LoRA adapter path (out_dir). """ os.makedirs(out_dir, exist_ok=True) @@ -309,6 +324,9 @@ def run_sft( model = get_peft_model(base, lora) def tokenize_and_mask(example): + """ + Tokenize prompt and response, and build labels masked for the prompt part. + """ prompt = example["prompt"] response = example["response"] @@ -366,7 +384,7 @@ def run_sft( trainer.train() - # Save only on rank0, then barrier so others wait + # Save only on rank0; other ranks wait at a barrier if rank == 0: model.save_pretrained(out_dir) tokenizer.save_pretrained(out_dir) @@ -394,6 +412,9 @@ def run_dpo( device_map: Dict[str, int], rank: int, ): + """ + Run DPO training starting from the base model + an SFT LoRA adapter. + """ patch_trl_dpo_trainer_if_needed() os.makedirs(out_dir, exist_ok=True) @@ -406,7 +427,7 @@ def run_dpo( base_model = build_quant_base(model_dir, device_map) assert_is_model(base_model, "base_model") - # Load SFT adapter and keep training it with DPO + # Load SFT adapter as trainable and continue training with DPO model = PeftModel.from_pretrained(base_model, sft_adapter, is_trainable=True) assert_is_model(model, "model(peft)") @@ -482,15 +503,13 @@ def main(): args = ap.parse_args() - # -------------------------- - # AUTO MENU (only once, BEFORE accelerate spawns processes) - # -------------------------- + # Show menu only in the initial, non-accelerate run if os.environ.get("LOCAL_RANK") is None and os.environ.get("AUTO_ACCELERATE_RELAUNCH") != "1": os.environ["AUTO_ACCELERATE_RELAUNCH"] = "1" # prevent menu duplication stage = choose_stage_simple() - # Build argv using defaults from args (you only press 1/2/3) + # Build argv using defaults from args argv = [ "--stage", stage, @@ -536,17 +555,17 @@ def main(): sft_adapter = args.sft_adapter or args.sft_out argv += ["--sft_adapter", sft_adapter] - # Relaunch under accelerate (2 GPUs) + # Relaunch under accelerate (2 processes) relaunch_with_accelerate(num_processes=2, extra_argv=argv) # -------------------------- - # We are inside accelerate now + # Running inside accelerate # -------------------------- local_rank, rank, world_size = ddp_info() torch.cuda.set_device(local_rank) device_map = {"": torch.cuda.current_device()} - # Deterministic seed (as in your DPO script style) + # Fixed seed torch.manual_seed(42) sft_adapter_path = args.sft_adapter or args.sft_out @@ -567,7 +586,6 @@ def main(): ) if args.stage in ("dpo", "both"): - # If "both", barrier already happened after saving SFT adapter run_dpo( model_dir=args.model, sft_adapter=sft_adapter_path,