Added comments
This commit is contained in:
parent
18cdd5daa6
commit
7bb03b551f
@ -2,18 +2,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Unified training script: SFT (QLoRA + masked loss) and DPO (TRL) in ONE file.
|
||||
Unified training script for:
|
||||
- SFT (QLoRA + masked loss)
|
||||
- DPO (TRL DPOTrainer)
|
||||
|
||||
UX requirement (as you asked):
|
||||
- You run: python3 train_unified.py
|
||||
- It shows a simple menu 1/2/3
|
||||
- You press 1, 2, or 3
|
||||
- Script relaunches itself with accelerate (2 GPUs) and runs the chosen stage
|
||||
using DEFAULT paths/hparams already embedded.
|
||||
|
||||
Notes:
|
||||
- Menu is shown ONLY once (before accelerate spawns processes).
|
||||
- After relaunch, each process runs normally with LOCAL_RANK set by accelerate.
|
||||
Execution model:
|
||||
- Run the script normally: it shows a simple menu (SFT / DPO / BOTH).
|
||||
- The script relaunches itself via accelerate (multi-process).
|
||||
- After relaunch, each process runs with LOCAL_RANK set by accelerate.
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -47,7 +43,7 @@ from trl import DPOTrainer, DPOConfig
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Defaults (same spirit as your original scripts)
|
||||
# Defaults
|
||||
# --------------------------
|
||||
DEFAULT_MODEL_DIR = "/home/hyrenko/Diploma/models/llama3.1-8b"
|
||||
|
||||
@ -61,17 +57,17 @@ DEFAULT_SFT_OUT = "./out/llama3_1_8b_sft_safety_lora_masked"
|
||||
DEFAULT_DPO_OUT = "./out/llama3_1_8b_dpo_safety_lora"
|
||||
DEFAULT_DPO_DATA = "./data/pku_dpo"
|
||||
|
||||
# Prompt wrappers (keep consistent)
|
||||
# Prompt wrappers
|
||||
INSTR_PREFIX = "### Instruction:\n"
|
||||
RESP_PREFIX = "\n\n### Response:\n"
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Simple interactive menu (shown only before accelerate)
|
||||
# Interactive menu (shown only before accelerate spawns processes)
|
||||
# --------------------------
|
||||
def choose_stage_simple() -> str:
|
||||
"""
|
||||
Returns: 'sft' | 'dpo' | 'both'
|
||||
Returns one of: 'sft' | 'dpo' | 'both'
|
||||
"""
|
||||
print("\n=== TRAIN MENU ===")
|
||||
print("1) SFT (QLoRA masked-loss)")
|
||||
@ -90,8 +86,8 @@ def choose_stage_simple() -> str:
|
||||
|
||||
def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]):
|
||||
"""
|
||||
Relaunches the script via accelerate (2 processes by default).
|
||||
Only runs if NOT already inside accelerate (LOCAL_RANK not set).
|
||||
Relaunch the script via accelerate with the requested number of processes.
|
||||
No-op if already running inside accelerate (LOCAL_RANK is set).
|
||||
"""
|
||||
if os.environ.get("LOCAL_RANK") is not None:
|
||||
return
|
||||
@ -109,7 +105,7 @@ def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]):
|
||||
sys.argv[0],
|
||||
*extra_argv,
|
||||
]
|
||||
print("\n[AUTO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n")
|
||||
print("\n[INFO] Relaunching with accelerate:\n" + " ".join(cmd) + "\n")
|
||||
raise SystemExit(subprocess.call(cmd))
|
||||
|
||||
|
||||
@ -117,6 +113,9 @@ def relaunch_with_accelerate(num_processes: int, extra_argv: List[str]):
|
||||
# DDP helpers
|
||||
# --------------------------
|
||||
def ddp_info():
|
||||
"""
|
||||
Returns (local_rank, rank, world_size) inferred from environment variables.
|
||||
"""
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
rank = int(os.environ.get("RANK", "0"))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
@ -124,15 +123,18 @@ def ddp_info():
|
||||
|
||||
|
||||
def ddp_barrier():
|
||||
"""
|
||||
Synchronization barrier for distributed runs (no-op if torch.distributed is not initialized).
|
||||
"""
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
# --------------------------
|
||||
# TRL <-> Transformers patches (to avoid signature mismatches)
|
||||
# TRL <-> Transformers patches (signature compatibility)
|
||||
# --------------------------
|
||||
def patch_trl_dpo_trainer_if_needed():
|
||||
# Patch get_batch_samples
|
||||
# Patch get_batch_samples to avoid signature mismatches across versions
|
||||
try:
|
||||
sig = list(inspect.signature(DPOTrainer.get_batch_samples).parameters.keys())
|
||||
if len(sig) >= 3 and sig[1] == "model":
|
||||
@ -143,7 +145,7 @@ def patch_trl_dpo_trainer_if_needed():
|
||||
except Exception:
|
||||
print("[PATCH] Could not inspect/patch DPOTrainer.get_batch_samples")
|
||||
|
||||
# Patch compute_loss
|
||||
# Patch compute_loss to accept num_items_in_batch if needed
|
||||
try:
|
||||
sig = inspect.signature(DPOTrainer.compute_loss)
|
||||
if "num_items_in_batch" not in sig.parameters:
|
||||
@ -167,6 +169,9 @@ def patch_trl_dpo_trainer_if_needed():
|
||||
|
||||
|
||||
def assert_is_model(obj, name="model"):
|
||||
"""
|
||||
Minimal runtime check: object should look like a HF CausalLM model.
|
||||
"""
|
||||
if not hasattr(obj, "generate"):
|
||||
raise TypeError(
|
||||
"[FATAL] {} is not a HF CausalLM model. type={}; missing .generate().".format(name, type(obj))
|
||||
@ -177,14 +182,17 @@ def assert_is_model(obj, name="model"):
|
||||
# SFT helpers
|
||||
# --------------------------
|
||||
def wrap_prompt(p: str) -> str:
|
||||
"""
|
||||
Wrap a raw prompt into the instruction/response template prefix.
|
||||
"""
|
||||
p = (p or "").strip()
|
||||
return "{}{}{}".format(INSTR_PREFIX, p, RESP_PREFIX)
|
||||
|
||||
|
||||
def ensure_sft_jsonl_exists(rank: int):
|
||||
"""
|
||||
Creates ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing.
|
||||
Only rank0 writes; others wait.
|
||||
Create ./data/pku_sft.jsonl from PKU-SafeRLHF-30K if missing.
|
||||
Only rank0 writes; all other ranks wait at a barrier.
|
||||
"""
|
||||
os.makedirs("./data", exist_ok=True)
|
||||
|
||||
@ -225,6 +233,10 @@ def ensure_sft_jsonl_exists(rank: int):
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForCausalLMWithLabels:
|
||||
"""
|
||||
Pads input_ids/attention_mask and aligns labels to the padded length.
|
||||
Labels use -100 for ignored positions (standard LM loss masking).
|
||||
"""
|
||||
tokenizer: Any
|
||||
|
||||
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
@ -252,6 +264,9 @@ class DataCollatorForCausalLMWithLabels:
|
||||
|
||||
|
||||
def build_quant_base(model_dir: str, device_map: Dict[str, int]):
|
||||
"""
|
||||
Load a 4-bit quantized base model using bitsandbytes config.
|
||||
"""
|
||||
bnb = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
@ -282,7 +297,7 @@ def run_sft(
|
||||
world_size: int,
|
||||
) -> str:
|
||||
"""
|
||||
Runs SFT and returns the path to the saved LoRA adapter (out_dir).
|
||||
Run SFT training and return the saved LoRA adapter path (out_dir).
|
||||
"""
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
@ -309,6 +324,9 @@ def run_sft(
|
||||
model = get_peft_model(base, lora)
|
||||
|
||||
def tokenize_and_mask(example):
|
||||
"""
|
||||
Tokenize prompt and response, and build labels masked for the prompt part.
|
||||
"""
|
||||
prompt = example["prompt"]
|
||||
response = example["response"]
|
||||
|
||||
@ -366,7 +384,7 @@ def run_sft(
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save only on rank0, then barrier so others wait
|
||||
# Save only on rank0; other ranks wait at a barrier
|
||||
if rank == 0:
|
||||
model.save_pretrained(out_dir)
|
||||
tokenizer.save_pretrained(out_dir)
|
||||
@ -394,6 +412,9 @@ def run_dpo(
|
||||
device_map: Dict[str, int],
|
||||
rank: int,
|
||||
):
|
||||
"""
|
||||
Run DPO training starting from the base model + an SFT LoRA adapter.
|
||||
"""
|
||||
patch_trl_dpo_trainer_if_needed()
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
@ -406,7 +427,7 @@ def run_dpo(
|
||||
base_model = build_quant_base(model_dir, device_map)
|
||||
assert_is_model(base_model, "base_model")
|
||||
|
||||
# Load SFT adapter and keep training it with DPO
|
||||
# Load SFT adapter as trainable and continue training with DPO
|
||||
model = PeftModel.from_pretrained(base_model, sft_adapter, is_trainable=True)
|
||||
assert_is_model(model, "model(peft)")
|
||||
|
||||
@ -482,15 +503,13 @@ def main():
|
||||
|
||||
args = ap.parse_args()
|
||||
|
||||
# --------------------------
|
||||
# AUTO MENU (only once, BEFORE accelerate spawns processes)
|
||||
# --------------------------
|
||||
# Show menu only in the initial, non-accelerate run
|
||||
if os.environ.get("LOCAL_RANK") is None and os.environ.get("AUTO_ACCELERATE_RELAUNCH") != "1":
|
||||
os.environ["AUTO_ACCELERATE_RELAUNCH"] = "1" # prevent menu duplication
|
||||
|
||||
stage = choose_stage_simple()
|
||||
|
||||
# Build argv using defaults from args (you only press 1/2/3)
|
||||
# Build argv using defaults from args
|
||||
argv = [
|
||||
"--stage",
|
||||
stage,
|
||||
@ -536,17 +555,17 @@ def main():
|
||||
sft_adapter = args.sft_adapter or args.sft_out
|
||||
argv += ["--sft_adapter", sft_adapter]
|
||||
|
||||
# Relaunch under accelerate (2 GPUs)
|
||||
# Relaunch under accelerate (2 processes)
|
||||
relaunch_with_accelerate(num_processes=2, extra_argv=argv)
|
||||
|
||||
# --------------------------
|
||||
# We are inside accelerate now
|
||||
# Running inside accelerate
|
||||
# --------------------------
|
||||
local_rank, rank, world_size = ddp_info()
|
||||
torch.cuda.set_device(local_rank)
|
||||
device_map = {"": torch.cuda.current_device()}
|
||||
|
||||
# Deterministic seed (as in your DPO script style)
|
||||
# Fixed seed
|
||||
torch.manual_seed(42)
|
||||
|
||||
sft_adapter_path = args.sft_adapter or args.sft_out
|
||||
@ -567,7 +586,6 @@ def main():
|
||||
)
|
||||
|
||||
if args.stage in ("dpo", "both"):
|
||||
# If "both", barrier already happened after saving SFT adapter
|
||||
run_dpo(
|
||||
model_dir=args.model,
|
||||
sft_adapter=sft_adapter_path,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user