Update Training/training_dpo_sft_mistral_sk.py

This commit is contained in:
Artur Hyrenko 2026-02-04 20:42:52 +00:00
parent 7bb03b551f
commit 5f05bc2460

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
Combined QLoRA training script for: Combined QLoRA training script for:
1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer 1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer
@ -10,7 +11,7 @@ Usage examples (multi-GPU):
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both
You can override any path/knob via CLI flags (see --help). All paths and hyperparameters can be overridden via CLI flags (see --help).
""" """
import os import os
@ -47,24 +48,29 @@ from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
# ========================== # ==========================
def patch_trl_dpo_trainer_if_needed(): def patch_trl_dpo_trainer_if_needed():
""" """
Fix TRL<->Transformers API mismatches. Applies runtime patches to handle TRL <-> Transformers API incompatibilities.
Why this exists:
- Some TRL versions implement DPOTrainer methods with older signatures.
- Transformers Trainer may call methods with a newer signature / extra kwargs.
- These mismatches can lead to runtime errors during training.
Patch set:
1) get_batch_samples signature mismatch: 1) get_batch_samples signature mismatch:
Transformers Trainer calls get_batch_samples(epoch_iterator, num_batches) - Transformers expects: get_batch_samples(epoch_iterator, num_batches)
Older TRL DPOTrainer defines get_batch_samples(model, batch) - Some older TRL expects: get_batch_samples(model, batch, ...)
=> 'generator' object has no attribute 'generate' - Patch: route the call through HFTrainer.get_batch_samples.
Patch: delegate to HFTrainer.get_batch_samples.
2) compute_loss signature mismatch: 2) compute_loss signature mismatch:
Newer Transformers passes num_items_in_batch kwarg. - Newer Transformers may pass num_items_in_batch
Patch wrapper to accept it if missing. - Some TRL versions do not accept that kwarg
- Patch: wrap compute_loss to accept num_items_in_batch and delegate.
""" """
# ---- Patch get_batch_samples ---- # ---- Patch get_batch_samples ----
try: try:
sig = inspect.signature(DPOTrainer.get_batch_samples) sig = inspect.signature(DPOTrainer.get_batch_samples)
params = list(sig.parameters.keys()) # includes "self" params = list(sig.parameters.keys()) # includes "self"
# Old TRL: (self, model, batch, ...) -> params[1] == "model" # Old TRL often has (self, model, batch, ...) so params[1] == "model"
if len(params) >= 3 and params[1] == "model": if len(params) >= 3 and params[1] == "model":
print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True) print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True)
@ -98,6 +104,15 @@ def patch_trl_dpo_trainer_if_needed():
# Callbacks # Callbacks
# ========================== # ==========================
class ProgressETA(TrainerCallback): class ProgressETA(TrainerCallback):
"""
Trainer callback that prints progress with ETA on the main local process.
Printed metrics:
- global step / max steps
- steps per second
- ETA in minutes
- optionally loss and learning rate from the trainer logs
"""
def __init__(self): def __init__(self):
self.t0 = None self.t0 = None
self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
@ -108,6 +123,7 @@ class ProgressETA(TrainerCallback):
print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True) print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs): def on_log(self, args, state, control, logs=None, **kwargs):
# Only print on local_rank==0 to reduce duplicated output under DDP
if self.local_rank != 0: if self.local_rank != 0:
return return
if not logs or not state.max_steps: if not logs or not state.max_steps:
@ -133,26 +149,36 @@ class ProgressETA(TrainerCallback):
# ========================== # ==========================
@dataclass @dataclass
class CommonCfg: class CommonCfg:
"""
Shared configuration for model/tokenizer/quantization used by both SFT and DPO.
"""
base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b" base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b"
target_modules: List[str] = None # set in __post_init__ target_modules: List[str] = None # set in __post_init__
# QLoRA quant config # QLoRA quant config (bitsandbytes)
load_in_4bit: bool = True load_in_4bit: bool = True
bnb_4bit_compute_dtype: str = "float16" bnb_4bit_compute_dtype: str = "float16"
bnb_4bit_use_double_quant: bool = True bnb_4bit_use_double_quant: bool = True
bnb_4bit_quant_type: str = "nf4" bnb_4bit_quant_type: str = "nf4"
# Tokenizer # Tokenizer loading
use_fast_tokenizer: bool = False use_fast_tokenizer: bool = False
trust_remote_code: bool = True trust_remote_code: bool = True
def __post_init__(self): def __post_init__(self):
# Default LoRA target modules for Mistral-style architectures
if self.target_modules is None: if self.target_modules is None:
self.target_modules = ["q_proj", "v_proj"] self.target_modules = ["q_proj", "v_proj"]
@dataclass @dataclass
class SFTCfg: class SFTCfg:
"""
SFT stage configuration.
Dataset requirement:
- load_from_disk(data_path) must contain a "text" column.
"""
data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text" data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text"
out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora" out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
@ -168,7 +194,7 @@ class SFTCfg:
save_steps: int = 200 save_steps: int = 200
save_total_limit: int = 2 save_total_limit: int = 2
# LoRA # LoRA hyperparameters
lora_r: int = 16 lora_r: int = 16
lora_alpha: int = 32 lora_alpha: int = 32
lora_dropout: float = 0.05 lora_dropout: float = 0.05
@ -176,6 +202,16 @@ class SFTCfg:
@dataclass @dataclass
class DPOCfg: class DPOCfg:
"""
DPO stage configuration.
Dataset requirement:
- load_from_disk(dpo_data_path) must contain: prompt, chosen, rejected
Adapter behavior:
- Loads the SFT adapter as frozen weights.
- Adds a NEW LoRA adapter for DPO and trains only that adapter.
"""
# SFT adapter input dir (from SFT run) # SFT adapter input dir (from SFT run)
sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora" sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
@ -208,6 +244,10 @@ class DPOCfg:
# Helpers # Helpers
# ========================== # ==========================
def get_local_rank_and_device(): def get_local_rank_and_device():
"""
Determines device placement based on LOCAL_RANK.
In DDP runs, each process binds itself to a separate GPU.
"""
local_rank = int(os.environ.get("LOCAL_RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
@ -219,6 +259,9 @@ def get_local_rank_and_device():
def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig: def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
"""
Builds BitsAndBytesConfig for 4-bit QLoRA.
"""
dtype = getattr(torch, common.bnb_4bit_compute_dtype) dtype = getattr(torch, common.bnb_4bit_compute_dtype)
return BitsAndBytesConfig( return BitsAndBytesConfig(
load_in_4bit=common.load_in_4bit, load_in_4bit=common.load_in_4bit,
@ -229,6 +272,9 @@ def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
def load_tokenizer(common: CommonCfg): def load_tokenizer(common: CommonCfg):
"""
Loads tokenizer from base_model_path and ensures pad_token is defined.
"""
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
common.base_model_path, common.base_model_path,
use_fast=common.use_fast_tokenizer, use_fast=common.use_fast_tokenizer,
@ -240,6 +286,14 @@ def load_tokenizer(common: CommonCfg):
def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device): def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device):
"""
Loads the base model with 4-bit quantization and prepares it for k-bit training.
Key settings:
- device_map binds the model to the current GPU under DDP.
- use_cache is disabled for training.
- gradient checkpointing is disabled here to avoid reentrant issues under DDP.
"""
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
common.base_model_path, common.base_model_path,
quantization_config=bnb_config, quantization_config=bnb_config,
@ -258,6 +312,9 @@ def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: t
def save_manifest(out_dir: str, local_rank: int, manifest: dict): def save_manifest(out_dir: str, local_rank: int, manifest: dict):
"""
Writes a JSON manifest for reproducibility (rank0 only).
"""
if local_rank != 0: if local_rank != 0:
return return
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
@ -269,6 +326,12 @@ def save_manifest(out_dir: str, local_rank: int, manifest: dict):
# SFT # SFT
# ========================== # ==========================
def run_sft(common: CommonCfg, sft: SFTCfg): def run_sft(common: CommonCfg, sft: SFTCfg):
"""
Runs SFT using TRL SFTTrainer with QLoRA adapter training.
Expected dataset:
- load_from_disk(sft.data_path) with a "text" field.
"""
local_rank, device = get_local_rank_and_device() local_rank, device = get_local_rank_and_device()
os.makedirs(sft.out_dir, exist_ok=True) os.makedirs(sft.out_dir, exist_ok=True)
@ -281,6 +344,7 @@ def run_sft(common: CommonCfg, sft: SFTCfg):
bnb_config = make_bnb_config(common) bnb_config = make_bnb_config(common)
model = load_4bit_model(common, bnb_config, device) model = load_4bit_model(common, bnb_config, device)
# Trainable LoRA adapter for SFT stage
lora = LoraConfig( lora = LoraConfig(
r=sft.lora_r, r=sft.lora_r,
lora_alpha=sft.lora_alpha, lora_alpha=sft.lora_alpha,
@ -357,7 +421,18 @@ def run_sft(common: CommonCfg, sft: SFTCfg):
# DPO # DPO
# ========================== # ==========================
def run_dpo(common: CommonCfg, dpo: DPOCfg): def run_dpo(common: CommonCfg, dpo: DPOCfg):
# Patches first (safe no-op if not needed) """
Runs DPO using TRL DPOTrainer with QLoRA adapters.
Adapter composition:
- Load base model (4-bit)
- Attach SFT adapter as frozen (is_trainable=False)
- Add a new LoRA adapter for DPO training (trainable)
Expected dataset:
- load_from_disk(dpo.dpo_data_path) with prompt/chosen/rejected columns
"""
# Apply patches early (no-op if already compatible)
patch_trl_dpo_trainer_if_needed() patch_trl_dpo_trainer_if_needed()
local_rank, device = get_local_rank_and_device() local_rank, device = get_local_rank_and_device()
@ -373,9 +448,10 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg):
bnb_config = make_bnb_config(common) bnb_config = make_bnb_config(common)
model = load_4bit_model(common, bnb_config, device) model = load_4bit_model(common, bnb_config, device)
# Load SFT adapter (frozen) + add NEW DPO adapter (trainable) # Load SFT adapter as frozen weights
model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False) model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False)
# Trainable LoRA adapter for DPO stage
dpo_lora = LoraConfig( dpo_lora = LoraConfig(
r=dpo.lora_r, r=dpo.lora_r,
lora_alpha=dpo.lora_alpha, lora_alpha=dpo.lora_alpha,
@ -450,6 +526,12 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg):
# CLI # CLI
# ========================== # ==========================
def build_parser(): def build_parser():
"""
Builds an argparse parser with subcommands:
- sft
- dpo
- both
"""
p = argparse.ArgumentParser( p = argparse.ArgumentParser(
description="Combined QLoRA training: SFT and DPO (TRL).", description="Combined QLoRA training: SFT and DPO (TRL).",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
@ -457,6 +539,7 @@ def build_parser():
sub = p.add_subparsers(dest="cmd", required=False) sub = p.add_subparsers(dest="cmd", required=False)
def add_common_args(sp): def add_common_args(sp):
# Shared base model/tokenizer arguments
sp.add_argument("--base_model_path", default=CommonCfg.base_model_path) sp.add_argument("--base_model_path", default=CommonCfg.base_model_path)
sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"]) sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"])
sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer") sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer")
@ -543,8 +626,7 @@ def build_parser():
def interactive_menu() -> str: def interactive_menu() -> str:
""" """
Simple interactive selector when the script is launched without CLI arguments. Interactive selector used when the script is launched without CLI arguments.
Uses defaults for all paths/knobs. If you want to override anything, run with CLI flags.
""" """
print("\n=== Training menu ===") print("\n=== Training menu ===")
print("1) SFT") print("1) SFT")
@ -566,15 +648,19 @@ def interactive_menu() -> str:
def main(): def main():
"""
Entry point:
- Parse CLI args (or use interactive menu when no args are provided)
- Create config objects
- Run the selected stage(s)
"""
parser = build_parser() parser = build_parser()
# If user runs the script with no arguments, we show interactive menu # If run without args: show menu and run with defaults
# and run with defaults.
import sys import sys
if len(sys.argv) == 1: if len(sys.argv) == 1:
# IMPORTANT for torchrun/DDP: # Under DDP, avoid prompting on every process:
# In distributed mode, we must NOT prompt on every process. # local_rank==0 writes the choice; other ranks read it.
# Rank0 asks, then writes the choice to a small file; other ranks read it.
world_size = int(os.environ.get("WORLD_SIZE", "1")) world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0"))
@ -594,7 +680,7 @@ def main():
except Exception as e: except Exception as e:
print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True) print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True)
else: else:
# Wait for rank0 to write the choice # Poll for choice_file created by local_rank==0
deadline = time.time() + 600 # seconds deadline = time.time() + 600 # seconds
cmd = None cmd = None
while time.time() < deadline: while time.time() < deadline:
@ -608,7 +694,7 @@ def main():
pass pass
time.sleep(0.2) time.sleep(0.2)
if not cmd: if not cmd:
print("[ERROR] Timed out waiting for rank0 menu selection. " print("[ERROR] Timed out waiting for menu selection. "
"Run with explicit subcommand: sft/dpo/both.", flush=True) "Run with explicit subcommand: sft/dpo/both.", flush=True)
raise SystemExit(2) raise SystemExit(2)