Update Training/training_dpo_sft_mistral_sk.py

This commit is contained in:
Artur Hyrenko 2026-02-04 20:42:52 +00:00
parent 7bb03b551f
commit 5f05bc2460

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Combined QLoRA training script for:
1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer
@ -10,7 +11,7 @@ Usage examples (multi-GPU):
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both
You can override any path/knob via CLI flags (see --help).
All paths and hyperparameters can be overridden via CLI flags (see --help).
"""
import os
@ -47,24 +48,29 @@ from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
# ==========================
def patch_trl_dpo_trainer_if_needed():
"""
Fix TRL<->Transformers API mismatches.
Applies runtime patches to handle TRL <-> Transformers API incompatibilities.
Why this exists:
- Some TRL versions implement DPOTrainer methods with older signatures.
- Transformers Trainer may call methods with a newer signature / extra kwargs.
- These mismatches can lead to runtime errors during training.
Patch set:
1) get_batch_samples signature mismatch:
Transformers Trainer calls get_batch_samples(epoch_iterator, num_batches)
Older TRL DPOTrainer defines get_batch_samples(model, batch)
=> 'generator' object has no attribute 'generate'
Patch: delegate to HFTrainer.get_batch_samples.
- Transformers expects: get_batch_samples(epoch_iterator, num_batches)
- Some older TRL expects: get_batch_samples(model, batch, ...)
- Patch: route the call through HFTrainer.get_batch_samples.
2) compute_loss signature mismatch:
Newer Transformers passes num_items_in_batch kwarg.
Patch wrapper to accept it if missing.
- Newer Transformers may pass num_items_in_batch
- Some TRL versions do not accept that kwarg
- Patch: wrap compute_loss to accept num_items_in_batch and delegate.
"""
# ---- Patch get_batch_samples ----
try:
sig = inspect.signature(DPOTrainer.get_batch_samples)
params = list(sig.parameters.keys()) # includes "self"
# Old TRL: (self, model, batch, ...) -> params[1] == "model"
# Old TRL often has (self, model, batch, ...) so params[1] == "model"
if len(params) >= 3 and params[1] == "model":
print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True)
@ -98,6 +104,15 @@ def patch_trl_dpo_trainer_if_needed():
# Callbacks
# ==========================
class ProgressETA(TrainerCallback):
"""
Trainer callback that prints progress with ETA on the main local process.
Printed metrics:
- global step / max steps
- steps per second
- ETA in minutes
- optionally loss and learning rate from the trainer logs
"""
def __init__(self):
self.t0 = None
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
@ -108,6 +123,7 @@ class ProgressETA(TrainerCallback):
print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs):
# Only print on local_rank==0 to reduce duplicated output under DDP
if self.local_rank != 0:
return
if not logs or not state.max_steps:
@ -133,26 +149,36 @@ class ProgressETA(TrainerCallback):
# ==========================
@dataclass
class CommonCfg:
"""
Shared configuration for model/tokenizer/quantization used by both SFT and DPO.
"""
base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b"
target_modules: List[str] = None # set in __post_init__
# QLoRA quant config
# QLoRA quant config (bitsandbytes)
load_in_4bit: bool = True
bnb_4bit_compute_dtype: str = "float16"
bnb_4bit_use_double_quant: bool = True
bnb_4bit_quant_type: str = "nf4"
# Tokenizer
# Tokenizer loading
use_fast_tokenizer: bool = False
trust_remote_code: bool = True
def __post_init__(self):
# Default LoRA target modules for Mistral-style architectures
if self.target_modules is None:
self.target_modules = ["q_proj", "v_proj"]
@dataclass
class SFTCfg:
"""
SFT stage configuration.
Dataset requirement:
- load_from_disk(data_path) must contain a "text" column.
"""
data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text"
out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
@ -168,7 +194,7 @@ class SFTCfg:
save_steps: int = 200
save_total_limit: int = 2
# LoRA
# LoRA hyperparameters
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
@ -176,6 +202,16 @@ class SFTCfg:
@dataclass
class DPOCfg:
"""
DPO stage configuration.
Dataset requirement:
- load_from_disk(dpo_data_path) must contain: prompt, chosen, rejected
Adapter behavior:
- Loads the SFT adapter as frozen weights.
- Adds a NEW LoRA adapter for DPO and trains only that adapter.
"""
# SFT adapter input dir (from SFT run)
sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
@ -208,6 +244,10 @@ class DPOCfg:
# Helpers
# ==========================
def get_local_rank_and_device():
"""
Determines device placement based on LOCAL_RANK.
In DDP runs, each process binds itself to a separate GPU.
"""
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
@ -219,6 +259,9 @@ def get_local_rank_and_device():
def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
"""
Builds BitsAndBytesConfig for 4-bit QLoRA.
"""
dtype = getattr(torch, common.bnb_4bit_compute_dtype)
return BitsAndBytesConfig(
load_in_4bit=common.load_in_4bit,
@ -229,6 +272,9 @@ def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
def load_tokenizer(common: CommonCfg):
"""
Loads tokenizer from base_model_path and ensures pad_token is defined.
"""
tokenizer = AutoTokenizer.from_pretrained(
common.base_model_path,
use_fast=common.use_fast_tokenizer,
@ -240,6 +286,14 @@ def load_tokenizer(common: CommonCfg):
def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device):
"""
Loads the base model with 4-bit quantization and prepares it for k-bit training.
Key settings:
- device_map binds the model to the current GPU under DDP.
- use_cache is disabled for training.
- gradient checkpointing is disabled here to avoid reentrant issues under DDP.
"""
model = AutoModelForCausalLM.from_pretrained(
common.base_model_path,
quantization_config=bnb_config,
@ -258,6 +312,9 @@ def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: t
def save_manifest(out_dir: str, local_rank: int, manifest: dict):
"""
Writes a JSON manifest for reproducibility (rank0 only).
"""
if local_rank != 0:
return
os.makedirs(out_dir, exist_ok=True)
@ -269,6 +326,12 @@ def save_manifest(out_dir: str, local_rank: int, manifest: dict):
# SFT
# ==========================
def run_sft(common: CommonCfg, sft: SFTCfg):
"""
Runs SFT using TRL SFTTrainer with QLoRA adapter training.
Expected dataset:
- load_from_disk(sft.data_path) with a "text" field.
"""
local_rank, device = get_local_rank_and_device()
os.makedirs(sft.out_dir, exist_ok=True)
@ -281,6 +344,7 @@ def run_sft(common: CommonCfg, sft: SFTCfg):
bnb_config = make_bnb_config(common)
model = load_4bit_model(common, bnb_config, device)
# Trainable LoRA adapter for SFT stage
lora = LoraConfig(
r=sft.lora_r,
lora_alpha=sft.lora_alpha,
@ -357,7 +421,18 @@ def run_sft(common: CommonCfg, sft: SFTCfg):
# DPO
# ==========================
def run_dpo(common: CommonCfg, dpo: DPOCfg):
# Patches first (safe no-op if not needed)
"""
Runs DPO using TRL DPOTrainer with QLoRA adapters.
Adapter composition:
- Load base model (4-bit)
- Attach SFT adapter as frozen (is_trainable=False)
- Add a new LoRA adapter for DPO training (trainable)
Expected dataset:
- load_from_disk(dpo.dpo_data_path) with prompt/chosen/rejected columns
"""
# Apply patches early (no-op if already compatible)
patch_trl_dpo_trainer_if_needed()
local_rank, device = get_local_rank_and_device()
@ -373,9 +448,10 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg):
bnb_config = make_bnb_config(common)
model = load_4bit_model(common, bnb_config, device)
# Load SFT adapter (frozen) + add NEW DPO adapter (trainable)
# Load SFT adapter as frozen weights
model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False)
# Trainable LoRA adapter for DPO stage
dpo_lora = LoraConfig(
r=dpo.lora_r,
lora_alpha=dpo.lora_alpha,
@ -450,6 +526,12 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg):
# CLI
# ==========================
def build_parser():
"""
Builds an argparse parser with subcommands:
- sft
- dpo
- both
"""
p = argparse.ArgumentParser(
description="Combined QLoRA training: SFT and DPO (TRL).",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
@ -457,6 +539,7 @@ def build_parser():
sub = p.add_subparsers(dest="cmd", required=False)
def add_common_args(sp):
# Shared base model/tokenizer arguments
sp.add_argument("--base_model_path", default=CommonCfg.base_model_path)
sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"])
sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer")
@ -543,8 +626,7 @@ def build_parser():
def interactive_menu() -> str:
"""
Simple interactive selector when the script is launched without CLI arguments.
Uses defaults for all paths/knobs. If you want to override anything, run with CLI flags.
Interactive selector used when the script is launched without CLI arguments.
"""
print("\n=== Training menu ===")
print("1) SFT")
@ -566,15 +648,19 @@ def interactive_menu() -> str:
def main():
"""
Entry point:
- Parse CLI args (or use interactive menu when no args are provided)
- Create config objects
- Run the selected stage(s)
"""
parser = build_parser()
# If user runs the script with no arguments, we show interactive menu
# and run with defaults.
# If run without args: show menu and run with defaults
import sys
if len(sys.argv) == 1:
# IMPORTANT for torchrun/DDP:
# In distributed mode, we must NOT prompt on every process.
# Rank0 asks, then writes the choice to a small file; other ranks read it.
# Under DDP, avoid prompting on every process:
# local_rank==0 writes the choice; other ranks read it.
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
@ -594,7 +680,7 @@ def main():
except Exception as e:
print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True)
else:
# Wait for rank0 to write the choice
# Poll for choice_file created by local_rank==0
deadline = time.time() + 600 # seconds
cmd = None
while time.time() < deadline:
@ -608,7 +694,7 @@ def main():
pass
time.sleep(0.2)
if not cmd:
print("[ERROR] Timed out waiting for rank0 menu selection. "
print("[ERROR] Timed out waiting for menu selection. "
"Run with explicit subcommand: sft/dpo/both.", flush=True)
raise SystemExit(2)