Update Training/training_dpo_sft_mistral_sk.py
This commit is contained in:
parent
7bb03b551f
commit
5f05bc2460
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Combined QLoRA training script for:
|
||||
1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer
|
||||
@ -10,7 +11,7 @@ Usage examples (multi-GPU):
|
||||
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo
|
||||
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both
|
||||
|
||||
You can override any path/knob via CLI flags (see --help).
|
||||
All paths and hyperparameters can be overridden via CLI flags (see --help).
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -47,24 +48,29 @@ from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
|
||||
# ==========================
|
||||
def patch_trl_dpo_trainer_if_needed():
|
||||
"""
|
||||
Fix TRL<->Transformers API mismatches.
|
||||
Applies runtime patches to handle TRL <-> Transformers API incompatibilities.
|
||||
|
||||
Why this exists:
|
||||
- Some TRL versions implement DPOTrainer methods with older signatures.
|
||||
- Transformers Trainer may call methods with a newer signature / extra kwargs.
|
||||
- These mismatches can lead to runtime errors during training.
|
||||
|
||||
Patch set:
|
||||
1) get_batch_samples signature mismatch:
|
||||
Transformers Trainer calls get_batch_samples(epoch_iterator, num_batches)
|
||||
Older TRL DPOTrainer defines get_batch_samples(model, batch)
|
||||
=> 'generator' object has no attribute 'generate'
|
||||
|
||||
Patch: delegate to HFTrainer.get_batch_samples.
|
||||
- Transformers expects: get_batch_samples(epoch_iterator, num_batches)
|
||||
- Some older TRL expects: get_batch_samples(model, batch, ...)
|
||||
- Patch: route the call through HFTrainer.get_batch_samples.
|
||||
|
||||
2) compute_loss signature mismatch:
|
||||
Newer Transformers passes num_items_in_batch kwarg.
|
||||
Patch wrapper to accept it if missing.
|
||||
- Newer Transformers may pass num_items_in_batch
|
||||
- Some TRL versions do not accept that kwarg
|
||||
- Patch: wrap compute_loss to accept num_items_in_batch and delegate.
|
||||
"""
|
||||
# ---- Patch get_batch_samples ----
|
||||
try:
|
||||
sig = inspect.signature(DPOTrainer.get_batch_samples)
|
||||
params = list(sig.parameters.keys()) # includes "self"
|
||||
# Old TRL: (self, model, batch, ...) -> params[1] == "model"
|
||||
# Old TRL often has (self, model, batch, ...) so params[1] == "model"
|
||||
if len(params) >= 3 and params[1] == "model":
|
||||
print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True)
|
||||
|
||||
@ -98,6 +104,15 @@ def patch_trl_dpo_trainer_if_needed():
|
||||
# Callbacks
|
||||
# ==========================
|
||||
class ProgressETA(TrainerCallback):
|
||||
"""
|
||||
Trainer callback that prints progress with ETA on the main local process.
|
||||
|
||||
Printed metrics:
|
||||
- global step / max steps
|
||||
- steps per second
|
||||
- ETA in minutes
|
||||
- optionally loss and learning rate from the trainer logs
|
||||
"""
|
||||
def __init__(self):
|
||||
self.t0 = None
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
@ -108,6 +123,7 @@ class ProgressETA(TrainerCallback):
|
||||
print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True)
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
# Only print on local_rank==0 to reduce duplicated output under DDP
|
||||
if self.local_rank != 0:
|
||||
return
|
||||
if not logs or not state.max_steps:
|
||||
@ -133,26 +149,36 @@ class ProgressETA(TrainerCallback):
|
||||
# ==========================
|
||||
@dataclass
|
||||
class CommonCfg:
|
||||
"""
|
||||
Shared configuration for model/tokenizer/quantization used by both SFT and DPO.
|
||||
"""
|
||||
base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b"
|
||||
target_modules: List[str] = None # set in __post_init__
|
||||
|
||||
# QLoRA quant config
|
||||
# QLoRA quant config (bitsandbytes)
|
||||
load_in_4bit: bool = True
|
||||
bnb_4bit_compute_dtype: str = "float16"
|
||||
bnb_4bit_use_double_quant: bool = True
|
||||
bnb_4bit_quant_type: str = "nf4"
|
||||
|
||||
# Tokenizer
|
||||
# Tokenizer loading
|
||||
use_fast_tokenizer: bool = False
|
||||
trust_remote_code: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# Default LoRA target modules for Mistral-style architectures
|
||||
if self.target_modules is None:
|
||||
self.target_modules = ["q_proj", "v_proj"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTCfg:
|
||||
"""
|
||||
SFT stage configuration.
|
||||
|
||||
Dataset requirement:
|
||||
- load_from_disk(data_path) must contain a "text" column.
|
||||
"""
|
||||
data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text"
|
||||
out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
||||
|
||||
@ -168,7 +194,7 @@ class SFTCfg:
|
||||
save_steps: int = 200
|
||||
save_total_limit: int = 2
|
||||
|
||||
# LoRA
|
||||
# LoRA hyperparameters
|
||||
lora_r: int = 16
|
||||
lora_alpha: int = 32
|
||||
lora_dropout: float = 0.05
|
||||
@ -176,6 +202,16 @@ class SFTCfg:
|
||||
|
||||
@dataclass
|
||||
class DPOCfg:
|
||||
"""
|
||||
DPO stage configuration.
|
||||
|
||||
Dataset requirement:
|
||||
- load_from_disk(dpo_data_path) must contain: prompt, chosen, rejected
|
||||
|
||||
Adapter behavior:
|
||||
- Loads the SFT adapter as frozen weights.
|
||||
- Adds a NEW LoRA adapter for DPO and trains only that adapter.
|
||||
"""
|
||||
# SFT adapter input dir (from SFT run)
|
||||
sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
||||
|
||||
@ -208,6 +244,10 @@ class DPOCfg:
|
||||
# Helpers
|
||||
# ==========================
|
||||
def get_local_rank_and_device():
|
||||
"""
|
||||
Determines device placement based on LOCAL_RANK.
|
||||
In DDP runs, each process binds itself to a separate GPU.
|
||||
"""
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(local_rank)
|
||||
@ -219,6 +259,9 @@ def get_local_rank_and_device():
|
||||
|
||||
|
||||
def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
|
||||
"""
|
||||
Builds BitsAndBytesConfig for 4-bit QLoRA.
|
||||
"""
|
||||
dtype = getattr(torch, common.bnb_4bit_compute_dtype)
|
||||
return BitsAndBytesConfig(
|
||||
load_in_4bit=common.load_in_4bit,
|
||||
@ -229,6 +272,9 @@ def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
|
||||
|
||||
|
||||
def load_tokenizer(common: CommonCfg):
|
||||
"""
|
||||
Loads tokenizer from base_model_path and ensures pad_token is defined.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
common.base_model_path,
|
||||
use_fast=common.use_fast_tokenizer,
|
||||
@ -240,6 +286,14 @@ def load_tokenizer(common: CommonCfg):
|
||||
|
||||
|
||||
def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device):
|
||||
"""
|
||||
Loads the base model with 4-bit quantization and prepares it for k-bit training.
|
||||
|
||||
Key settings:
|
||||
- device_map binds the model to the current GPU under DDP.
|
||||
- use_cache is disabled for training.
|
||||
- gradient checkpointing is disabled here to avoid reentrant issues under DDP.
|
||||
"""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
common.base_model_path,
|
||||
quantization_config=bnb_config,
|
||||
@ -258,6 +312,9 @@ def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: t
|
||||
|
||||
|
||||
def save_manifest(out_dir: str, local_rank: int, manifest: dict):
|
||||
"""
|
||||
Writes a JSON manifest for reproducibility (rank0 only).
|
||||
"""
|
||||
if local_rank != 0:
|
||||
return
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
@ -269,6 +326,12 @@ def save_manifest(out_dir: str, local_rank: int, manifest: dict):
|
||||
# SFT
|
||||
# ==========================
|
||||
def run_sft(common: CommonCfg, sft: SFTCfg):
|
||||
"""
|
||||
Runs SFT using TRL SFTTrainer with QLoRA adapter training.
|
||||
|
||||
Expected dataset:
|
||||
- load_from_disk(sft.data_path) with a "text" field.
|
||||
"""
|
||||
local_rank, device = get_local_rank_and_device()
|
||||
os.makedirs(sft.out_dir, exist_ok=True)
|
||||
|
||||
@ -281,6 +344,7 @@ def run_sft(common: CommonCfg, sft: SFTCfg):
|
||||
bnb_config = make_bnb_config(common)
|
||||
model = load_4bit_model(common, bnb_config, device)
|
||||
|
||||
# Trainable LoRA adapter for SFT stage
|
||||
lora = LoraConfig(
|
||||
r=sft.lora_r,
|
||||
lora_alpha=sft.lora_alpha,
|
||||
@ -357,7 +421,18 @@ def run_sft(common: CommonCfg, sft: SFTCfg):
|
||||
# DPO
|
||||
# ==========================
|
||||
def run_dpo(common: CommonCfg, dpo: DPOCfg):
|
||||
# Patches first (safe no-op if not needed)
|
||||
"""
|
||||
Runs DPO using TRL DPOTrainer with QLoRA adapters.
|
||||
|
||||
Adapter composition:
|
||||
- Load base model (4-bit)
|
||||
- Attach SFT adapter as frozen (is_trainable=False)
|
||||
- Add a new LoRA adapter for DPO training (trainable)
|
||||
|
||||
Expected dataset:
|
||||
- load_from_disk(dpo.dpo_data_path) with prompt/chosen/rejected columns
|
||||
"""
|
||||
# Apply patches early (no-op if already compatible)
|
||||
patch_trl_dpo_trainer_if_needed()
|
||||
|
||||
local_rank, device = get_local_rank_and_device()
|
||||
@ -373,9 +448,10 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg):
|
||||
bnb_config = make_bnb_config(common)
|
||||
model = load_4bit_model(common, bnb_config, device)
|
||||
|
||||
# Load SFT adapter (frozen) + add NEW DPO adapter (trainable)
|
||||
# Load SFT adapter as frozen weights
|
||||
model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False)
|
||||
|
||||
# Trainable LoRA adapter for DPO stage
|
||||
dpo_lora = LoraConfig(
|
||||
r=dpo.lora_r,
|
||||
lora_alpha=dpo.lora_alpha,
|
||||
@ -450,6 +526,12 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg):
|
||||
# CLI
|
||||
# ==========================
|
||||
def build_parser():
|
||||
"""
|
||||
Builds an argparse parser with subcommands:
|
||||
- sft
|
||||
- dpo
|
||||
- both
|
||||
"""
|
||||
p = argparse.ArgumentParser(
|
||||
description="Combined QLoRA training: SFT and DPO (TRL).",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
@ -457,6 +539,7 @@ def build_parser():
|
||||
sub = p.add_subparsers(dest="cmd", required=False)
|
||||
|
||||
def add_common_args(sp):
|
||||
# Shared base model/tokenizer arguments
|
||||
sp.add_argument("--base_model_path", default=CommonCfg.base_model_path)
|
||||
sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"])
|
||||
sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer")
|
||||
@ -543,8 +626,7 @@ def build_parser():
|
||||
|
||||
def interactive_menu() -> str:
|
||||
"""
|
||||
Simple interactive selector when the script is launched without CLI arguments.
|
||||
Uses defaults for all paths/knobs. If you want to override anything, run with CLI flags.
|
||||
Interactive selector used when the script is launched without CLI arguments.
|
||||
"""
|
||||
print("\n=== Training menu ===")
|
||||
print("1) SFT")
|
||||
@ -566,15 +648,19 @@ def interactive_menu() -> str:
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Entry point:
|
||||
- Parse CLI args (or use interactive menu when no args are provided)
|
||||
- Create config objects
|
||||
- Run the selected stage(s)
|
||||
"""
|
||||
parser = build_parser()
|
||||
|
||||
# If user runs the script with no arguments, we show interactive menu
|
||||
# and run with defaults.
|
||||
# If run without args: show menu and run with defaults
|
||||
import sys
|
||||
if len(sys.argv) == 1:
|
||||
# IMPORTANT for torchrun/DDP:
|
||||
# In distributed mode, we must NOT prompt on every process.
|
||||
# Rank0 asks, then writes the choice to a small file; other ranks read it.
|
||||
# Under DDP, avoid prompting on every process:
|
||||
# local_rank==0 writes the choice; other ranks read it.
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
|
||||
@ -594,7 +680,7 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True)
|
||||
else:
|
||||
# Wait for rank0 to write the choice
|
||||
# Poll for choice_file created by local_rank==0
|
||||
deadline = time.time() + 600 # seconds
|
||||
cmd = None
|
||||
while time.time() < deadline:
|
||||
@ -608,7 +694,7 @@ def main():
|
||||
pass
|
||||
time.sleep(0.2)
|
||||
if not cmd:
|
||||
print("[ERROR] Timed out waiting for rank0 menu selection. "
|
||||
print("[ERROR] Timed out waiting for menu selection. "
|
||||
"Run with explicit subcommand: sft/dpo/both.", flush=True)
|
||||
raise SystemExit(2)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user