diff --git a/Training/training_dpo_sft_mistral_sk.py b/Training/training_dpo_sft_mistral_sk.py index 5d7278f..1531433 100644 --- a/Training/training_dpo_sft_mistral_sk.py +++ b/Training/training_dpo_sft_mistral_sk.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- + """ Combined QLoRA training script for: 1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer @@ -10,7 +11,7 @@ Usage examples (multi-GPU): torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both -You can override any path/knob via CLI flags (see --help). +All paths and hyperparameters can be overridden via CLI flags (see --help). """ import os @@ -47,24 +48,29 @@ from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig # ========================== def patch_trl_dpo_trainer_if_needed(): """ - Fix TRL<->Transformers API mismatches. + Applies runtime patches to handle TRL <-> Transformers API incompatibilities. + Why this exists: + - Some TRL versions implement DPOTrainer methods with older signatures. + - Transformers Trainer may call methods with a newer signature / extra kwargs. + - These mismatches can lead to runtime errors during training. + + Patch set: 1) get_batch_samples signature mismatch: - Transformers Trainer calls get_batch_samples(epoch_iterator, num_batches) - Older TRL DPOTrainer defines get_batch_samples(model, batch) - => 'generator' object has no attribute 'generate' - - Patch: delegate to HFTrainer.get_batch_samples. + - Transformers expects: get_batch_samples(epoch_iterator, num_batches) + - Some older TRL expects: get_batch_samples(model, batch, ...) + - Patch: route the call through HFTrainer.get_batch_samples. 2) compute_loss signature mismatch: - Newer Transformers passes num_items_in_batch kwarg. - Patch wrapper to accept it if missing. + - Newer Transformers may pass num_items_in_batch + - Some TRL versions do not accept that kwarg + - Patch: wrap compute_loss to accept num_items_in_batch and delegate. """ # ---- Patch get_batch_samples ---- try: sig = inspect.signature(DPOTrainer.get_batch_samples) params = list(sig.parameters.keys()) # includes "self" - # Old TRL: (self, model, batch, ...) -> params[1] == "model" + # Old TRL often has (self, model, batch, ...) so params[1] == "model" if len(params) >= 3 and params[1] == "model": print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True) @@ -98,6 +104,15 @@ def patch_trl_dpo_trainer_if_needed(): # Callbacks # ========================== class ProgressETA(TrainerCallback): + """ + Trainer callback that prints progress with ETA on the main local process. + + Printed metrics: + - global step / max steps + - steps per second + - ETA in minutes + - optionally loss and learning rate from the trainer logs + """ def __init__(self): self.t0 = None self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) @@ -108,6 +123,7 @@ class ProgressETA(TrainerCallback): print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True) def on_log(self, args, state, control, logs=None, **kwargs): + # Only print on local_rank==0 to reduce duplicated output under DDP if self.local_rank != 0: return if not logs or not state.max_steps: @@ -133,26 +149,36 @@ class ProgressETA(TrainerCallback): # ========================== @dataclass class CommonCfg: + """ + Shared configuration for model/tokenizer/quantization used by both SFT and DPO. + """ base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b" target_modules: List[str] = None # set in __post_init__ - # QLoRA quant config + # QLoRA quant config (bitsandbytes) load_in_4bit: bool = True bnb_4bit_compute_dtype: str = "float16" bnb_4bit_use_double_quant: bool = True bnb_4bit_quant_type: str = "nf4" - # Tokenizer + # Tokenizer loading use_fast_tokenizer: bool = False trust_remote_code: bool = True def __post_init__(self): + # Default LoRA target modules for Mistral-style architectures if self.target_modules is None: self.target_modules = ["q_proj", "v_proj"] @dataclass class SFTCfg: + """ + SFT stage configuration. + + Dataset requirement: + - load_from_disk(data_path) must contain a "text" column. + """ data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text" out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora" @@ -168,7 +194,7 @@ class SFTCfg: save_steps: int = 200 save_total_limit: int = 2 - # LoRA + # LoRA hyperparameters lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 @@ -176,6 +202,16 @@ class SFTCfg: @dataclass class DPOCfg: + """ + DPO stage configuration. + + Dataset requirement: + - load_from_disk(dpo_data_path) must contain: prompt, chosen, rejected + + Adapter behavior: + - Loads the SFT adapter as frozen weights. + - Adds a NEW LoRA adapter for DPO and trains only that adapter. + """ # SFT adapter input dir (from SFT run) sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora" @@ -208,6 +244,10 @@ class DPOCfg: # Helpers # ========================== def get_local_rank_and_device(): + """ + Determines device placement based on LOCAL_RANK. + In DDP runs, each process binds itself to a separate GPU. + """ local_rank = int(os.environ.get("LOCAL_RANK", "0")) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) @@ -219,6 +259,9 @@ def get_local_rank_and_device(): def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig: + """ + Builds BitsAndBytesConfig for 4-bit QLoRA. + """ dtype = getattr(torch, common.bnb_4bit_compute_dtype) return BitsAndBytesConfig( load_in_4bit=common.load_in_4bit, @@ -229,6 +272,9 @@ def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig: def load_tokenizer(common: CommonCfg): + """ + Loads tokenizer from base_model_path and ensures pad_token is defined. + """ tokenizer = AutoTokenizer.from_pretrained( common.base_model_path, use_fast=common.use_fast_tokenizer, @@ -240,6 +286,14 @@ def load_tokenizer(common: CommonCfg): def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device): + """ + Loads the base model with 4-bit quantization and prepares it for k-bit training. + + Key settings: + - device_map binds the model to the current GPU under DDP. + - use_cache is disabled for training. + - gradient checkpointing is disabled here to avoid reentrant issues under DDP. + """ model = AutoModelForCausalLM.from_pretrained( common.base_model_path, quantization_config=bnb_config, @@ -258,6 +312,9 @@ def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: t def save_manifest(out_dir: str, local_rank: int, manifest: dict): + """ + Writes a JSON manifest for reproducibility (rank0 only). + """ if local_rank != 0: return os.makedirs(out_dir, exist_ok=True) @@ -269,6 +326,12 @@ def save_manifest(out_dir: str, local_rank: int, manifest: dict): # SFT # ========================== def run_sft(common: CommonCfg, sft: SFTCfg): + """ + Runs SFT using TRL SFTTrainer with QLoRA adapter training. + + Expected dataset: + - load_from_disk(sft.data_path) with a "text" field. + """ local_rank, device = get_local_rank_and_device() os.makedirs(sft.out_dir, exist_ok=True) @@ -281,6 +344,7 @@ def run_sft(common: CommonCfg, sft: SFTCfg): bnb_config = make_bnb_config(common) model = load_4bit_model(common, bnb_config, device) + # Trainable LoRA adapter for SFT stage lora = LoraConfig( r=sft.lora_r, lora_alpha=sft.lora_alpha, @@ -357,7 +421,18 @@ def run_sft(common: CommonCfg, sft: SFTCfg): # DPO # ========================== def run_dpo(common: CommonCfg, dpo: DPOCfg): - # Patches first (safe no-op if not needed) + """ + Runs DPO using TRL DPOTrainer with QLoRA adapters. + + Adapter composition: + - Load base model (4-bit) + - Attach SFT adapter as frozen (is_trainable=False) + - Add a new LoRA adapter for DPO training (trainable) + + Expected dataset: + - load_from_disk(dpo.dpo_data_path) with prompt/chosen/rejected columns + """ + # Apply patches early (no-op if already compatible) patch_trl_dpo_trainer_if_needed() local_rank, device = get_local_rank_and_device() @@ -373,9 +448,10 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg): bnb_config = make_bnb_config(common) model = load_4bit_model(common, bnb_config, device) - # Load SFT adapter (frozen) + add NEW DPO adapter (trainable) + # Load SFT adapter as frozen weights model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False) + # Trainable LoRA adapter for DPO stage dpo_lora = LoraConfig( r=dpo.lora_r, lora_alpha=dpo.lora_alpha, @@ -450,6 +526,12 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg): # CLI # ========================== def build_parser(): + """ + Builds an argparse parser with subcommands: + - sft + - dpo + - both + """ p = argparse.ArgumentParser( description="Combined QLoRA training: SFT and DPO (TRL).", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -457,6 +539,7 @@ def build_parser(): sub = p.add_subparsers(dest="cmd", required=False) def add_common_args(sp): + # Shared base model/tokenizer arguments sp.add_argument("--base_model_path", default=CommonCfg.base_model_path) sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"]) sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer") @@ -543,8 +626,7 @@ def build_parser(): def interactive_menu() -> str: """ - Simple interactive selector when the script is launched without CLI arguments. - Uses defaults for all paths/knobs. If you want to override anything, run with CLI flags. + Interactive selector used when the script is launched without CLI arguments. """ print("\n=== Training menu ===") print("1) SFT") @@ -566,15 +648,19 @@ def interactive_menu() -> str: def main(): + """ + Entry point: + - Parse CLI args (or use interactive menu when no args are provided) + - Create config objects + - Run the selected stage(s) + """ parser = build_parser() - # If user runs the script with no arguments, we show interactive menu - # and run with defaults. + # If run without args: show menu and run with defaults import sys if len(sys.argv) == 1: - # IMPORTANT for torchrun/DDP: - # In distributed mode, we must NOT prompt on every process. - # Rank0 asks, then writes the choice to a small file; other ranks read it. + # Under DDP, avoid prompting on every process: + # local_rank==0 writes the choice; other ranks read it. world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) @@ -594,7 +680,7 @@ def main(): except Exception as e: print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True) else: - # Wait for rank0 to write the choice + # Poll for choice_file created by local_rank==0 deadline = time.time() + 600 # seconds cmd = None while time.time() < deadline: @@ -608,7 +694,7 @@ def main(): pass time.sleep(0.2) if not cmd: - print("[ERROR] Timed out waiting for rank0 menu selection. " + print("[ERROR] Timed out waiting for menu selection. " "Run with explicit subcommand: sft/dpo/both.", flush=True) raise SystemExit(2)