713 lines
26 KiB
Python
713 lines
26 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
Combined QLoRA training script for:
|
|
1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer
|
|
2) DPO (Direct Preference Optimization) with TRL DPOTrainer
|
|
|
|
Usage examples (multi-GPU):
|
|
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py sft
|
|
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo
|
|
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both
|
|
|
|
You can override any path/knob via CLI flags (see --help).
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import json
|
|
import inspect
|
|
import argparse
|
|
from dataclasses import dataclass, asdict
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from datasets import load_from_disk
|
|
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForCausalLM,
|
|
BitsAndBytesConfig,
|
|
Trainer as HFTrainer,
|
|
TrainerCallback,
|
|
)
|
|
|
|
from peft import (
|
|
PeftModel,
|
|
LoraConfig,
|
|
get_peft_model,
|
|
prepare_model_for_kbit_training,
|
|
)
|
|
|
|
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
|
|
|
|
|
|
# ==========================
|
|
# Compatibility patches
|
|
# ==========================
|
|
def patch_trl_dpo_trainer_if_needed():
|
|
"""
|
|
Fix TRL<->Transformers API mismatches.
|
|
|
|
1) get_batch_samples signature mismatch:
|
|
Transformers Trainer calls get_batch_samples(epoch_iterator, num_batches)
|
|
Older TRL DPOTrainer defines get_batch_samples(model, batch)
|
|
=> 'generator' object has no attribute 'generate'
|
|
|
|
Patch: delegate to HFTrainer.get_batch_samples.
|
|
|
|
2) compute_loss signature mismatch:
|
|
Newer Transformers passes num_items_in_batch kwarg.
|
|
Patch wrapper to accept it if missing.
|
|
"""
|
|
# ---- Patch get_batch_samples ----
|
|
try:
|
|
sig = inspect.signature(DPOTrainer.get_batch_samples)
|
|
params = list(sig.parameters.keys()) # includes "self"
|
|
# Old TRL: (self, model, batch, ...) -> params[1] == "model"
|
|
if len(params) >= 3 and params[1] == "model":
|
|
print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True)
|
|
|
|
def _get_batch_samples(self, epoch_iterator, num_batches):
|
|
return HFTrainer.get_batch_samples(self, epoch_iterator, num_batches)
|
|
|
|
DPOTrainer.get_batch_samples = _get_batch_samples
|
|
else:
|
|
print("[INFO] DPOTrainer.get_batch_samples signature looks compatible; no patch needed", flush=True)
|
|
except Exception as e:
|
|
print(f"[WARN] Could not inspect/patch get_batch_samples: {e}", flush=True)
|
|
|
|
# ---- Patch compute_loss ----
|
|
try:
|
|
sig = inspect.signature(DPOTrainer.compute_loss)
|
|
if "num_items_in_batch" not in sig.parameters:
|
|
print("[PATCH] DPOTrainer.compute_loss lacks num_items_in_batch; patching wrapper", flush=True)
|
|
_orig_compute_loss = DPOTrainer.compute_loss
|
|
|
|
def _compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
|
return _orig_compute_loss(self, model, inputs, return_outputs=return_outputs)
|
|
|
|
DPOTrainer.compute_loss = _compute_loss
|
|
else:
|
|
print("[INFO] DPOTrainer.compute_loss signature looks compatible; no patch needed", flush=True)
|
|
except Exception as e:
|
|
print(f"[WARN] Could not inspect/patch compute_loss: {e}", flush=True)
|
|
|
|
|
|
# ==========================
|
|
# Callbacks
|
|
# ==========================
|
|
class ProgressETA(TrainerCallback):
|
|
def __init__(self):
|
|
self.t0 = None
|
|
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
|
|
def on_train_begin(self, args, state, control, **kwargs):
|
|
self.t0 = time.time()
|
|
if self.local_rank == 0:
|
|
print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True)
|
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
if self.local_rank != 0:
|
|
return
|
|
if not logs or not state.max_steps:
|
|
return
|
|
|
|
step = state.global_step
|
|
max_steps = state.max_steps
|
|
elapsed = time.time() - (self.t0 or time.time())
|
|
sps = step / max(elapsed, 1e-9)
|
|
rem = max_steps - step
|
|
eta = rem / max(sps, 1e-9)
|
|
|
|
msg = f"[PROGRESS] step {step}/{max_steps} | {sps:.3f} steps/s | ETA {eta/60:.1f} min"
|
|
if "loss" in logs:
|
|
msg += f" | loss {logs['loss']:.4f}"
|
|
if "learning_rate" in logs:
|
|
msg += f" | lr {logs['learning_rate']:.3e}"
|
|
print(msg, flush=True)
|
|
|
|
|
|
# ==========================
|
|
# Config dataclasses
|
|
# ==========================
|
|
@dataclass
|
|
class CommonCfg:
|
|
base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b"
|
|
target_modules: List[str] = None # set in __post_init__
|
|
|
|
# QLoRA quant config
|
|
load_in_4bit: bool = True
|
|
bnb_4bit_compute_dtype: str = "float16"
|
|
bnb_4bit_use_double_quant: bool = True
|
|
bnb_4bit_quant_type: str = "nf4"
|
|
|
|
# Tokenizer
|
|
use_fast_tokenizer: bool = False
|
|
trust_remote_code: bool = True
|
|
|
|
def __post_init__(self):
|
|
if self.target_modules is None:
|
|
self.target_modules = ["q_proj", "v_proj"]
|
|
|
|
|
|
@dataclass
|
|
class SFTCfg:
|
|
data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text"
|
|
out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
|
|
|
max_seq_len: int = 1024
|
|
num_epochs: int = 1
|
|
per_device_bs: int = 1
|
|
grad_accum: int = 16
|
|
|
|
lr: float = 2e-4
|
|
warmup_ratio: float = 0.03
|
|
|
|
logging_steps: int = 10
|
|
save_steps: int = 200
|
|
save_total_limit: int = 2
|
|
|
|
# LoRA
|
|
lora_r: int = 16
|
|
lora_alpha: int = 32
|
|
lora_dropout: float = 0.05
|
|
|
|
|
|
@dataclass
|
|
class DPOCfg:
|
|
# SFT adapter input dir (from SFT run)
|
|
sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
|
|
|
|
dpo_data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_dpo" # must have prompt/chosen/rejected
|
|
out_dir: str = "/home/hyrenko/Diploma/outputs_dpo/mistral-sk-7b-pku-saferlhf-sk-dpo-qlora"
|
|
|
|
max_length: int = 1024
|
|
max_prompt_length: int = 512
|
|
|
|
num_epochs: int = 1
|
|
per_device_bs: int = 1
|
|
grad_accum: int = 16
|
|
|
|
lr: float = 1e-5
|
|
warmup_ratio: float = 0.03
|
|
|
|
logging_steps: int = 10
|
|
save_steps: int = 200
|
|
save_total_limit: int = 2
|
|
|
|
beta: float = 0.1
|
|
|
|
# NEW LoRA adapter for DPO (separate from SFT)
|
|
lora_r: int = 16
|
|
lora_alpha: int = 32
|
|
lora_dropout: float = 0.05
|
|
|
|
|
|
# ==========================
|
|
# Helpers
|
|
# ==========================
|
|
def get_local_rank_and_device():
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
if torch.cuda.is_available():
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device("cuda", local_rank)
|
|
else:
|
|
device = torch.device("cpu")
|
|
print(f"[INFO] local_rank={local_rank} device={device}", flush=True)
|
|
return local_rank, device
|
|
|
|
|
|
def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
|
|
dtype = getattr(torch, common.bnb_4bit_compute_dtype)
|
|
return BitsAndBytesConfig(
|
|
load_in_4bit=common.load_in_4bit,
|
|
bnb_4bit_compute_dtype=dtype,
|
|
bnb_4bit_use_double_quant=common.bnb_4bit_use_double_quant,
|
|
bnb_4bit_quant_type=common.bnb_4bit_quant_type,
|
|
)
|
|
|
|
|
|
def load_tokenizer(common: CommonCfg):
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
common.base_model_path,
|
|
use_fast=common.use_fast_tokenizer,
|
|
trust_remote_code=common.trust_remote_code,
|
|
)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
return tokenizer
|
|
|
|
|
|
def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device):
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
common.base_model_path,
|
|
quantization_config=bnb_config,
|
|
device_map={"": torch.cuda.current_device()} if device.type == "cuda" else None,
|
|
torch_dtype=torch.float16,
|
|
trust_remote_code=common.trust_remote_code,
|
|
)
|
|
model.config.use_cache = False
|
|
|
|
# Avoid DDP + checkpointing reentrant issues
|
|
if hasattr(model, "gradient_checkpointing_disable"):
|
|
model.gradient_checkpointing_disable()
|
|
|
|
model = prepare_model_for_kbit_training(model)
|
|
return model
|
|
|
|
|
|
def save_manifest(out_dir: str, local_rank: int, manifest: dict):
|
|
if local_rank != 0:
|
|
return
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
with open(os.path.join(out_dir, "run_manifest.json"), "w", encoding="utf-8") as f:
|
|
json.dump(manifest, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
# ==========================
|
|
# SFT
|
|
# ==========================
|
|
def run_sft(common: CommonCfg, sft: SFTCfg):
|
|
local_rank, device = get_local_rank_and_device()
|
|
os.makedirs(sft.out_dir, exist_ok=True)
|
|
|
|
ds = load_from_disk(sft.data_path)
|
|
if "text" not in ds.column_names:
|
|
raise RuntimeError(f"[SFT] Dataset must have 'text' column. Found: {ds.column_names}")
|
|
print(f"[INFO] [SFT] Dataset rows: {len(ds)}", flush=True)
|
|
|
|
tokenizer = load_tokenizer(common)
|
|
bnb_config = make_bnb_config(common)
|
|
model = load_4bit_model(common, bnb_config, device)
|
|
|
|
lora = LoraConfig(
|
|
r=sft.lora_r,
|
|
lora_alpha=sft.lora_alpha,
|
|
lora_dropout=sft.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=common.target_modules,
|
|
)
|
|
model = get_peft_model(model, lora)
|
|
|
|
if local_rank == 0:
|
|
model.print_trainable_parameters()
|
|
|
|
cfg = SFTConfig(
|
|
output_dir=sft.out_dir,
|
|
dataset_text_field="text",
|
|
max_seq_length=sft.max_seq_len,
|
|
packing=True,
|
|
|
|
num_train_epochs=sft.num_epochs,
|
|
per_device_train_batch_size=sft.per_device_bs,
|
|
gradient_accumulation_steps=sft.grad_accum,
|
|
|
|
learning_rate=sft.lr,
|
|
lr_scheduler_type="cosine",
|
|
warmup_ratio=sft.warmup_ratio,
|
|
weight_decay=0.0,
|
|
|
|
logging_steps=sft.logging_steps,
|
|
save_steps=sft.save_steps,
|
|
save_total_limit=sft.save_total_limit,
|
|
|
|
fp16=True,
|
|
bf16=False,
|
|
|
|
optim="paged_adamw_8bit",
|
|
report_to="none",
|
|
remove_unused_columns=False,
|
|
|
|
ddp_find_unused_parameters=False,
|
|
gradient_checkpointing=False,
|
|
)
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
train_dataset=ds,
|
|
args=cfg,
|
|
)
|
|
trainer.add_callback(ProgressETA())
|
|
|
|
save_manifest(
|
|
sft.out_dir,
|
|
local_rank,
|
|
{
|
|
"stage": "sft",
|
|
"common": asdict(common),
|
|
"sft": asdict(sft),
|
|
"quant": "4bit nf4 qlora",
|
|
"gradient_checkpointing": False,
|
|
},
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
if local_rank == 0:
|
|
trainer.save_model(sft.out_dir) # saves PEFT adapter
|
|
tokenizer.save_pretrained(sft.out_dir)
|
|
print(f"[OK] [SFT] Finished. Adapter saved to: {sft.out_dir}", flush=True)
|
|
print("[OK] [SFT] Base model was NOT overwritten; only LoRA adapter weights were trained/saved.", flush=True)
|
|
|
|
|
|
# ==========================
|
|
# DPO
|
|
# ==========================
|
|
def run_dpo(common: CommonCfg, dpo: DPOCfg):
|
|
# Patches first (safe no-op if not needed)
|
|
patch_trl_dpo_trainer_if_needed()
|
|
|
|
local_rank, device = get_local_rank_and_device()
|
|
os.makedirs(dpo.out_dir, exist_ok=True)
|
|
|
|
ds = load_from_disk(dpo.dpo_data_path)
|
|
for c in ("prompt", "chosen", "rejected"):
|
|
if c not in ds.column_names:
|
|
raise RuntimeError(f"[DPO] Missing '{c}' in DPO dataset. Found: {ds.column_names}")
|
|
print(f"[INFO] [DPO] Dataset rows: {len(ds)}", flush=True)
|
|
|
|
tokenizer = load_tokenizer(common)
|
|
bnb_config = make_bnb_config(common)
|
|
model = load_4bit_model(common, bnb_config, device)
|
|
|
|
# Load SFT adapter (frozen) + add NEW DPO adapter (trainable)
|
|
model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False)
|
|
|
|
dpo_lora = LoraConfig(
|
|
r=dpo.lora_r,
|
|
lora_alpha=dpo.lora_alpha,
|
|
lora_dropout=dpo.lora_dropout,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
target_modules=common.target_modules,
|
|
)
|
|
model = get_peft_model(model, dpo_lora)
|
|
|
|
if local_rank == 0:
|
|
model.print_trainable_parameters()
|
|
|
|
cfg = DPOConfig(
|
|
output_dir=dpo.out_dir,
|
|
max_length=dpo.max_length,
|
|
max_prompt_length=dpo.max_prompt_length,
|
|
|
|
num_train_epochs=dpo.num_epochs,
|
|
per_device_train_batch_size=dpo.per_device_bs,
|
|
gradient_accumulation_steps=dpo.grad_accum,
|
|
|
|
learning_rate=dpo.lr,
|
|
lr_scheduler_type="cosine",
|
|
warmup_ratio=dpo.warmup_ratio,
|
|
|
|
logging_steps=dpo.logging_steps,
|
|
save_steps=dpo.save_steps,
|
|
save_total_limit=dpo.save_total_limit,
|
|
|
|
fp16=True,
|
|
bf16=False,
|
|
optim="paged_adamw_8bit",
|
|
report_to="none",
|
|
|
|
beta=dpo.beta,
|
|
|
|
ddp_find_unused_parameters=False,
|
|
gradient_checkpointing=False,
|
|
)
|
|
|
|
trainer = DPOTrainer(
|
|
model=model,
|
|
ref_model=None,
|
|
args=cfg,
|
|
train_dataset=ds,
|
|
tokenizer=tokenizer,
|
|
)
|
|
trainer.add_callback(ProgressETA())
|
|
|
|
save_manifest(
|
|
dpo.out_dir,
|
|
local_rank,
|
|
{
|
|
"stage": "dpo",
|
|
"common": asdict(common),
|
|
"dpo": asdict(dpo),
|
|
"notes": "Includes TRL/Transformers compatibility patches for get_batch_samples + compute_loss",
|
|
},
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
if local_rank == 0:
|
|
trainer.save_model(dpo.out_dir)
|
|
tokenizer.save_pretrained(dpo.out_dir)
|
|
print(f"[OK] [DPO] Finished. Adapter saved to: {dpo.out_dir}", flush=True)
|
|
print("[OK] [DPO] Base model was NOT overwritten; only DPO LoRA adapter weights were trained/saved.", flush=True)
|
|
|
|
|
|
# ==========================
|
|
# CLI
|
|
# ==========================
|
|
def build_parser():
|
|
p = argparse.ArgumentParser(
|
|
description="Combined QLoRA training: SFT and DPO (TRL).",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
sub = p.add_subparsers(dest="cmd", required=False)
|
|
|
|
def add_common_args(sp):
|
|
sp.add_argument("--base_model_path", default=CommonCfg.base_model_path)
|
|
sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"])
|
|
sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer")
|
|
sp.add_argument("--no_trust_remote_code", action="store_true", help="Disable trust_remote_code")
|
|
|
|
# ---- SFT ----
|
|
sft_p = sub.add_parser("sft", help="Run SFT stage only")
|
|
add_common_args(sft_p)
|
|
sft_p.add_argument("--sft_data_path", default=SFTCfg.data_path)
|
|
sft_p.add_argument("--sft_out_dir", default=SFTCfg.out_dir)
|
|
sft_p.add_argument("--sft_max_seq_len", type=int, default=SFTCfg.max_seq_len)
|
|
sft_p.add_argument("--sft_num_epochs", type=int, default=SFTCfg.num_epochs)
|
|
sft_p.add_argument("--sft_per_device_bs", type=int, default=SFTCfg.per_device_bs)
|
|
sft_p.add_argument("--sft_grad_accum", type=int, default=SFTCfg.grad_accum)
|
|
sft_p.add_argument("--sft_lr", type=float, default=SFTCfg.lr)
|
|
sft_p.add_argument("--sft_warmup_ratio", type=float, default=SFTCfg.warmup_ratio)
|
|
sft_p.add_argument("--sft_logging_steps", type=int, default=SFTCfg.logging_steps)
|
|
sft_p.add_argument("--sft_save_steps", type=int, default=SFTCfg.save_steps)
|
|
sft_p.add_argument("--sft_save_total_limit", type=int, default=SFTCfg.save_total_limit)
|
|
sft_p.add_argument("--sft_lora_r", type=int, default=SFTCfg.lora_r)
|
|
sft_p.add_argument("--sft_lora_alpha", type=int, default=SFTCfg.lora_alpha)
|
|
sft_p.add_argument("--sft_lora_dropout", type=float, default=SFTCfg.lora_dropout)
|
|
|
|
# ---- DPO ----
|
|
dpo_p = sub.add_parser("dpo", help="Run DPO stage only")
|
|
add_common_args(dpo_p)
|
|
dpo_p.add_argument("--sft_adapter_dir", default=DPOCfg.sft_adapter_dir)
|
|
dpo_p.add_argument("--dpo_data_path", default=DPOCfg.dpo_data_path)
|
|
dpo_p.add_argument("--dpo_out_dir", default=DPOCfg.out_dir)
|
|
dpo_p.add_argument("--dpo_max_length", type=int, default=DPOCfg.max_length)
|
|
dpo_p.add_argument("--dpo_max_prompt_length", type=int, default=DPOCfg.max_prompt_length)
|
|
dpo_p.add_argument("--dpo_num_epochs", type=int, default=DPOCfg.num_epochs)
|
|
dpo_p.add_argument("--dpo_per_device_bs", type=int, default=DPOCfg.per_device_bs)
|
|
dpo_p.add_argument("--dpo_grad_accum", type=int, default=DPOCfg.grad_accum)
|
|
dpo_p.add_argument("--dpo_lr", type=float, default=DPOCfg.lr)
|
|
dpo_p.add_argument("--dpo_warmup_ratio", type=float, default=DPOCfg.warmup_ratio)
|
|
dpo_p.add_argument("--dpo_logging_steps", type=int, default=DPOCfg.logging_steps)
|
|
dpo_p.add_argument("--dpo_save_steps", type=int, default=DPOCfg.save_steps)
|
|
dpo_p.add_argument("--dpo_save_total_limit", type=int, default=DPOCfg.save_total_limit)
|
|
dpo_p.add_argument("--dpo_beta", type=float, default=DPOCfg.beta)
|
|
dpo_p.add_argument("--dpo_lora_r", type=int, default=DPOCfg.lora_r)
|
|
dpo_p.add_argument("--dpo_lora_alpha", type=int, default=DPOCfg.lora_alpha)
|
|
dpo_p.add_argument("--dpo_lora_dropout", type=float, default=DPOCfg.lora_dropout)
|
|
|
|
# ---- BOTH ----
|
|
both_p = sub.add_parser("both", help="Run SFT then DPO (uses SFT out dir as DPO sft_adapter_dir unless overridden)")
|
|
add_common_args(both_p)
|
|
# SFT args subset
|
|
both_p.add_argument("--sft_data_path", default=SFTCfg.data_path)
|
|
both_p.add_argument("--sft_out_dir", default=SFTCfg.out_dir)
|
|
both_p.add_argument("--sft_max_seq_len", type=int, default=SFTCfg.max_seq_len)
|
|
both_p.add_argument("--sft_num_epochs", type=int, default=SFTCfg.num_epochs)
|
|
both_p.add_argument("--sft_per_device_bs", type=int, default=SFTCfg.per_device_bs)
|
|
both_p.add_argument("--sft_grad_accum", type=int, default=SFTCfg.grad_accum)
|
|
both_p.add_argument("--sft_lr", type=float, default=SFTCfg.lr)
|
|
both_p.add_argument("--sft_warmup_ratio", type=float, default=SFTCfg.warmup_ratio)
|
|
both_p.add_argument("--sft_logging_steps", type=int, default=SFTCfg.logging_steps)
|
|
both_p.add_argument("--sft_save_steps", type=int, default=SFTCfg.save_steps)
|
|
both_p.add_argument("--sft_save_total_limit", type=int, default=SFTCfg.save_total_limit)
|
|
both_p.add_argument("--sft_lora_r", type=int, default=SFTCfg.lora_r)
|
|
both_p.add_argument("--sft_lora_alpha", type=int, default=SFTCfg.lora_alpha)
|
|
both_p.add_argument("--sft_lora_dropout", type=float, default=SFTCfg.lora_dropout)
|
|
# DPO args subset
|
|
both_p.add_argument("--dpo_data_path", default=DPOCfg.dpo_data_path)
|
|
both_p.add_argument("--dpo_out_dir", default=DPOCfg.out_dir)
|
|
both_p.add_argument("--dpo_max_length", type=int, default=DPOCfg.max_length)
|
|
both_p.add_argument("--dpo_max_prompt_length", type=int, default=DPOCfg.max_prompt_length)
|
|
both_p.add_argument("--dpo_num_epochs", type=int, default=DPOCfg.num_epochs)
|
|
both_p.add_argument("--dpo_per_device_bs", type=int, default=DPOCfg.per_device_bs)
|
|
both_p.add_argument("--dpo_grad_accum", type=int, default=DPOCfg.grad_accum)
|
|
both_p.add_argument("--dpo_lr", type=float, default=DPOCfg.lr)
|
|
both_p.add_argument("--dpo_warmup_ratio", type=float, default=DPOCfg.warmup_ratio)
|
|
both_p.add_argument("--dpo_logging_steps", type=int, default=DPOCfg.logging_steps)
|
|
both_p.add_argument("--dpo_save_steps", type=int, default=DPOCfg.save_steps)
|
|
both_p.add_argument("--dpo_save_total_limit", type=int, default=DPOCfg.save_total_limit)
|
|
both_p.add_argument("--dpo_beta", type=float, default=DPOCfg.beta)
|
|
both_p.add_argument("--dpo_lora_r", type=int, default=DPOCfg.lora_r)
|
|
both_p.add_argument("--dpo_lora_alpha", type=int, default=DPOCfg.lora_alpha)
|
|
both_p.add_argument("--dpo_lora_dropout", type=float, default=DPOCfg.lora_dropout)
|
|
both_p.add_argument("--sft_adapter_dir", default="", help="If set, use this as DPO SFT adapter dir; otherwise uses --sft_out_dir")
|
|
|
|
return p
|
|
|
|
|
|
def interactive_menu() -> str:
|
|
"""
|
|
Simple interactive selector when the script is launched without CLI arguments.
|
|
Uses defaults for all paths/knobs. If you want to override anything, run with CLI flags.
|
|
"""
|
|
print("\n=== Training menu ===")
|
|
print("1) SFT")
|
|
print("2) DPO")
|
|
print("3) SFT + DPO (both)")
|
|
print("q) Quit\n")
|
|
|
|
while True:
|
|
choice = input("Select an option [1/2/3/q]: ").strip().lower()
|
|
if choice in ("1", "sft"):
|
|
return "sft"
|
|
if choice in ("2", "dpo"):
|
|
return "dpo"
|
|
if choice in ("3", "both", "all"):
|
|
return "both"
|
|
if choice in ("q", "quit", "exit"):
|
|
raise SystemExit(0)
|
|
print("Invalid choice. Please enter 1, 2, 3, or q.")
|
|
|
|
|
|
def main():
|
|
parser = build_parser()
|
|
|
|
# If user runs the script with no arguments, we show interactive menu
|
|
# and run with defaults.
|
|
import sys
|
|
if len(sys.argv) == 1:
|
|
# IMPORTANT for torchrun/DDP:
|
|
# In distributed mode, we must NOT prompt on every process.
|
|
# Rank0 asks, then writes the choice to a small file; other ranks read it.
|
|
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
|
|
if world_size == 1:
|
|
cmd = interactive_menu()
|
|
else:
|
|
choice_file = os.path.join(
|
|
os.getcwd(),
|
|
f".train_choice_{os.environ.get('MASTER_PORT', '0')}_{os.environ.get('RUN_ID', '0')}.json",
|
|
)
|
|
|
|
if local_rank == 0:
|
|
cmd = interactive_menu()
|
|
try:
|
|
with open(choice_file, "w", encoding="utf-8") as f:
|
|
json.dump({"cmd": cmd}, f)
|
|
except Exception as e:
|
|
print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True)
|
|
else:
|
|
# Wait for rank0 to write the choice
|
|
deadline = time.time() + 600 # seconds
|
|
cmd = None
|
|
while time.time() < deadline:
|
|
if os.path.exists(choice_file):
|
|
try:
|
|
with open(choice_file, "r", encoding="utf-8") as f:
|
|
cmd = json.load(f).get("cmd")
|
|
if cmd:
|
|
break
|
|
except Exception:
|
|
pass
|
|
time.sleep(0.2)
|
|
if not cmd:
|
|
print("[ERROR] Timed out waiting for rank0 menu selection. "
|
|
"Run with explicit subcommand: sft/dpo/both.", flush=True)
|
|
raise SystemExit(2)
|
|
|
|
args = parser.parse_args([cmd]) # defaults
|
|
else:
|
|
args = parser.parse_args()
|
|
if not getattr(args, "cmd", None):
|
|
parser.print_help()
|
|
raise SystemExit(2)
|
|
|
|
common = CommonCfg(
|
|
base_model_path=args.base_model_path,
|
|
target_modules=args.target_modules,
|
|
use_fast_tokenizer=args.use_fast_tokenizer,
|
|
trust_remote_code=not args.no_trust_remote_code,
|
|
)
|
|
|
|
if args.cmd == "sft":
|
|
sft = SFTCfg(
|
|
data_path=args.sft_data_path,
|
|
out_dir=args.sft_out_dir,
|
|
max_seq_len=args.sft_max_seq_len,
|
|
num_epochs=args.sft_num_epochs,
|
|
per_device_bs=args.sft_per_device_bs,
|
|
grad_accum=args.sft_grad_accum,
|
|
lr=args.sft_lr,
|
|
warmup_ratio=args.sft_warmup_ratio,
|
|
logging_steps=args.sft_logging_steps,
|
|
save_steps=args.sft_save_steps,
|
|
save_total_limit=args.sft_save_total_limit,
|
|
lora_r=args.sft_lora_r,
|
|
lora_alpha=args.sft_lora_alpha,
|
|
lora_dropout=args.sft_lora_dropout,
|
|
)
|
|
run_sft(common, sft)
|
|
|
|
elif args.cmd == "dpo":
|
|
dpo = DPOCfg(
|
|
sft_adapter_dir=args.sft_adapter_dir,
|
|
dpo_data_path=args.dpo_data_path,
|
|
out_dir=args.dpo_out_dir,
|
|
max_length=args.dpo_max_length,
|
|
max_prompt_length=args.dpo_max_prompt_length,
|
|
num_epochs=args.dpo_num_epochs,
|
|
per_device_bs=args.dpo_per_device_bs,
|
|
grad_accum=args.dpo_grad_accum,
|
|
lr=args.dpo_lr,
|
|
warmup_ratio=args.dpo_warmup_ratio,
|
|
logging_steps=args.dpo_logging_steps,
|
|
save_steps=args.dpo_save_steps,
|
|
save_total_limit=args.dpo_save_total_limit,
|
|
beta=args.dpo_beta,
|
|
lora_r=args.dpo_lora_r,
|
|
lora_alpha=args.dpo_lora_alpha,
|
|
lora_dropout=args.dpo_lora_dropout,
|
|
)
|
|
run_dpo(common, dpo)
|
|
|
|
elif args.cmd == "both":
|
|
sft = SFTCfg(
|
|
data_path=args.sft_data_path,
|
|
out_dir=args.sft_out_dir,
|
|
max_seq_len=args.sft_max_seq_len,
|
|
num_epochs=args.sft_num_epochs,
|
|
per_device_bs=args.sft_per_device_bs,
|
|
grad_accum=args.sft_grad_accum,
|
|
lr=args.sft_lr,
|
|
warmup_ratio=args.sft_warmup_ratio,
|
|
logging_steps=args.sft_logging_steps,
|
|
save_steps=args.sft_save_steps,
|
|
save_total_limit=args.sft_save_total_limit,
|
|
lora_r=args.sft_lora_r,
|
|
lora_alpha=args.sft_lora_alpha,
|
|
lora_dropout=args.sft_lora_dropout,
|
|
)
|
|
run_sft(common, sft)
|
|
|
|
sft_adapter_dir = args.sft_adapter_dir.strip() or sft.out_dir
|
|
dpo = DPOCfg(
|
|
sft_adapter_dir=sft_adapter_dir,
|
|
dpo_data_path=args.dpo_data_path,
|
|
out_dir=args.dpo_out_dir,
|
|
max_length=args.dpo_max_length,
|
|
max_prompt_length=args.dpo_max_prompt_length,
|
|
num_epochs=args.dpo_num_epochs,
|
|
per_device_bs=args.dpo_per_device_bs,
|
|
grad_accum=args.dpo_grad_accum,
|
|
lr=args.dpo_lr,
|
|
warmup_ratio=args.dpo_warmup_ratio,
|
|
logging_steps=args.dpo_logging_steps,
|
|
save_steps=args.dpo_save_steps,
|
|
save_total_limit=args.dpo_save_total_limit,
|
|
beta=args.dpo_beta,
|
|
lora_r=args.dpo_lora_r,
|
|
lora_alpha=args.dpo_lora_alpha,
|
|
lora_dropout=args.dpo_lora_dropout,
|
|
)
|
|
run_dpo(common, dpo)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|