DP/Training/convert_dpo_sft.py
2026-02-04 20:34:11 +00:00

200 lines
5.8 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
from datasets import load_dataset, Dataset
DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K"
SPLIT = "train"
# Outputs
OUT_SFT_JSONL = "./data/pku_sft.jsonl"
OUT_DPO_DIR = "./data/pku_dpo"
# Prompt template
INSTR_PREFIX = "### Instruction:\n"
RESP_PREFIX = "\n\n### Response:\n"
# -----------------------------------------------------------------------------
# wrap_prompt(p)
# Purpose:
# Wrap a raw prompt string into a stable instruction/response template.
#
# Behavior:
# - Converts None to an empty string
# - Strips leading/trailing whitespace
# - Returns a formatted prompt with fixed prefixes
#
# Output format:
# ### Instruction:
# <prompt>
#
# ### Response:
# -----------------------------------------------------------------------------
def wrap_prompt(p: str) -> str:
p = (p or "").strip()
return f"{INSTR_PREFIX}{p}{RESP_PREFIX}"
# -----------------------------------------------------------------------------
# convert_sft(ds, out_jsonl)
# Purpose:
# Convert the dataset into an SFT jsonl file, where each line is:
# {"prompt": <wrapped_prompt>, "response": <safe_response>}
#
# Expected input schema (per row):
# - prompt
# - response_0
# - response_1
# - safer_response_id (0 or 1 indicates which response is considered "safe")
#
# Filtering rules:
# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1)
#
# Side effects:
# - Creates the output directory if needed
# - Writes JSON Lines to out_jsonl
# - Prints basic stats (saved rows, dropped rows)
# -----------------------------------------------------------------------------
def convert_sft(ds, out_jsonl: str) -> None:
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
rows = []
dropped = 0
for ex in ds:
prompt = (ex.get("prompt") or "").strip()
r0 = (ex.get("response_0") or "").strip()
r1 = (ex.get("response_1") or "").strip()
sid = ex.get("safer_response_id", None)
if not prompt or not r0 or not r1 or sid not in (0, 1):
dropped += 1
continue
safe = r0 if sid == 0 else r1
rows.append({"prompt": wrap_prompt(prompt), "response": safe})
with open(out_jsonl, "w", encoding="utf-8") as f:
for r in rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
print(f"[OK] SFT saved: {out_jsonl}")
print(f"rows: {len(rows)} | dropped: {dropped}")
if rows:
print("example keys:", rows[0].keys())
# -----------------------------------------------------------------------------
# convert_dpo(ds, out_dir)
# Purpose:
# Convert the dataset into a DPO dataset directory (Hugging Face datasets format),
# with columns:
# - prompt (wrapped)
# - chosen
# - rejected
#
# Expected input schema (per row):
# - prompt
# - response_0
# - response_1
# - safer_response_id (0 or 1 indicates which response is "chosen")
#
# Filtering rules:
# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1)
#
# Side effects:
# - Creates the output directory if needed
# - Saves the dataset to out_dir using save_to_disk()
# - Prints basic stats (saved rows, dropped rows, resulting columns)
# -----------------------------------------------------------------------------
def convert_dpo(ds, out_dir: str) -> None:
os.makedirs(os.path.dirname(out_dir) or ".", exist_ok=True)
rows = []
dropped = 0
for ex in ds:
prompt = (ex.get("prompt") or "").strip()
r0 = (ex.get("response_0") or "").strip()
r1 = (ex.get("response_1") or "").strip()
sid = ex.get("safer_response_id", None)
if not prompt or not r0 or not r1 or sid not in (0, 1):
dropped += 1
continue
chosen = r0 if sid == 0 else r1
rejected = r1 if sid == 0 else r0
rows.append({"prompt": wrap_prompt(prompt), "chosen": chosen, "rejected": rejected})
dpo = Dataset.from_list(rows)
dpo.save_to_disk(out_dir)
print(f"[OK] DPO saved to: {out_dir}")
print(f"rows: {len(rows)} | dropped: {dropped}")
print("columns:", dpo.column_names)
# -----------------------------------------------------------------------------
# ask_choice()
# Purpose:
# Simple CLI menu to select what to generate:
# 1) SFT only
# 2) DPO only
# 3) Both
# 0) Exit
#
# Behavior:
# - Loops until a valid choice is entered
# - Returns one of: "0", "1", "2", "3"
# -----------------------------------------------------------------------------
def ask_choice() -> str:
print("\nWhat do you want to generate?")
print(" 1) SFT (jsonl: prompt+response)")
print(" 2) DPO (dataset dir: prompt+chosen+rejected)")
print(" 3) BOTH")
print(" 0) Exit")
while True:
choice = input("Select [1/2/3/0]: ").strip()
if choice in {"0", "1", "2", "3"}:
return choice
print("Invalid choice. Please enter 1, 2, 3, or 0.")
# -----------------------------------------------------------------------------
# main()
# Purpose:
# Entry point:
# - Loads the dataset from Hugging Face Hub
# - Presents an interactive menu
# - Generates SFT and/or DPO outputs under ./data
# -----------------------------------------------------------------------------
def main():
print(f"Loading dataset: {DATASET_ID} split={SPLIT}")
ds = load_dataset(DATASET_ID, split=SPLIT)
choice = ask_choice()
if choice == "0":
print("Exit.")
return
# Ensure ./data exists (used by the output paths)
os.makedirs("./data", exist_ok=True)
if choice == "1":
convert_sft(ds, OUT_SFT_JSONL)
elif choice == "2":
convert_dpo(ds, OUT_DPO_DIR)
elif choice == "3":
convert_sft(ds, OUT_SFT_JSONL)
convert_dpo(ds, OUT_DPO_DIR)
if __name__ == "__main__":
main()