Update Training/training_dpo_sft_mistral_sk.py
This commit is contained in:
parent
7bb03b551f
commit
5f05bc2460
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Combined QLoRA training script for:
|
Combined QLoRA training script for:
|
||||||
1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer
|
1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer
|
||||||
@ -10,7 +11,7 @@ Usage examples (multi-GPU):
|
|||||||
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo
|
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo
|
||||||
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both
|
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both
|
||||||
|
|
||||||
You can override any path/knob via CLI flags (see --help).
|
All paths and hyperparameters can be overridden via CLI flags (see --help).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -47,24 +48,29 @@ from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
|
|||||||
# ==========================
|
# ==========================
|
||||||
def patch_trl_dpo_trainer_if_needed():
|
def patch_trl_dpo_trainer_if_needed():
|
||||||
"""
|
"""
|
||||||
Fix TRL<->Transformers API mismatches.
|
Applies runtime patches to handle TRL <-> Transformers API incompatibilities.
|
||||||
|
|
||||||
|
Why this exists:
|
||||||
|
- Some TRL versions implement DPOTrainer methods with older signatures.
|
||||||
|
- Transformers Trainer may call methods with a newer signature / extra kwargs.
|
||||||
|
- These mismatches can lead to runtime errors during training.
|
||||||
|
|
||||||
|
Patch set:
|
||||||
1) get_batch_samples signature mismatch:
|
1) get_batch_samples signature mismatch:
|
||||||
Transformers Trainer calls get_batch_samples(epoch_iterator, num_batches)
|
- Transformers expects: get_batch_samples(epoch_iterator, num_batches)
|
||||||
Older TRL DPOTrainer defines get_batch_samples(model, batch)
|
- Some older TRL expects: get_batch_samples(model, batch, ...)
|
||||||
=> 'generator' object has no attribute 'generate'
|
- Patch: route the call through HFTrainer.get_batch_samples.
|
||||||
|
|
||||||
Patch: delegate to HFTrainer.get_batch_samples.
|
|
||||||
|
|
||||||
2) compute_loss signature mismatch:
|
2) compute_loss signature mismatch:
|
||||||
Newer Transformers passes num_items_in_batch kwarg.
|
- Newer Transformers may pass num_items_in_batch
|
||||||
Patch wrapper to accept it if missing.
|
- Some TRL versions do not accept that kwarg
|
||||||
|
- Patch: wrap compute_loss to accept num_items_in_batch and delegate.
|
||||||
"""
|
"""
|
||||||
# ---- Patch get_batch_samples ----
|
# ---- Patch get_batch_samples ----
|
||||||
try:
|
try:
|
||||||
sig = inspect.signature(DPOTrainer.get_batch_samples)
|
sig = inspect.signature(DPOTrainer.get_batch_samples)
|
||||||
params = list(sig.parameters.keys()) # includes "self"
|
params = list(sig.parameters.keys()) # includes "self"
|
||||||
# Old TRL: (self, model, batch, ...) -> params[1] == "model"
|
# Old TRL often has (self, model, batch, ...) so params[1] == "model"
|
||||||
if len(params) >= 3 and params[1] == "model":
|
if len(params) >= 3 and params[1] == "model":
|
||||||
print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True)
|
print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True)
|
||||||
|
|
||||||
@ -98,6 +104,15 @@ def patch_trl_dpo_trainer_if_needed():
|
|||||||
# Callbacks
|
# Callbacks
|
||||||
# ==========================
|
# ==========================
|
||||||
class ProgressETA(TrainerCallback):
|
class ProgressETA(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Trainer callback that prints progress with ETA on the main local process.
|
||||||
|
|
||||||
|
Printed metrics:
|
||||||
|
- global step / max steps
|
||||||
|
- steps per second
|
||||||
|
- ETA in minutes
|
||||||
|
- optionally loss and learning rate from the trainer logs
|
||||||
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.t0 = None
|
self.t0 = None
|
||||||
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
@ -108,6 +123,7 @@ class ProgressETA(TrainerCallback):
|
|||||||
print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True)
|
print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True)
|
||||||
|
|
||||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
# Only print on local_rank==0 to reduce duplicated output under DDP
|
||||||
if self.local_rank != 0:
|
if self.local_rank != 0:
|
||||||
return
|
return
|
||||||
if not logs or not state.max_steps:
|
if not logs or not state.max_steps:
|
||||||
@ -133,26 +149,36 @@ class ProgressETA(TrainerCallback):
|
|||||||
# ==========================
|
# ==========================
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommonCfg:
|
class CommonCfg:
|
||||||
|
"""
|
||||||
|
Shared configuration for model/tokenizer/quantization used by both SFT and DPO.
|
||||||
|
"""
|
||||||
base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b"
|
base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b"
|
||||||
target_modules: List[str] = None # set in __post_init__
|
target_modules: List[str] = None # set in __post_init__
|
||||||
|
|
||||||
# QLoRA quant config
|
# QLoRA quant config (bitsandbytes)
|
||||||
load_in_4bit: bool = True
|
load_in_4bit: bool = True
|
||||||
bnb_4bit_compute_dtype: str = "float16"
|
bnb_4bit_compute_dtype: str = "float16"
|
||||||
bnb_4bit_use_double_quant: bool = True
|
bnb_4bit_use_double_quant: bool = True
|
||||||
bnb_4bit_quant_type: str = "nf4"
|
bnb_4bit_quant_type: str = "nf4"
|
||||||
|
|
||||||
# Tokenizer
|
# Tokenizer loading
|
||||||
use_fast_tokenizer: bool = False
|
use_fast_tokenizer: bool = False
|
||||||
trust_remote_code: bool = True
|
trust_remote_code: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
# Default LoRA target modules for Mistral-style architectures
|
||||||
if self.target_modules is None:
|
if self.target_modules is None:
|
||||||
self.target_modules = ["q_proj", "v_proj"]
|
self.target_modules = ["q_proj", "v_proj"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SFTCfg:
|
class SFTCfg:
|
||||||
|
"""
|
||||||
|
SFT stage configuration.
|
||||||
|
|
||||||
|
Dataset requirement:
|
||||||
|
- load_from_disk(data_path) must contain a "text" column.
|
||||||
|
"""
|
||||||
data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text"
|
data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text"
|
||||||
out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
||||||
|
|
||||||
@ -168,7 +194,7 @@ class SFTCfg:
|
|||||||
save_steps: int = 200
|
save_steps: int = 200
|
||||||
save_total_limit: int = 2
|
save_total_limit: int = 2
|
||||||
|
|
||||||
# LoRA
|
# LoRA hyperparameters
|
||||||
lora_r: int = 16
|
lora_r: int = 16
|
||||||
lora_alpha: int = 32
|
lora_alpha: int = 32
|
||||||
lora_dropout: float = 0.05
|
lora_dropout: float = 0.05
|
||||||
@ -176,6 +202,16 @@ class SFTCfg:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DPOCfg:
|
class DPOCfg:
|
||||||
|
"""
|
||||||
|
DPO stage configuration.
|
||||||
|
|
||||||
|
Dataset requirement:
|
||||||
|
- load_from_disk(dpo_data_path) must contain: prompt, chosen, rejected
|
||||||
|
|
||||||
|
Adapter behavior:
|
||||||
|
- Loads the SFT adapter as frozen weights.
|
||||||
|
- Adds a NEW LoRA adapter for DPO and trains only that adapter.
|
||||||
|
"""
|
||||||
# SFT adapter input dir (from SFT run)
|
# SFT adapter input dir (from SFT run)
|
||||||
sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
||||||
|
|
||||||
@ -208,6 +244,10 @@ class DPOCfg:
|
|||||||
# Helpers
|
# Helpers
|
||||||
# ==========================
|
# ==========================
|
||||||
def get_local_rank_and_device():
|
def get_local_rank_and_device():
|
||||||
|
"""
|
||||||
|
Determines device placement based on LOCAL_RANK.
|
||||||
|
In DDP runs, each process binds itself to a separate GPU.
|
||||||
|
"""
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
@ -219,6 +259,9 @@ def get_local_rank_and_device():
|
|||||||
|
|
||||||
|
|
||||||
def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
|
def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
|
||||||
|
"""
|
||||||
|
Builds BitsAndBytesConfig for 4-bit QLoRA.
|
||||||
|
"""
|
||||||
dtype = getattr(torch, common.bnb_4bit_compute_dtype)
|
dtype = getattr(torch, common.bnb_4bit_compute_dtype)
|
||||||
return BitsAndBytesConfig(
|
return BitsAndBytesConfig(
|
||||||
load_in_4bit=common.load_in_4bit,
|
load_in_4bit=common.load_in_4bit,
|
||||||
@ -229,6 +272,9 @@ def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
|
|||||||
|
|
||||||
|
|
||||||
def load_tokenizer(common: CommonCfg):
|
def load_tokenizer(common: CommonCfg):
|
||||||
|
"""
|
||||||
|
Loads tokenizer from base_model_path and ensures pad_token is defined.
|
||||||
|
"""
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
common.base_model_path,
|
common.base_model_path,
|
||||||
use_fast=common.use_fast_tokenizer,
|
use_fast=common.use_fast_tokenizer,
|
||||||
@ -240,6 +286,14 @@ def load_tokenizer(common: CommonCfg):
|
|||||||
|
|
||||||
|
|
||||||
def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device):
|
def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device):
|
||||||
|
"""
|
||||||
|
Loads the base model with 4-bit quantization and prepares it for k-bit training.
|
||||||
|
|
||||||
|
Key settings:
|
||||||
|
- device_map binds the model to the current GPU under DDP.
|
||||||
|
- use_cache is disabled for training.
|
||||||
|
- gradient checkpointing is disabled here to avoid reentrant issues under DDP.
|
||||||
|
"""
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
common.base_model_path,
|
common.base_model_path,
|
||||||
quantization_config=bnb_config,
|
quantization_config=bnb_config,
|
||||||
@ -258,6 +312,9 @@ def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: t
|
|||||||
|
|
||||||
|
|
||||||
def save_manifest(out_dir: str, local_rank: int, manifest: dict):
|
def save_manifest(out_dir: str, local_rank: int, manifest: dict):
|
||||||
|
"""
|
||||||
|
Writes a JSON manifest for reproducibility (rank0 only).
|
||||||
|
"""
|
||||||
if local_rank != 0:
|
if local_rank != 0:
|
||||||
return
|
return
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
@ -269,6 +326,12 @@ def save_manifest(out_dir: str, local_rank: int, manifest: dict):
|
|||||||
# SFT
|
# SFT
|
||||||
# ==========================
|
# ==========================
|
||||||
def run_sft(common: CommonCfg, sft: SFTCfg):
|
def run_sft(common: CommonCfg, sft: SFTCfg):
|
||||||
|
"""
|
||||||
|
Runs SFT using TRL SFTTrainer with QLoRA adapter training.
|
||||||
|
|
||||||
|
Expected dataset:
|
||||||
|
- load_from_disk(sft.data_path) with a "text" field.
|
||||||
|
"""
|
||||||
local_rank, device = get_local_rank_and_device()
|
local_rank, device = get_local_rank_and_device()
|
||||||
os.makedirs(sft.out_dir, exist_ok=True)
|
os.makedirs(sft.out_dir, exist_ok=True)
|
||||||
|
|
||||||
@ -281,6 +344,7 @@ def run_sft(common: CommonCfg, sft: SFTCfg):
|
|||||||
bnb_config = make_bnb_config(common)
|
bnb_config = make_bnb_config(common)
|
||||||
model = load_4bit_model(common, bnb_config, device)
|
model = load_4bit_model(common, bnb_config, device)
|
||||||
|
|
||||||
|
# Trainable LoRA adapter for SFT stage
|
||||||
lora = LoraConfig(
|
lora = LoraConfig(
|
||||||
r=sft.lora_r,
|
r=sft.lora_r,
|
||||||
lora_alpha=sft.lora_alpha,
|
lora_alpha=sft.lora_alpha,
|
||||||
@ -357,7 +421,18 @@ def run_sft(common: CommonCfg, sft: SFTCfg):
|
|||||||
# DPO
|
# DPO
|
||||||
# ==========================
|
# ==========================
|
||||||
def run_dpo(common: CommonCfg, dpo: DPOCfg):
|
def run_dpo(common: CommonCfg, dpo: DPOCfg):
|
||||||
# Patches first (safe no-op if not needed)
|
"""
|
||||||
|
Runs DPO using TRL DPOTrainer with QLoRA adapters.
|
||||||
|
|
||||||
|
Adapter composition:
|
||||||
|
- Load base model (4-bit)
|
||||||
|
- Attach SFT adapter as frozen (is_trainable=False)
|
||||||
|
- Add a new LoRA adapter for DPO training (trainable)
|
||||||
|
|
||||||
|
Expected dataset:
|
||||||
|
- load_from_disk(dpo.dpo_data_path) with prompt/chosen/rejected columns
|
||||||
|
"""
|
||||||
|
# Apply patches early (no-op if already compatible)
|
||||||
patch_trl_dpo_trainer_if_needed()
|
patch_trl_dpo_trainer_if_needed()
|
||||||
|
|
||||||
local_rank, device = get_local_rank_and_device()
|
local_rank, device = get_local_rank_and_device()
|
||||||
@ -373,9 +448,10 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg):
|
|||||||
bnb_config = make_bnb_config(common)
|
bnb_config = make_bnb_config(common)
|
||||||
model = load_4bit_model(common, bnb_config, device)
|
model = load_4bit_model(common, bnb_config, device)
|
||||||
|
|
||||||
# Load SFT adapter (frozen) + add NEW DPO adapter (trainable)
|
# Load SFT adapter as frozen weights
|
||||||
model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False)
|
model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False)
|
||||||
|
|
||||||
|
# Trainable LoRA adapter for DPO stage
|
||||||
dpo_lora = LoraConfig(
|
dpo_lora = LoraConfig(
|
||||||
r=dpo.lora_r,
|
r=dpo.lora_r,
|
||||||
lora_alpha=dpo.lora_alpha,
|
lora_alpha=dpo.lora_alpha,
|
||||||
@ -450,6 +526,12 @@ def run_dpo(common: CommonCfg, dpo: DPOCfg):
|
|||||||
# CLI
|
# CLI
|
||||||
# ==========================
|
# ==========================
|
||||||
def build_parser():
|
def build_parser():
|
||||||
|
"""
|
||||||
|
Builds an argparse parser with subcommands:
|
||||||
|
- sft
|
||||||
|
- dpo
|
||||||
|
- both
|
||||||
|
"""
|
||||||
p = argparse.ArgumentParser(
|
p = argparse.ArgumentParser(
|
||||||
description="Combined QLoRA training: SFT and DPO (TRL).",
|
description="Combined QLoRA training: SFT and DPO (TRL).",
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
@ -457,6 +539,7 @@ def build_parser():
|
|||||||
sub = p.add_subparsers(dest="cmd", required=False)
|
sub = p.add_subparsers(dest="cmd", required=False)
|
||||||
|
|
||||||
def add_common_args(sp):
|
def add_common_args(sp):
|
||||||
|
# Shared base model/tokenizer arguments
|
||||||
sp.add_argument("--base_model_path", default=CommonCfg.base_model_path)
|
sp.add_argument("--base_model_path", default=CommonCfg.base_model_path)
|
||||||
sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"])
|
sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"])
|
||||||
sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer")
|
sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer")
|
||||||
@ -543,8 +626,7 @@ def build_parser():
|
|||||||
|
|
||||||
def interactive_menu() -> str:
|
def interactive_menu() -> str:
|
||||||
"""
|
"""
|
||||||
Simple interactive selector when the script is launched without CLI arguments.
|
Interactive selector used when the script is launched without CLI arguments.
|
||||||
Uses defaults for all paths/knobs. If you want to override anything, run with CLI flags.
|
|
||||||
"""
|
"""
|
||||||
print("\n=== Training menu ===")
|
print("\n=== Training menu ===")
|
||||||
print("1) SFT")
|
print("1) SFT")
|
||||||
@ -566,15 +648,19 @@ def interactive_menu() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
"""
|
||||||
|
Entry point:
|
||||||
|
- Parse CLI args (or use interactive menu when no args are provided)
|
||||||
|
- Create config objects
|
||||||
|
- Run the selected stage(s)
|
||||||
|
"""
|
||||||
parser = build_parser()
|
parser = build_parser()
|
||||||
|
|
||||||
# If user runs the script with no arguments, we show interactive menu
|
# If run without args: show menu and run with defaults
|
||||||
# and run with defaults.
|
|
||||||
import sys
|
import sys
|
||||||
if len(sys.argv) == 1:
|
if len(sys.argv) == 1:
|
||||||
# IMPORTANT for torchrun/DDP:
|
# Under DDP, avoid prompting on every process:
|
||||||
# In distributed mode, we must NOT prompt on every process.
|
# local_rank==0 writes the choice; other ranks read it.
|
||||||
# Rank0 asks, then writes the choice to a small file; other ranks read it.
|
|
||||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
|
||||||
@ -594,7 +680,7 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True)
|
print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True)
|
||||||
else:
|
else:
|
||||||
# Wait for rank0 to write the choice
|
# Poll for choice_file created by local_rank==0
|
||||||
deadline = time.time() + 600 # seconds
|
deadline = time.time() + 600 # seconds
|
||||||
cmd = None
|
cmd = None
|
||||||
while time.time() < deadline:
|
while time.time() < deadline:
|
||||||
@ -608,7 +694,7 @@ def main():
|
|||||||
pass
|
pass
|
||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
if not cmd:
|
if not cmd:
|
||||||
print("[ERROR] Timed out waiting for rank0 menu selection. "
|
print("[ERROR] Timed out waiting for menu selection. "
|
||||||
"Run with explicit subcommand: sft/dpo/both.", flush=True)
|
"Run with explicit subcommand: sft/dpo/both.", flush=True)
|
||||||
raise SystemExit(2)
|
raise SystemExit(2)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user