799 lines
28 KiB
Python
799 lines
28 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Combined QLoRA training script for:
|
|
1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer
|
|
2) DPO (Direct Preference Optimization) with TRL DPOTrainer
|
|
|
|
Usage examples (multi-GPU):
|
|
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py sft
|
|
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo
|
|
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both
|
|
|
|
All paths and hyperparameters can be overridden via CLI flags (see --help).
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import json
|
|
import inspect
|
|
import argparse
|
|
from dataclasses import dataclass, asdict
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from datasets import load_from_disk
|
|
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForCausalLM,
|
|
BitsAndBytesConfig,
|
|
Trainer as HFTrainer,
|
|
TrainerCallback,
|
|
)
|
|
|
|
from peft import (
|
|
PeftModel,
|
|
LoraConfig,
|
|
get_peft_model,
|
|
prepare_model_for_kbit_training,
|
|
)
|
|
|
|
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
|
|
|
|
|
|
# ==========================
|
|
# Compatibility patches
|
|
# ==========================
|
|
def patch_trl_dpo_trainer_if_needed():
|
|
"""
|
|
Applies runtime patches to handle TRL <-> Transformers API incompatibilities.
|
|
|
|
Why this exists:
|
|
- Some TRL versions implement DPOTrainer methods with older signatures.
|
|
- Transformers Trainer may call methods with a newer signature / extra kwargs.
|
|
- These mismatches can lead to runtime errors during training.
|
|
|
|
Patch set:
|
|
1) get_batch_samples signature mismatch:
|
|
- Transformers expects: get_batch_samples(epoch_iterator, num_batches)
|
|
- Some older TRL expects: get_batch_samples(model, batch, ...)
|
|
- Patch: route the call through HFTrainer.get_batch_samples.
|
|
|
|
2) compute_loss signature mismatch:
|
|
- Newer Transformers may pass num_items_in_batch
|
|
- Some TRL versions do not accept that kwarg
|
|
- Patch: wrap compute_loss to accept num_items_in_batch and delegate.
|
|
"""
|
|
# ---- Patch get_batch_samples ----
|
|
try:
|
|
sig = inspect.signature(DPOTrainer.get_batch_samples)
|
|
params = list(sig.parameters.keys()) # includes "self"
|
|
# Old TRL often has (self, model, batch, ...) so params[1] == "model"
|
|
if len(params) >= 3 and params[1] == "model":
|
|
print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True)
|
|
|
|
def _get_batch_samples(self, epoch_iterator, num_batches):
|
|
return HFTrainer.get_batch_samples(self, epoch_iterator, num_batches)
|
|
|
|
DPOTrainer.get_batch_samples = _get_batch_samples
|
|
else:
|
|
print("[INFO] DPOTrainer.get_batch_samples signature looks compatible; no patch needed", flush=True)
|
|
except Exception as e:
|
|
print(f"[WARN] Could not inspect/patch get_batch_samples: {e}", flush=True)
|
|
|
|
# ---- Patch compute_loss ----
|
|
try:
|
|
sig = inspect.signature(DPOTrainer.compute_loss)
|
|
if "num_items_in_batch" not in sig.parameters:
|
|
print("[PATCH] DPOTrainer.compute_loss lacks num_items_in_batch; patching wrapper", flush=True)
|
|
_orig_compute_loss = DPOTrainer.compute_loss
|
|
|
|
def _compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
|
return _orig_compute_loss(self, model, inputs, return_outputs=return_outputs)
|
|
|
|
DPOTrainer.compute_loss = _compute_loss
|
|
else:
|
|
print("[INFO] DPOTrainer.compute_loss signature looks compatible; no patch needed", flush=True)
|
|
except Exception as e:
|
|
print(f"[WARN] Could not inspect/patch compute_loss: {e}", flush=True)
|
|
|
|
|
|
# ==========================
|
|
# Callbacks
|
|
# ==========================
|
|
class ProgressETA(TrainerCallback):
|
|
"""
|
|
Trainer callback that prints progress with ETA on the main local process.
|
|
|
|
Printed metrics:
|
|
- global step / max steps
|
|
- steps per second
|
|
- ETA in minutes
|
|
- optionally loss and learning rate from the trainer logs
|
|
"""
|
|
def __init__(self):
|
|
self.t0 = None
|
|
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
|
|
def on_train_begin(self, args, state, control, **kwargs):
|
|
self.t0 = time.time()
|
|
if self.local_rank == 0:
|
|
print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True)
|
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
# Only print on local_rank==0 to reduce duplicated output under DDP
|
|
if self.local_rank != 0:
|
|
return
|
|
if not logs or not state.max_steps:
|
|
return
|
|
|
|
step = state.global_step
|
|
max_steps = state.max_steps
|
|
elapsed = time.time() - (self.t0 or time.time())
|
|
sps = step / max(elapsed, 1e-9)
|
|
rem = max_steps - step
|
|
eta = rem / max(sps, 1e-9)
|
|
|
|
msg = f"[PROGRESS] step {step}/{max_steps} | {sps:.3f} steps/s | ETA {eta/60:.1f} min"
|
|
if "loss" in logs:
|
|
msg += f" | loss {logs['loss']:.4f}"
|
|
if "learning_rate" in logs:
|
|
msg += f" | lr {logs['learning_rate']:.3e}"
|
|
print(msg, flush=True)
|
|
|
|
|
|
# ==========================
|
|
# Config dataclasses
|
|
# ==========================
|
|
@dataclass
|
|
class CommonCfg:
|
|
"""
|
|
Shared configuration for model/tokenizer/quantization used by both SFT and DPO.
|
|
"""
|
|
base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b"
|
|
target_modules: List[str] = None # set in __post_init__
|
|
|
|
# QLoRA quant config (bitsandbytes)
|
|
load_in_4bit: bool = True
|
|
bnb_4bit_compute_dtype: str = "float16"
|
|
bnb_4bit_use_double_quant: bool = True
|
|
bnb_4bit_quant_type: str = "nf4"
|
|
|
|
# Tokenizer loading
|
|
use_fast_tokenizer: bool = False
|
|
trust_remote_code: bool = True
|
|
|
|
def __post_init__(self):
|
|
# Default LoRA target modules for Mistral-style architectures
|
|
if self.target_modules is None:
|
|
self.target_modules = ["q_proj", "v_proj"]
|
|
|
|
|
|
@dataclass
|
|
class SFTCfg:
|
|
"""
|
|
SFT stage configuration.
|
|
|
|
Dataset requirement:
|
|
- load_from_disk(data_path) must contain a "text" column.
|
|
"""
|
|
data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text"
|
|
out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
|
|
|
max_seq_len: int = 1024
|
|
num_epochs: int = 1
|
|
per_device_bs: int = 1
|
|
grad_accum: int = 16
|
|
|
|
lr: float = 2e-4
|
|
warmup_ratio: float = 0.03
|
|
|
|
logging_steps: int = 10
|
|
save_steps: int = 200
|
|
save_total_limit: int = 2
|
|
|
|
# LoRA hyperparameters
|
|
lora_r: int = 16
|
|
lora_alpha: int = 32
|
|
lora_dropout: float = 0.05
|
|
|
|
|
|
@dataclass
|
|
class DPOCfg:
|
|
"""
|
|
DPO stage configuration.
|
|
|
|
Dataset requirement:
|
|
- load_from_disk(dpo_data_path) must contain: prompt, chosen, rejected
|
|
|
|
Adapter behavior:
|
|
- Loads the SFT adapter as frozen weights.
|
|
- Adds a NEW LoRA adapter for DPO and trains only that adapter.
|
|
"""
|
|
# SFT adapter input dir (from SFT run)
|
|
sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
|
|
|
dpo_data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_dpo" # must have prompt/chosen/rejected
|
|
out_dir: str = "/home/hyrenko/Diploma/outputs_dpo/mistral-sk-7b-pku-saferlhf-sk-dpo-qlora"
|
|
|
|
max_length: int = 1024
|
|
max_prompt_length: int = 512
|
|
|
|
num_epochs: int = 1
|
|
per_device_bs: int = 1
|
|
grad_accum: int = 16
|
|
|
|
lr: float = 1e-5
|
|
warmup_ratio: float = 0.03
|
|
|
|
logging_steps: int = 10
|
|
save_steps: int = 200
|
|
save_total_limit: int = 2
|
|
|
|
beta: float = 0.1
|
|
|
|
# NEW LoRA adapter for DPO (separate from SFT)
|
|
lora_r: int = 16
|
|
lora_alpha: int = 32
|
|
lora_dropout: float = 0.05
|
|
|
|
|
|
# ==========================
|
|
# Helpers
|
|
# ==========================
|
|
def get_local_rank_and_device():
|
|
"""
|
|
Determines device placement based on LOCAL_RANK.
|
|
In DDP runs, each process binds itself to a separate GPU.
|
|
"""
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device("cuda", local_rank)
|
|
else:
|
|
device = torch.device("cpu")
|
|
print(f"[INFO] local_rank={local_rank} device={device}", flush=True)
|
|
return local_rank, device
|
|
|
|
|
|
def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
|
|
"""
|
|
Builds BitsAndBytesConfig for 4-bit QLoRA.
|
|
"""
|
|
dtype = getattr(torch, common.bnb_4bit_compute_dtype)
|
|
return BitsAndBytesConfig(
|
|
load_in_4bit=common.load_in_4bit,
|
|
bnb_4bit_compute_dtype=dtype,
|
|
bnb_4bit_use_double_quant=common.bnb_4bit_use_double_quant,
|
|
bnb_4bit_quant_type=common.bnb_4bit_quant_type,
|
|
)
|
|
|
|
|
|
def load_tokenizer(common: CommonCfg):
|
|
"""
|
|
Loads tokenizer from base_model_path and ensures pad_token is defined.
|
|
"""
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
common.base_model_path,
|
|
use_fast=common.use_fast_tokenizer,
|
|
trust_remote_code=common.trust_remote_code,
|
|
)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
return tokenizer
|
|
|
|
|
|
def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device):
|
|
"""
|
|
Loads the base model with 4-bit quantization and prepares it for k-bit training.
|
|
|
|
Key settings:
|
|
- device_map binds the model to the current GPU under DDP.
|
|
- use_cache is disabled for training.
|
|
- gradient checkpointing is disabled here to avoid reentrant issues under DDP.
|
|
"""
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
common.base_model_path,
|
|
quantization_config=bnb_config,
|
|
device_map={"": torch.cuda.current_device()} if device.type == "cuda" else None,
|
|
torch_dtype=torch.float16,
|
|
trust_remote_code=common.trust_remote_code,
|
|
)
|
|
model.config.use_cache = False
|
|
|
|
# Avoid DDP + checkpointing reentrant issues
|
|
if hasattr(model, "gradient_checkpointing_disable"):
|
|
model.gradient_checkpointing_disable()
|
|
|
|
model = prepare_model_for_kbit_training(model)
|
|
return model
|
|
|
|
|
|
def save_manifest(out_dir: str, local_rank: int, manifest: dict):
|
|
"""
|
|
Writes a JSON manifest for reproducibility (rank0 only).
|
|
"""
|
|
if local_rank != 0:
|
|
return
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
with open(os.path.join(out_dir, "run_manifest.json"), "w", encoding="utf-8") as f:
|
|
json.dump(manifest, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
# ==========================
|
|
# SFT
|
|
# ==========================
|
|
def run_sft(common: CommonCfg, sft: SFTCfg):
|
|
"""
|
|
Runs SFT using TRL SFTTrainer with QLoRA adapter training.
|
|
|
|
Expected dataset:
|
|
- load_from_disk(sft.data_path) with a "text" field.
|
|
"""
|
|
local_rank, device = get_local_rank_and_device()
|
|
os.makedirs(sft.out_dir, exist_ok=True)
|
|
|
|
ds = load_from_disk(sft.data_path)
|
|
if "text" not in ds.column_names:
|
|
raise RuntimeError(f"[SFT] Dataset must have 'text' column. Found: {ds.column_names}")
|
|
print(f"[INFO] [SFT] Dataset rows: {len(ds)}", flush=True)
|
|
|
|
tokenizer = load_tokenizer(common)
|
|
bnb_config = make_bnb_config(common)
|
|
model = load_4bit_model(common, bnb_config, device)
|
|
|
|
# Trainable LoRA adapter for SFT stage
|
|
lora = LoraConfig(
|
|
r=sft.lora_r,
|
|
lora_alpha=sft.lora_alpha,
|
|
lora_dropout=sft.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=common.target_modules,
|
|
)
|
|
model = get_peft_model(model, lora)
|
|
|
|
if local_rank == 0:
|
|
model.print_trainable_parameters()
|
|
|
|
cfg = SFTConfig(
|
|
output_dir=sft.out_dir,
|
|
dataset_text_field="text",
|
|
max_seq_length=sft.max_seq_len,
|
|
packing=True,
|
|
|
|
num_train_epochs=sft.num_epochs,
|
|
per_device_train_batch_size=sft.per_device_bs,
|
|
gradient_accumulation_steps=sft.grad_accum,
|
|
|
|
learning_rate=sft.lr,
|
|
lr_scheduler_type="cosine",
|
|
warmup_ratio=sft.warmup_ratio,
|
|
weight_decay=0.0,
|
|
|
|
logging_steps=sft.logging_steps,
|
|
save_steps=sft.save_steps,
|
|
save_total_limit=sft.save_total_limit,
|
|
|
|
fp16=True,
|
|
bf16=False,
|
|
|
|
optim="paged_adamw_8bit",
|
|
report_to="none",
|
|
remove_unused_columns=False,
|
|
|
|
ddp_find_unused_parameters=False,
|
|
gradient_checkpointing=False,
|
|
)
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
train_dataset=ds,
|
|
args=cfg,
|
|
)
|
|
trainer.add_callback(ProgressETA())
|
|
|
|
save_manifest(
|
|
sft.out_dir,
|
|
local_rank,
|
|
{
|
|
"stage": "sft",
|
|
"common": asdict(common),
|
|
"sft": asdict(sft),
|
|
"quant": "4bit nf4 qlora",
|
|
"gradient_checkpointing": False,
|
|
},
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
if local_rank == 0:
|
|
trainer.save_model(sft.out_dir) # saves PEFT adapter
|
|
tokenizer.save_pretrained(sft.out_dir)
|
|
print(f"[OK] [SFT] Finished. Adapter saved to: {sft.out_dir}", flush=True)
|
|
print("[OK] [SFT] Base model was NOT overwritten; only LoRA adapter weights were trained/saved.", flush=True)
|
|
|
|
|
|
# ==========================
|
|
# DPO
|
|
# ==========================
|
|
def run_dpo(common: CommonCfg, dpo: DPOCfg):
|
|
"""
|
|
Runs DPO using TRL DPOTrainer with QLoRA adapters.
|
|
|
|
Adapter composition:
|
|
- Load base model (4-bit)
|
|
- Attach SFT adapter as frozen (is_trainable=False)
|
|
- Add a new LoRA adapter for DPO training (trainable)
|
|
|
|
Expected dataset:
|
|
- load_from_disk(dpo.dpo_data_path) with prompt/chosen/rejected columns
|
|
"""
|
|
# Apply patches early (no-op if already compatible)
|
|
patch_trl_dpo_trainer_if_needed()
|
|
|
|
local_rank, device = get_local_rank_and_device()
|
|
os.makedirs(dpo.out_dir, exist_ok=True)
|
|
|
|
ds = load_from_disk(dpo.dpo_data_path)
|
|
for c in ("prompt", "chosen", "rejected"):
|
|
if c not in ds.column_names:
|
|
raise RuntimeError(f"[DPO] Missing '{c}' in DPO dataset. Found: {ds.column_names}")
|
|
print(f"[INFO] [DPO] Dataset rows: {len(ds)}", flush=True)
|
|
|
|
tokenizer = load_tokenizer(common)
|
|
bnb_config = make_bnb_config(common)
|
|
model = load_4bit_model(common, bnb_config, device)
|
|
|
|
# Load SFT adapter as frozen weights
|
|
model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False)
|
|
|
|
# Trainable LoRA adapter for DPO stage
|
|
dpo_lora = LoraConfig(
|
|
r=dpo.lora_r,
|
|
lora_alpha=dpo.lora_alpha,
|
|
lora_dropout=dpo.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=common.target_modules,
|
|
)
|
|
model = get_peft_model(model, dpo_lora)
|
|
|
|
if local_rank == 0:
|
|
model.print_trainable_parameters()
|
|
|
|
cfg = DPOConfig(
|
|
output_dir=dpo.out_dir,
|
|
max_length=dpo.max_length,
|
|
max_prompt_length=dpo.max_prompt_length,
|
|
|
|
num_train_epochs=dpo.num_epochs,
|
|
per_device_train_batch_size=dpo.per_device_bs,
|
|
gradient_accumulation_steps=dpo.grad_accum,
|
|
|
|
learning_rate=dpo.lr,
|
|
lr_scheduler_type="cosine",
|
|
warmup_ratio=dpo.warmup_ratio,
|
|
|
|
logging_steps=dpo.logging_steps,
|
|
save_steps=dpo.save_steps,
|
|
save_total_limit=dpo.save_total_limit,
|
|
|
|
fp16=True,
|
|
bf16=False,
|
|
optim="paged_adamw_8bit",
|
|
report_to="none",
|
|
|
|
beta=dpo.beta,
|
|
|
|
ddp_find_unused_parameters=False,
|
|
gradient_checkpointing=False,
|
|
)
|
|
|
|
trainer = DPOTrainer(
|
|
model=model,
|
|
ref_model=None,
|
|
args=cfg,
|
|
train_dataset=ds,
|
|
tokenizer=tokenizer,
|
|
)
|
|
trainer.add_callback(ProgressETA())
|
|
|
|
save_manifest(
|
|
dpo.out_dir,
|
|
local_rank,
|
|
{
|
|
"stage": "dpo",
|
|
"common": asdict(common),
|
|
"dpo": asdict(dpo),
|
|
"notes": "Includes TRL/Transformers compatibility patches for get_batch_samples + compute_loss",
|
|
},
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
if local_rank == 0:
|
|
trainer.save_model(dpo.out_dir)
|
|
tokenizer.save_pretrained(dpo.out_dir)
|
|
print(f"[OK] [DPO] Finished. Adapter saved to: {dpo.out_dir}", flush=True)
|
|
print("[OK] [DPO] Base model was NOT overwritten; only DPO LoRA adapter weights were trained/saved.", flush=True)
|
|
|
|
|
|
# ==========================
|
|
# CLI
|
|
# ==========================
|
|
def build_parser():
|
|
"""
|
|
Builds an argparse parser with subcommands:
|
|
- sft
|
|
- dpo
|
|
- both
|
|
"""
|
|
p = argparse.ArgumentParser(
|
|
description="Combined QLoRA training: SFT and DPO (TRL).",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
sub = p.add_subparsers(dest="cmd", required=False)
|
|
|
|
def add_common_args(sp):
|
|
# Shared base model/tokenizer arguments
|
|
sp.add_argument("--base_model_path", default=CommonCfg.base_model_path)
|
|
sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"])
|
|
sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer")
|
|
sp.add_argument("--no_trust_remote_code", action="store_true", help="Disable trust_remote_code")
|
|
|
|
# ---- SFT ----
|
|
sft_p = sub.add_parser("sft", help="Run SFT stage only")
|
|
add_common_args(sft_p)
|
|
sft_p.add_argument("--sft_data_path", default=SFTCfg.data_path)
|
|
sft_p.add_argument("--sft_out_dir", default=SFTCfg.out_dir)
|
|
sft_p.add_argument("--sft_max_seq_len", type=int, default=SFTCfg.max_seq_len)
|
|
sft_p.add_argument("--sft_num_epochs", type=int, default=SFTCfg.num_epochs)
|
|
sft_p.add_argument("--sft_per_device_bs", type=int, default=SFTCfg.per_device_bs)
|
|
sft_p.add_argument("--sft_grad_accum", type=int, default=SFTCfg.grad_accum)
|
|
sft_p.add_argument("--sft_lr", type=float, default=SFTCfg.lr)
|
|
sft_p.add_argument("--sft_warmup_ratio", type=float, default=SFTCfg.warmup_ratio)
|
|
sft_p.add_argument("--sft_logging_steps", type=int, default=SFTCfg.logging_steps)
|
|
sft_p.add_argument("--sft_save_steps", type=int, default=SFTCfg.save_steps)
|
|
sft_p.add_argument("--sft_save_total_limit", type=int, default=SFTCfg.save_total_limit)
|
|
sft_p.add_argument("--sft_lora_r", type=int, default=SFTCfg.lora_r)
|
|
sft_p.add_argument("--sft_lora_alpha", type=int, default=SFTCfg.lora_alpha)
|
|
sft_p.add_argument("--sft_lora_dropout", type=float, default=SFTCfg.lora_dropout)
|
|
|
|
# ---- DPO ----
|
|
dpo_p = sub.add_parser("dpo", help="Run DPO stage only")
|
|
add_common_args(dpo_p)
|
|
dpo_p.add_argument("--sft_adapter_dir", default=DPOCfg.sft_adapter_dir)
|
|
dpo_p.add_argument("--dpo_data_path", default=DPOCfg.dpo_data_path)
|
|
dpo_p.add_argument("--dpo_out_dir", default=DPOCfg.out_dir)
|
|
dpo_p.add_argument("--dpo_max_length", type=int, default=DPOCfg.max_length)
|
|
dpo_p.add_argument("--dpo_max_prompt_length", type=int, default=DPOCfg.max_prompt_length)
|
|
dpo_p.add_argument("--dpo_num_epochs", type=int, default=DPOCfg.num_epochs)
|
|
dpo_p.add_argument("--dpo_per_device_bs", type=int, default=DPOCfg.per_device_bs)
|
|
dpo_p.add_argument("--dpo_grad_accum", type=int, default=DPOCfg.grad_accum)
|
|
dpo_p.add_argument("--dpo_lr", type=float, default=DPOCfg.lr)
|
|
dpo_p.add_argument("--dpo_warmup_ratio", type=float, default=DPOCfg.warmup_ratio)
|
|
dpo_p.add_argument("--dpo_logging_steps", type=int, default=DPOCfg.logging_steps)
|
|
dpo_p.add_argument("--dpo_save_steps", type=int, default=DPOCfg.save_steps)
|
|
dpo_p.add_argument("--dpo_save_total_limit", type=int, default=DPOCfg.save_total_limit)
|
|
dpo_p.add_argument("--dpo_beta", type=float, default=DPOCfg.beta)
|
|
dpo_p.add_argument("--dpo_lora_r", type=int, default=DPOCfg.lora_r)
|
|
dpo_p.add_argument("--dpo_lora_alpha", type=int, default=DPOCfg.lora_alpha)
|
|
dpo_p.add_argument("--dpo_lora_dropout", type=float, default=DPOCfg.lora_dropout)
|
|
|
|
# ---- BOTH ----
|
|
both_p = sub.add_parser("both", help="Run SFT then DPO (uses SFT out dir as DPO sft_adapter_dir unless overridden)")
|
|
add_common_args(both_p)
|
|
# SFT args subset
|
|
both_p.add_argument("--sft_data_path", default=SFTCfg.data_path)
|
|
both_p.add_argument("--sft_out_dir", default=SFTCfg.out_dir)
|
|
both_p.add_argument("--sft_max_seq_len", type=int, default=SFTCfg.max_seq_len)
|
|
both_p.add_argument("--sft_num_epochs", type=int, default=SFTCfg.num_epochs)
|
|
both_p.add_argument("--sft_per_device_bs", type=int, default=SFTCfg.per_device_bs)
|
|
both_p.add_argument("--sft_grad_accum", type=int, default=SFTCfg.grad_accum)
|
|
both_p.add_argument("--sft_lr", type=float, default=SFTCfg.lr)
|
|
both_p.add_argument("--sft_warmup_ratio", type=float, default=SFTCfg.warmup_ratio)
|
|
both_p.add_argument("--sft_logging_steps", type=int, default=SFTCfg.logging_steps)
|
|
both_p.add_argument("--sft_save_steps", type=int, default=SFTCfg.save_steps)
|
|
both_p.add_argument("--sft_save_total_limit", type=int, default=SFTCfg.save_total_limit)
|
|
both_p.add_argument("--sft_lora_r", type=int, default=SFTCfg.lora_r)
|
|
both_p.add_argument("--sft_lora_alpha", type=int, default=SFTCfg.lora_alpha)
|
|
both_p.add_argument("--sft_lora_dropout", type=float, default=SFTCfg.lora_dropout)
|
|
# DPO args subset
|
|
both_p.add_argument("--dpo_data_path", default=DPOCfg.dpo_data_path)
|
|
both_p.add_argument("--dpo_out_dir", default=DPOCfg.out_dir)
|
|
both_p.add_argument("--dpo_max_length", type=int, default=DPOCfg.max_length)
|
|
both_p.add_argument("--dpo_max_prompt_length", type=int, default=DPOCfg.max_prompt_length)
|
|
both_p.add_argument("--dpo_num_epochs", type=int, default=DPOCfg.num_epochs)
|
|
both_p.add_argument("--dpo_per_device_bs", type=int, default=DPOCfg.per_device_bs)
|
|
both_p.add_argument("--dpo_grad_accum", type=int, default=DPOCfg.grad_accum)
|
|
both_p.add_argument("--dpo_lr", type=float, default=DPOCfg.lr)
|
|
both_p.add_argument("--dpo_warmup_ratio", type=float, default=DPOCfg.warmup_ratio)
|
|
both_p.add_argument("--dpo_logging_steps", type=int, default=DPOCfg.logging_steps)
|
|
both_p.add_argument("--dpo_save_steps", type=int, default=DPOCfg.save_steps)
|
|
both_p.add_argument("--dpo_save_total_limit", type=int, default=DPOCfg.save_total_limit)
|
|
both_p.add_argument("--dpo_beta", type=float, default=DPOCfg.beta)
|
|
both_p.add_argument("--dpo_lora_r", type=int, default=DPOCfg.lora_r)
|
|
both_p.add_argument("--dpo_lora_alpha", type=int, default=DPOCfg.lora_alpha)
|
|
both_p.add_argument("--dpo_lora_dropout", type=float, default=DPOCfg.lora_dropout)
|
|
both_p.add_argument("--sft_adapter_dir", default="", help="If set, use this as DPO SFT adapter dir; otherwise uses --sft_out_dir")
|
|
|
|
return p
|
|
|
|
|
|
def interactive_menu() -> str:
|
|
"""
|
|
Interactive selector used when the script is launched without CLI arguments.
|
|
"""
|
|
print("\n=== Training menu ===")
|
|
print("1) SFT")
|
|
print("2) DPO")
|
|
print("3) SFT + DPO (both)")
|
|
print("q) Quit\n")
|
|
|
|
while True:
|
|
choice = input("Select an option [1/2/3/q]: ").strip().lower()
|
|
if choice in ("1", "sft"):
|
|
return "sft"
|
|
if choice in ("2", "dpo"):
|
|
return "dpo"
|
|
if choice in ("3", "both", "all"):
|
|
return "both"
|
|
if choice in ("q", "quit", "exit"):
|
|
raise SystemExit(0)
|
|
print("Invalid choice. Please enter 1, 2, 3, or q.")
|
|
|
|
|
|
def main():
|
|
"""
|
|
Entry point:
|
|
- Parse CLI args (or use interactive menu when no args are provided)
|
|
- Create config objects
|
|
- Run the selected stage(s)
|
|
"""
|
|
parser = build_parser()
|
|
|
|
# If run without args: show menu and run with defaults
|
|
import sys
|
|
if len(sys.argv) == 1:
|
|
# Under DDP, avoid prompting on every process:
|
|
# local_rank==0 writes the choice; other ranks read it.
|
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
|
|
if world_size == 1:
|
|
cmd = interactive_menu()
|
|
else:
|
|
choice_file = os.path.join(
|
|
os.getcwd(),
|
|
f".train_choice_{os.environ.get('MASTER_PORT', '0')}_{os.environ.get('RUN_ID', '0')}.json",
|
|
)
|
|
|
|
if local_rank == 0:
|
|
cmd = interactive_menu()
|
|
try:
|
|
with open(choice_file, "w", encoding="utf-8") as f:
|
|
json.dump({"cmd": cmd}, f)
|
|
except Exception as e:
|
|
print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True)
|
|
else:
|
|
# Poll for choice_file created by local_rank==0
|
|
deadline = time.time() + 600 # seconds
|
|
cmd = None
|
|
while time.time() < deadline:
|
|
if os.path.exists(choice_file):
|
|
try:
|
|
with open(choice_file, "r", encoding="utf-8") as f:
|
|
cmd = json.load(f).get("cmd")
|
|
if cmd:
|
|
break
|
|
except Exception:
|
|
pass
|
|
time.sleep(0.2)
|
|
if not cmd:
|
|
print("[ERROR] Timed out waiting for menu selection. "
|
|
"Run with explicit subcommand: sft/dpo/both.", flush=True)
|
|
raise SystemExit(2)
|
|
|
|
args = parser.parse_args([cmd]) # defaults
|
|
else:
|
|
args = parser.parse_args()
|
|
if not getattr(args, "cmd", None):
|
|
parser.print_help()
|
|
raise SystemExit(2)
|
|
|
|
common = CommonCfg(
|
|
base_model_path=args.base_model_path,
|
|
target_modules=args.target_modules,
|
|
use_fast_tokenizer=args.use_fast_tokenizer,
|
|
trust_remote_code=not args.no_trust_remote_code,
|
|
)
|
|
|
|
if args.cmd == "sft":
|
|
sft = SFTCfg(
|
|
data_path=args.sft_data_path,
|
|
out_dir=args.sft_out_dir,
|
|
max_seq_len=args.sft_max_seq_len,
|
|
num_epochs=args.sft_num_epochs,
|
|
per_device_bs=args.sft_per_device_bs,
|
|
grad_accum=args.sft_grad_accum,
|
|
lr=args.sft_lr,
|
|
warmup_ratio=args.sft_warmup_ratio,
|
|
logging_steps=args.sft_logging_steps,
|
|
save_steps=args.sft_save_steps,
|
|
save_total_limit=args.sft_save_total_limit,
|
|
lora_r=args.sft_lora_r,
|
|
lora_alpha=args.sft_lora_alpha,
|
|
lora_dropout=args.sft_lora_dropout,
|
|
)
|
|
run_sft(common, sft)
|
|
|
|
elif args.cmd == "dpo":
|
|
dpo = DPOCfg(
|
|
sft_adapter_dir=args.sft_adapter_dir,
|
|
dpo_data_path=args.dpo_data_path,
|
|
out_dir=args.dpo_out_dir,
|
|
max_length=args.dpo_max_length,
|
|
max_prompt_length=args.dpo_max_prompt_length,
|
|
num_epochs=args.dpo_num_epochs,
|
|
per_device_bs=args.dpo_per_device_bs,
|
|
grad_accum=args.dpo_grad_accum,
|
|
lr=args.dpo_lr,
|
|
warmup_ratio=args.dpo_warmup_ratio,
|
|
logging_steps=args.dpo_logging_steps,
|
|
save_steps=args.dpo_save_steps,
|
|
save_total_limit=args.dpo_save_total_limit,
|
|
beta=args.dpo_beta,
|
|
lora_r=args.dpo_lora_r,
|
|
lora_alpha=args.dpo_lora_alpha,
|
|
lora_dropout=args.dpo_lora_dropout,
|
|
)
|
|
run_dpo(common, dpo)
|
|
|
|
elif args.cmd == "both":
|
|
sft = SFTCfg(
|
|
data_path=args.sft_data_path,
|
|
out_dir=args.sft_out_dir,
|
|
max_seq_len=args.sft_max_seq_len,
|
|
num_epochs=args.sft_num_epochs,
|
|
per_device_bs=args.sft_per_device_bs,
|
|
grad_accum=args.sft_grad_accum,
|
|
lr=args.sft_lr,
|
|
warmup_ratio=args.sft_warmup_ratio,
|
|
logging_steps=args.sft_logging_steps,
|
|
save_steps=args.sft_save_steps,
|
|
save_total_limit=args.sft_save_total_limit,
|
|
lora_r=args.sft_lora_r,
|
|
lora_alpha=args.sft_lora_alpha,
|
|
lora_dropout=args.sft_lora_dropout,
|
|
)
|
|
run_sft(common, sft)
|
|
|
|
sft_adapter_dir = args.sft_adapter_dir.strip() or sft.out_dir
|
|
dpo = DPOCfg(
|
|
sft_adapter_dir=sft_adapter_dir,
|
|
dpo_data_path=args.dpo_data_path,
|
|
out_dir=args.dpo_out_dir,
|
|
max_length=args.dpo_max_length,
|
|
max_prompt_length=args.dpo_max_prompt_length,
|
|
num_epochs=args.dpo_num_epochs,
|
|
per_device_bs=args.dpo_per_device_bs,
|
|
grad_accum=args.dpo_grad_accum,
|
|
lr=args.dpo_lr,
|
|
warmup_ratio=args.dpo_warmup_ratio,
|
|
logging_steps=args.dpo_logging_steps,
|
|
save_steps=args.dpo_save_steps,
|
|
save_total_limit=args.dpo_save_total_limit,
|
|
beta=args.dpo_beta,
|
|
lora_r=args.dpo_lora_r,
|
|
lora_alpha=args.dpo_lora_alpha,
|
|
lora_dropout=args.dpo_lora_dropout,
|
|
)
|
|
run_dpo(common, dpo)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|