DP/Training/training_dpo_sft_mistral_sk.py

799 lines
28 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Combined QLoRA training script for:
1) SFT (Supervised Fine-Tuning) with TRL SFTTrainer
2) DPO (Direct Preference Optimization) with TRL DPOTrainer
Usage examples (multi-GPU):
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py sft
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py dpo
torchrun --nproc_per_node=2 training_sft_dpo_sk_combined.py both
All paths and hyperparameters can be overridden via CLI flags (see --help).
"""
import os
import time
import json
import inspect
import argparse
from dataclasses import dataclass, asdict
from typing import List, Optional
import torch
from datasets import load_from_disk
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
Trainer as HFTrainer,
TrainerCallback,
)
from peft import (
PeftModel,
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
)
from trl import SFTTrainer, SFTConfig, DPOTrainer, DPOConfig
# ==========================
# Compatibility patches
# ==========================
def patch_trl_dpo_trainer_if_needed():
"""
Applies runtime patches to handle TRL <-> Transformers API incompatibilities.
Why this exists:
- Some TRL versions implement DPOTrainer methods with older signatures.
- Transformers Trainer may call methods with a newer signature / extra kwargs.
- These mismatches can lead to runtime errors during training.
Patch set:
1) get_batch_samples signature mismatch:
- Transformers expects: get_batch_samples(epoch_iterator, num_batches)
- Some older TRL expects: get_batch_samples(model, batch, ...)
- Patch: route the call through HFTrainer.get_batch_samples.
2) compute_loss signature mismatch:
- Newer Transformers may pass num_items_in_batch
- Some TRL versions do not accept that kwarg
- Patch: wrap compute_loss to accept num_items_in_batch and delegate.
"""
# ---- Patch get_batch_samples ----
try:
sig = inspect.signature(DPOTrainer.get_batch_samples)
params = list(sig.parameters.keys()) # includes "self"
# Old TRL often has (self, model, batch, ...) so params[1] == "model"
if len(params) >= 3 and params[1] == "model":
print("[PATCH] DPOTrainer.get_batch_samples signature looks old; patching to HFTrainer.get_batch_samples", flush=True)
def _get_batch_samples(self, epoch_iterator, num_batches):
return HFTrainer.get_batch_samples(self, epoch_iterator, num_batches)
DPOTrainer.get_batch_samples = _get_batch_samples
else:
print("[INFO] DPOTrainer.get_batch_samples signature looks compatible; no patch needed", flush=True)
except Exception as e:
print(f"[WARN] Could not inspect/patch get_batch_samples: {e}", flush=True)
# ---- Patch compute_loss ----
try:
sig = inspect.signature(DPOTrainer.compute_loss)
if "num_items_in_batch" not in sig.parameters:
print("[PATCH] DPOTrainer.compute_loss lacks num_items_in_batch; patching wrapper", flush=True)
_orig_compute_loss = DPOTrainer.compute_loss
def _compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
return _orig_compute_loss(self, model, inputs, return_outputs=return_outputs)
DPOTrainer.compute_loss = _compute_loss
else:
print("[INFO] DPOTrainer.compute_loss signature looks compatible; no patch needed", flush=True)
except Exception as e:
print(f"[WARN] Could not inspect/patch compute_loss: {e}", flush=True)
# ==========================
# Callbacks
# ==========================
class ProgressETA(TrainerCallback):
"""
Trainer callback that prints progress with ETA on the main local process.
Printed metrics:
- global step / max steps
- steps per second
- ETA in minutes
- optionally loss and learning rate from the trainer logs
"""
def __init__(self):
self.t0 = None
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
def on_train_begin(self, args, state, control, **kwargs):
self.t0 = time.time()
if self.local_rank == 0:
print(f"[INFO] Training start. max_steps={state.max_steps}", flush=True)
def on_log(self, args, state, control, logs=None, **kwargs):
# Only print on local_rank==0 to reduce duplicated output under DDP
if self.local_rank != 0:
return
if not logs or not state.max_steps:
return
step = state.global_step
max_steps = state.max_steps
elapsed = time.time() - (self.t0 or time.time())
sps = step / max(elapsed, 1e-9)
rem = max_steps - step
eta = rem / max(sps, 1e-9)
msg = f"[PROGRESS] step {step}/{max_steps} | {sps:.3f} steps/s | ETA {eta/60:.1f} min"
if "loss" in logs:
msg += f" | loss {logs['loss']:.4f}"
if "learning_rate" in logs:
msg += f" | lr {logs['learning_rate']:.3e}"
print(msg, flush=True)
# ==========================
# Config dataclasses
# ==========================
@dataclass
class CommonCfg:
"""
Shared configuration for model/tokenizer/quantization used by both SFT and DPO.
"""
base_model_path: str = "/home/hyrenko/Diploma/models/mistral-sk-7b"
target_modules: List[str] = None # set in __post_init__
# QLoRA quant config (bitsandbytes)
load_in_4bit: bool = True
bnb_4bit_compute_dtype: str = "float16"
bnb_4bit_use_double_quant: bool = True
bnb_4bit_quant_type: str = "nf4"
# Tokenizer loading
use_fast_tokenizer: bool = False
trust_remote_code: bool = True
def __post_init__(self):
# Default LoRA target modules for Mistral-style architectures
if self.target_modules is None:
self.target_modules = ["q_proj", "v_proj"]
@dataclass
class SFTCfg:
"""
SFT stage configuration.
Dataset requirement:
- load_from_disk(data_path) must contain a "text" column.
"""
data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_sft" # must have "text"
out_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
max_seq_len: int = 1024
num_epochs: int = 1
per_device_bs: int = 1
grad_accum: int = 16
lr: float = 2e-4
warmup_ratio: float = 0.03
logging_steps: int = 10
save_steps: int = 200
save_total_limit: int = 2
# LoRA hyperparameters
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
@dataclass
class DPOCfg:
"""
DPO stage configuration.
Dataset requirement:
- load_from_disk(dpo_data_path) must contain: prompt, chosen, rejected
Adapter behavior:
- Loads the SFT adapter as frozen weights.
- Adds a NEW LoRA adapter for DPO and trains only that adapter.
"""
# SFT adapter input dir (from SFT run)
sft_adapter_dir: str = "/home/hyrenko/Diploma/outputs_sft/mistral-sk-7b-pku-saferlhf-sk-sft-qlora"
dpo_data_path: str = "/home/hyrenko/Diploma/datasets/pku_saferlhf_sk_dpo" # must have prompt/chosen/rejected
out_dir: str = "/home/hyrenko/Diploma/outputs_dpo/mistral-sk-7b-pku-saferlhf-sk-dpo-qlora"
max_length: int = 1024
max_prompt_length: int = 512
num_epochs: int = 1
per_device_bs: int = 1
grad_accum: int = 16
lr: float = 1e-5
warmup_ratio: float = 0.03
logging_steps: int = 10
save_steps: int = 200
save_total_limit: int = 2
beta: float = 0.1
# NEW LoRA adapter for DPO (separate from SFT)
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
# ==========================
# Helpers
# ==========================
def get_local_rank_and_device():
"""
Determines device placement based on LOCAL_RANK.
In DDP runs, each process binds itself to a separate GPU.
"""
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
else:
device = torch.device("cpu")
print(f"[INFO] local_rank={local_rank} device={device}", flush=True)
return local_rank, device
def make_bnb_config(common: CommonCfg) -> BitsAndBytesConfig:
"""
Builds BitsAndBytesConfig for 4-bit QLoRA.
"""
dtype = getattr(torch, common.bnb_4bit_compute_dtype)
return BitsAndBytesConfig(
load_in_4bit=common.load_in_4bit,
bnb_4bit_compute_dtype=dtype,
bnb_4bit_use_double_quant=common.bnb_4bit_use_double_quant,
bnb_4bit_quant_type=common.bnb_4bit_quant_type,
)
def load_tokenizer(common: CommonCfg):
"""
Loads tokenizer from base_model_path and ensures pad_token is defined.
"""
tokenizer = AutoTokenizer.from_pretrained(
common.base_model_path,
use_fast=common.use_fast_tokenizer,
trust_remote_code=common.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def load_4bit_model(common: CommonCfg, bnb_config: BitsAndBytesConfig, device: torch.device):
"""
Loads the base model with 4-bit quantization and prepares it for k-bit training.
Key settings:
- device_map binds the model to the current GPU under DDP.
- use_cache is disabled for training.
- gradient checkpointing is disabled here to avoid reentrant issues under DDP.
"""
model = AutoModelForCausalLM.from_pretrained(
common.base_model_path,
quantization_config=bnb_config,
device_map={"": torch.cuda.current_device()} if device.type == "cuda" else None,
torch_dtype=torch.float16,
trust_remote_code=common.trust_remote_code,
)
model.config.use_cache = False
# Avoid DDP + checkpointing reentrant issues
if hasattr(model, "gradient_checkpointing_disable"):
model.gradient_checkpointing_disable()
model = prepare_model_for_kbit_training(model)
return model
def save_manifest(out_dir: str, local_rank: int, manifest: dict):
"""
Writes a JSON manifest for reproducibility (rank0 only).
"""
if local_rank != 0:
return
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "run_manifest.json"), "w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=2)
# ==========================
# SFT
# ==========================
def run_sft(common: CommonCfg, sft: SFTCfg):
"""
Runs SFT using TRL SFTTrainer with QLoRA adapter training.
Expected dataset:
- load_from_disk(sft.data_path) with a "text" field.
"""
local_rank, device = get_local_rank_and_device()
os.makedirs(sft.out_dir, exist_ok=True)
ds = load_from_disk(sft.data_path)
if "text" not in ds.column_names:
raise RuntimeError(f"[SFT] Dataset must have 'text' column. Found: {ds.column_names}")
print(f"[INFO] [SFT] Dataset rows: {len(ds)}", flush=True)
tokenizer = load_tokenizer(common)
bnb_config = make_bnb_config(common)
model = load_4bit_model(common, bnb_config, device)
# Trainable LoRA adapter for SFT stage
lora = LoraConfig(
r=sft.lora_r,
lora_alpha=sft.lora_alpha,
lora_dropout=sft.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=common.target_modules,
)
model = get_peft_model(model, lora)
if local_rank == 0:
model.print_trainable_parameters()
cfg = SFTConfig(
output_dir=sft.out_dir,
dataset_text_field="text",
max_seq_length=sft.max_seq_len,
packing=True,
num_train_epochs=sft.num_epochs,
per_device_train_batch_size=sft.per_device_bs,
gradient_accumulation_steps=sft.grad_accum,
learning_rate=sft.lr,
lr_scheduler_type="cosine",
warmup_ratio=sft.warmup_ratio,
weight_decay=0.0,
logging_steps=sft.logging_steps,
save_steps=sft.save_steps,
save_total_limit=sft.save_total_limit,
fp16=True,
bf16=False,
optim="paged_adamw_8bit",
report_to="none",
remove_unused_columns=False,
ddp_find_unused_parameters=False,
gradient_checkpointing=False,
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=ds,
args=cfg,
)
trainer.add_callback(ProgressETA())
save_manifest(
sft.out_dir,
local_rank,
{
"stage": "sft",
"common": asdict(common),
"sft": asdict(sft),
"quant": "4bit nf4 qlora",
"gradient_checkpointing": False,
},
)
trainer.train()
if local_rank == 0:
trainer.save_model(sft.out_dir) # saves PEFT adapter
tokenizer.save_pretrained(sft.out_dir)
print(f"[OK] [SFT] Finished. Adapter saved to: {sft.out_dir}", flush=True)
print("[OK] [SFT] Base model was NOT overwritten; only LoRA adapter weights were trained/saved.", flush=True)
# ==========================
# DPO
# ==========================
def run_dpo(common: CommonCfg, dpo: DPOCfg):
"""
Runs DPO using TRL DPOTrainer with QLoRA adapters.
Adapter composition:
- Load base model (4-bit)
- Attach SFT adapter as frozen (is_trainable=False)
- Add a new LoRA adapter for DPO training (trainable)
Expected dataset:
- load_from_disk(dpo.dpo_data_path) with prompt/chosen/rejected columns
"""
# Apply patches early (no-op if already compatible)
patch_trl_dpo_trainer_if_needed()
local_rank, device = get_local_rank_and_device()
os.makedirs(dpo.out_dir, exist_ok=True)
ds = load_from_disk(dpo.dpo_data_path)
for c in ("prompt", "chosen", "rejected"):
if c not in ds.column_names:
raise RuntimeError(f"[DPO] Missing '{c}' in DPO dataset. Found: {ds.column_names}")
print(f"[INFO] [DPO] Dataset rows: {len(ds)}", flush=True)
tokenizer = load_tokenizer(common)
bnb_config = make_bnb_config(common)
model = load_4bit_model(common, bnb_config, device)
# Load SFT adapter as frozen weights
model = PeftModel.from_pretrained(model, dpo.sft_adapter_dir, is_trainable=False)
# Trainable LoRA adapter for DPO stage
dpo_lora = LoraConfig(
r=dpo.lora_r,
lora_alpha=dpo.lora_alpha,
lora_dropout=dpo.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=common.target_modules,
)
model = get_peft_model(model, dpo_lora)
if local_rank == 0:
model.print_trainable_parameters()
cfg = DPOConfig(
output_dir=dpo.out_dir,
max_length=dpo.max_length,
max_prompt_length=dpo.max_prompt_length,
num_train_epochs=dpo.num_epochs,
per_device_train_batch_size=dpo.per_device_bs,
gradient_accumulation_steps=dpo.grad_accum,
learning_rate=dpo.lr,
lr_scheduler_type="cosine",
warmup_ratio=dpo.warmup_ratio,
logging_steps=dpo.logging_steps,
save_steps=dpo.save_steps,
save_total_limit=dpo.save_total_limit,
fp16=True,
bf16=False,
optim="paged_adamw_8bit",
report_to="none",
beta=dpo.beta,
ddp_find_unused_parameters=False,
gradient_checkpointing=False,
)
trainer = DPOTrainer(
model=model,
ref_model=None,
args=cfg,
train_dataset=ds,
tokenizer=tokenizer,
)
trainer.add_callback(ProgressETA())
save_manifest(
dpo.out_dir,
local_rank,
{
"stage": "dpo",
"common": asdict(common),
"dpo": asdict(dpo),
"notes": "Includes TRL/Transformers compatibility patches for get_batch_samples + compute_loss",
},
)
trainer.train()
if local_rank == 0:
trainer.save_model(dpo.out_dir)
tokenizer.save_pretrained(dpo.out_dir)
print(f"[OK] [DPO] Finished. Adapter saved to: {dpo.out_dir}", flush=True)
print("[OK] [DPO] Base model was NOT overwritten; only DPO LoRA adapter weights were trained/saved.", flush=True)
# ==========================
# CLI
# ==========================
def build_parser():
"""
Builds an argparse parser with subcommands:
- sft
- dpo
- both
"""
p = argparse.ArgumentParser(
description="Combined QLoRA training: SFT and DPO (TRL).",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
sub = p.add_subparsers(dest="cmd", required=False)
def add_common_args(sp):
# Shared base model/tokenizer arguments
sp.add_argument("--base_model_path", default=CommonCfg.base_model_path)
sp.add_argument("--target_modules", nargs="+", default=["q_proj", "v_proj"])
sp.add_argument("--use_fast_tokenizer", action="store_true", help="If set, use fast tokenizer")
sp.add_argument("--no_trust_remote_code", action="store_true", help="Disable trust_remote_code")
# ---- SFT ----
sft_p = sub.add_parser("sft", help="Run SFT stage only")
add_common_args(sft_p)
sft_p.add_argument("--sft_data_path", default=SFTCfg.data_path)
sft_p.add_argument("--sft_out_dir", default=SFTCfg.out_dir)
sft_p.add_argument("--sft_max_seq_len", type=int, default=SFTCfg.max_seq_len)
sft_p.add_argument("--sft_num_epochs", type=int, default=SFTCfg.num_epochs)
sft_p.add_argument("--sft_per_device_bs", type=int, default=SFTCfg.per_device_bs)
sft_p.add_argument("--sft_grad_accum", type=int, default=SFTCfg.grad_accum)
sft_p.add_argument("--sft_lr", type=float, default=SFTCfg.lr)
sft_p.add_argument("--sft_warmup_ratio", type=float, default=SFTCfg.warmup_ratio)
sft_p.add_argument("--sft_logging_steps", type=int, default=SFTCfg.logging_steps)
sft_p.add_argument("--sft_save_steps", type=int, default=SFTCfg.save_steps)
sft_p.add_argument("--sft_save_total_limit", type=int, default=SFTCfg.save_total_limit)
sft_p.add_argument("--sft_lora_r", type=int, default=SFTCfg.lora_r)
sft_p.add_argument("--sft_lora_alpha", type=int, default=SFTCfg.lora_alpha)
sft_p.add_argument("--sft_lora_dropout", type=float, default=SFTCfg.lora_dropout)
# ---- DPO ----
dpo_p = sub.add_parser("dpo", help="Run DPO stage only")
add_common_args(dpo_p)
dpo_p.add_argument("--sft_adapter_dir", default=DPOCfg.sft_adapter_dir)
dpo_p.add_argument("--dpo_data_path", default=DPOCfg.dpo_data_path)
dpo_p.add_argument("--dpo_out_dir", default=DPOCfg.out_dir)
dpo_p.add_argument("--dpo_max_length", type=int, default=DPOCfg.max_length)
dpo_p.add_argument("--dpo_max_prompt_length", type=int, default=DPOCfg.max_prompt_length)
dpo_p.add_argument("--dpo_num_epochs", type=int, default=DPOCfg.num_epochs)
dpo_p.add_argument("--dpo_per_device_bs", type=int, default=DPOCfg.per_device_bs)
dpo_p.add_argument("--dpo_grad_accum", type=int, default=DPOCfg.grad_accum)
dpo_p.add_argument("--dpo_lr", type=float, default=DPOCfg.lr)
dpo_p.add_argument("--dpo_warmup_ratio", type=float, default=DPOCfg.warmup_ratio)
dpo_p.add_argument("--dpo_logging_steps", type=int, default=DPOCfg.logging_steps)
dpo_p.add_argument("--dpo_save_steps", type=int, default=DPOCfg.save_steps)
dpo_p.add_argument("--dpo_save_total_limit", type=int, default=DPOCfg.save_total_limit)
dpo_p.add_argument("--dpo_beta", type=float, default=DPOCfg.beta)
dpo_p.add_argument("--dpo_lora_r", type=int, default=DPOCfg.lora_r)
dpo_p.add_argument("--dpo_lora_alpha", type=int, default=DPOCfg.lora_alpha)
dpo_p.add_argument("--dpo_lora_dropout", type=float, default=DPOCfg.lora_dropout)
# ---- BOTH ----
both_p = sub.add_parser("both", help="Run SFT then DPO (uses SFT out dir as DPO sft_adapter_dir unless overridden)")
add_common_args(both_p)
# SFT args subset
both_p.add_argument("--sft_data_path", default=SFTCfg.data_path)
both_p.add_argument("--sft_out_dir", default=SFTCfg.out_dir)
both_p.add_argument("--sft_max_seq_len", type=int, default=SFTCfg.max_seq_len)
both_p.add_argument("--sft_num_epochs", type=int, default=SFTCfg.num_epochs)
both_p.add_argument("--sft_per_device_bs", type=int, default=SFTCfg.per_device_bs)
both_p.add_argument("--sft_grad_accum", type=int, default=SFTCfg.grad_accum)
both_p.add_argument("--sft_lr", type=float, default=SFTCfg.lr)
both_p.add_argument("--sft_warmup_ratio", type=float, default=SFTCfg.warmup_ratio)
both_p.add_argument("--sft_logging_steps", type=int, default=SFTCfg.logging_steps)
both_p.add_argument("--sft_save_steps", type=int, default=SFTCfg.save_steps)
both_p.add_argument("--sft_save_total_limit", type=int, default=SFTCfg.save_total_limit)
both_p.add_argument("--sft_lora_r", type=int, default=SFTCfg.lora_r)
both_p.add_argument("--sft_lora_alpha", type=int, default=SFTCfg.lora_alpha)
both_p.add_argument("--sft_lora_dropout", type=float, default=SFTCfg.lora_dropout)
# DPO args subset
both_p.add_argument("--dpo_data_path", default=DPOCfg.dpo_data_path)
both_p.add_argument("--dpo_out_dir", default=DPOCfg.out_dir)
both_p.add_argument("--dpo_max_length", type=int, default=DPOCfg.max_length)
both_p.add_argument("--dpo_max_prompt_length", type=int, default=DPOCfg.max_prompt_length)
both_p.add_argument("--dpo_num_epochs", type=int, default=DPOCfg.num_epochs)
both_p.add_argument("--dpo_per_device_bs", type=int, default=DPOCfg.per_device_bs)
both_p.add_argument("--dpo_grad_accum", type=int, default=DPOCfg.grad_accum)
both_p.add_argument("--dpo_lr", type=float, default=DPOCfg.lr)
both_p.add_argument("--dpo_warmup_ratio", type=float, default=DPOCfg.warmup_ratio)
both_p.add_argument("--dpo_logging_steps", type=int, default=DPOCfg.logging_steps)
both_p.add_argument("--dpo_save_steps", type=int, default=DPOCfg.save_steps)
both_p.add_argument("--dpo_save_total_limit", type=int, default=DPOCfg.save_total_limit)
both_p.add_argument("--dpo_beta", type=float, default=DPOCfg.beta)
both_p.add_argument("--dpo_lora_r", type=int, default=DPOCfg.lora_r)
both_p.add_argument("--dpo_lora_alpha", type=int, default=DPOCfg.lora_alpha)
both_p.add_argument("--dpo_lora_dropout", type=float, default=DPOCfg.lora_dropout)
both_p.add_argument("--sft_adapter_dir", default="", help="If set, use this as DPO SFT adapter dir; otherwise uses --sft_out_dir")
return p
def interactive_menu() -> str:
"""
Interactive selector used when the script is launched without CLI arguments.
"""
print("\n=== Training menu ===")
print("1) SFT")
print("2) DPO")
print("3) SFT + DPO (both)")
print("q) Quit\n")
while True:
choice = input("Select an option [1/2/3/q]: ").strip().lower()
if choice in ("1", "sft"):
return "sft"
if choice in ("2", "dpo"):
return "dpo"
if choice in ("3", "both", "all"):
return "both"
if choice in ("q", "quit", "exit"):
raise SystemExit(0)
print("Invalid choice. Please enter 1, 2, 3, or q.")
def main():
"""
Entry point:
- Parse CLI args (or use interactive menu when no args are provided)
- Create config objects
- Run the selected stage(s)
"""
parser = build_parser()
# If run without args: show menu and run with defaults
import sys
if len(sys.argv) == 1:
# Under DDP, avoid prompting on every process:
# local_rank==0 writes the choice; other ranks read it.
world_size = int(os.environ.get("WORLD_SIZE", "1"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if world_size == 1:
cmd = interactive_menu()
else:
choice_file = os.path.join(
os.getcwd(),
f".train_choice_{os.environ.get('MASTER_PORT', '0')}_{os.environ.get('RUN_ID', '0')}.json",
)
if local_rank == 0:
cmd = interactive_menu()
try:
with open(choice_file, "w", encoding="utf-8") as f:
json.dump({"cmd": cmd}, f)
except Exception as e:
print(f"[WARN] Could not write choice file {choice_file}: {e}", flush=True)
else:
# Poll for choice_file created by local_rank==0
deadline = time.time() + 600 # seconds
cmd = None
while time.time() < deadline:
if os.path.exists(choice_file):
try:
with open(choice_file, "r", encoding="utf-8") as f:
cmd = json.load(f).get("cmd")
if cmd:
break
except Exception:
pass
time.sleep(0.2)
if not cmd:
print("[ERROR] Timed out waiting for menu selection. "
"Run with explicit subcommand: sft/dpo/both.", flush=True)
raise SystemExit(2)
args = parser.parse_args([cmd]) # defaults
else:
args = parser.parse_args()
if not getattr(args, "cmd", None):
parser.print_help()
raise SystemExit(2)
common = CommonCfg(
base_model_path=args.base_model_path,
target_modules=args.target_modules,
use_fast_tokenizer=args.use_fast_tokenizer,
trust_remote_code=not args.no_trust_remote_code,
)
if args.cmd == "sft":
sft = SFTCfg(
data_path=args.sft_data_path,
out_dir=args.sft_out_dir,
max_seq_len=args.sft_max_seq_len,
num_epochs=args.sft_num_epochs,
per_device_bs=args.sft_per_device_bs,
grad_accum=args.sft_grad_accum,
lr=args.sft_lr,
warmup_ratio=args.sft_warmup_ratio,
logging_steps=args.sft_logging_steps,
save_steps=args.sft_save_steps,
save_total_limit=args.sft_save_total_limit,
lora_r=args.sft_lora_r,
lora_alpha=args.sft_lora_alpha,
lora_dropout=args.sft_lora_dropout,
)
run_sft(common, sft)
elif args.cmd == "dpo":
dpo = DPOCfg(
sft_adapter_dir=args.sft_adapter_dir,
dpo_data_path=args.dpo_data_path,
out_dir=args.dpo_out_dir,
max_length=args.dpo_max_length,
max_prompt_length=args.dpo_max_prompt_length,
num_epochs=args.dpo_num_epochs,
per_device_bs=args.dpo_per_device_bs,
grad_accum=args.dpo_grad_accum,
lr=args.dpo_lr,
warmup_ratio=args.dpo_warmup_ratio,
logging_steps=args.dpo_logging_steps,
save_steps=args.dpo_save_steps,
save_total_limit=args.dpo_save_total_limit,
beta=args.dpo_beta,
lora_r=args.dpo_lora_r,
lora_alpha=args.dpo_lora_alpha,
lora_dropout=args.dpo_lora_dropout,
)
run_dpo(common, dpo)
elif args.cmd == "both":
sft = SFTCfg(
data_path=args.sft_data_path,
out_dir=args.sft_out_dir,
max_seq_len=args.sft_max_seq_len,
num_epochs=args.sft_num_epochs,
per_device_bs=args.sft_per_device_bs,
grad_accum=args.sft_grad_accum,
lr=args.sft_lr,
warmup_ratio=args.sft_warmup_ratio,
logging_steps=args.sft_logging_steps,
save_steps=args.sft_save_steps,
save_total_limit=args.sft_save_total_limit,
lora_r=args.sft_lora_r,
lora_alpha=args.sft_lora_alpha,
lora_dropout=args.sft_lora_dropout,
)
run_sft(common, sft)
sft_adapter_dir = args.sft_adapter_dir.strip() or sft.out_dir
dpo = DPOCfg(
sft_adapter_dir=sft_adapter_dir,
dpo_data_path=args.dpo_data_path,
out_dir=args.dpo_out_dir,
max_length=args.dpo_max_length,
max_prompt_length=args.dpo_max_prompt_length,
num_epochs=args.dpo_num_epochs,
per_device_bs=args.dpo_per_device_bs,
grad_accum=args.dpo_grad_accum,
lr=args.dpo_lr,
warmup_ratio=args.dpo_warmup_ratio,
logging_steps=args.dpo_logging_steps,
save_steps=args.dpo_save_steps,
save_total_limit=args.dpo_save_total_limit,
beta=args.dpo_beta,
lora_r=args.dpo_lora_r,
lora_alpha=args.dpo_lora_alpha,
lora_dropout=args.dpo_lora_dropout,
)
run_dpo(common, dpo)
if __name__ == "__main__":
main()