200 lines
5.8 KiB
Python
200 lines
5.8 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import json
|
|
from datasets import load_dataset, Dataset
|
|
|
|
|
|
DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K"
|
|
SPLIT = "train"
|
|
|
|
# Outputs
|
|
OUT_SFT_JSONL = "./data/pku_sft.jsonl"
|
|
OUT_DPO_DIR = "./data/pku_dpo"
|
|
|
|
# Prompt template
|
|
INSTR_PREFIX = "### Instruction:\n"
|
|
RESP_PREFIX = "\n\n### Response:\n"
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# wrap_prompt(p)
|
|
# Purpose:
|
|
# Wrap a raw prompt string into a stable instruction/response template.
|
|
#
|
|
# Behavior:
|
|
# - Converts None to an empty string
|
|
# - Strips leading/trailing whitespace
|
|
# - Returns a formatted prompt with fixed prefixes
|
|
#
|
|
# Output format:
|
|
# ### Instruction:
|
|
# <prompt>
|
|
#
|
|
# ### Response:
|
|
# -----------------------------------------------------------------------------
|
|
def wrap_prompt(p: str) -> str:
|
|
p = (p or "").strip()
|
|
return f"{INSTR_PREFIX}{p}{RESP_PREFIX}"
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# convert_sft(ds, out_jsonl)
|
|
# Purpose:
|
|
# Convert the dataset into an SFT jsonl file, where each line is:
|
|
# {"prompt": <wrapped_prompt>, "response": <safe_response>}
|
|
#
|
|
# Expected input schema (per row):
|
|
# - prompt
|
|
# - response_0
|
|
# - response_1
|
|
# - safer_response_id (0 or 1 indicates which response is considered "safe")
|
|
#
|
|
# Filtering rules:
|
|
# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1)
|
|
#
|
|
# Side effects:
|
|
# - Creates the output directory if needed
|
|
# - Writes JSON Lines to out_jsonl
|
|
# - Prints basic stats (saved rows, dropped rows)
|
|
# -----------------------------------------------------------------------------
|
|
def convert_sft(ds, out_jsonl: str) -> None:
|
|
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
|
|
|
|
rows = []
|
|
dropped = 0
|
|
|
|
for ex in ds:
|
|
prompt = (ex.get("prompt") or "").strip()
|
|
r0 = (ex.get("response_0") or "").strip()
|
|
r1 = (ex.get("response_1") or "").strip()
|
|
sid = ex.get("safer_response_id", None)
|
|
|
|
if not prompt or not r0 or not r1 or sid not in (0, 1):
|
|
dropped += 1
|
|
continue
|
|
|
|
safe = r0 if sid == 0 else r1
|
|
rows.append({"prompt": wrap_prompt(prompt), "response": safe})
|
|
|
|
with open(out_jsonl, "w", encoding="utf-8") as f:
|
|
for r in rows:
|
|
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
|
|
|
print(f"[OK] SFT saved: {out_jsonl}")
|
|
print(f"rows: {len(rows)} | dropped: {dropped}")
|
|
if rows:
|
|
print("example keys:", rows[0].keys())
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# convert_dpo(ds, out_dir)
|
|
# Purpose:
|
|
# Convert the dataset into a DPO dataset directory (Hugging Face datasets format),
|
|
# with columns:
|
|
# - prompt (wrapped)
|
|
# - chosen
|
|
# - rejected
|
|
#
|
|
# Expected input schema (per row):
|
|
# - prompt
|
|
# - response_0
|
|
# - response_1
|
|
# - safer_response_id (0 or 1 indicates which response is "chosen")
|
|
#
|
|
# Filtering rules:
|
|
# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1)
|
|
#
|
|
# Side effects:
|
|
# - Creates the output directory if needed
|
|
# - Saves the dataset to out_dir using save_to_disk()
|
|
# - Prints basic stats (saved rows, dropped rows, resulting columns)
|
|
# -----------------------------------------------------------------------------
|
|
def convert_dpo(ds, out_dir: str) -> None:
|
|
os.makedirs(os.path.dirname(out_dir) or ".", exist_ok=True)
|
|
|
|
rows = []
|
|
dropped = 0
|
|
|
|
for ex in ds:
|
|
prompt = (ex.get("prompt") or "").strip()
|
|
r0 = (ex.get("response_0") or "").strip()
|
|
r1 = (ex.get("response_1") or "").strip()
|
|
sid = ex.get("safer_response_id", None)
|
|
|
|
if not prompt or not r0 or not r1 or sid not in (0, 1):
|
|
dropped += 1
|
|
continue
|
|
|
|
chosen = r0 if sid == 0 else r1
|
|
rejected = r1 if sid == 0 else r0
|
|
|
|
rows.append({"prompt": wrap_prompt(prompt), "chosen": chosen, "rejected": rejected})
|
|
|
|
dpo = Dataset.from_list(rows)
|
|
dpo.save_to_disk(out_dir)
|
|
|
|
print(f"[OK] DPO saved to: {out_dir}")
|
|
print(f"rows: {len(rows)} | dropped: {dropped}")
|
|
print("columns:", dpo.column_names)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# ask_choice()
|
|
# Purpose:
|
|
# Simple CLI menu to select what to generate:
|
|
# 1) SFT only
|
|
# 2) DPO only
|
|
# 3) Both
|
|
# 0) Exit
|
|
#
|
|
# Behavior:
|
|
# - Loops until a valid choice is entered
|
|
# - Returns one of: "0", "1", "2", "3"
|
|
# -----------------------------------------------------------------------------
|
|
def ask_choice() -> str:
|
|
print("\nWhat do you want to generate?")
|
|
print(" 1) SFT (jsonl: prompt+response)")
|
|
print(" 2) DPO (dataset dir: prompt+chosen+rejected)")
|
|
print(" 3) BOTH")
|
|
print(" 0) Exit")
|
|
while True:
|
|
choice = input("Select [1/2/3/0]: ").strip()
|
|
if choice in {"0", "1", "2", "3"}:
|
|
return choice
|
|
print("Invalid choice. Please enter 1, 2, 3, or 0.")
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# main()
|
|
# Purpose:
|
|
# Entry point:
|
|
# - Loads the dataset from Hugging Face Hub
|
|
# - Presents an interactive menu
|
|
# - Generates SFT and/or DPO outputs under ./data
|
|
# -----------------------------------------------------------------------------
|
|
def main():
|
|
print(f"Loading dataset: {DATASET_ID} split={SPLIT}")
|
|
ds = load_dataset(DATASET_ID, split=SPLIT)
|
|
|
|
choice = ask_choice()
|
|
if choice == "0":
|
|
print("Exit.")
|
|
return
|
|
|
|
# Ensure ./data exists (used by the output paths)
|
|
os.makedirs("./data", exist_ok=True)
|
|
|
|
if choice == "1":
|
|
convert_sft(ds, OUT_SFT_JSONL)
|
|
elif choice == "2":
|
|
convert_dpo(ds, OUT_DPO_DIR)
|
|
elif choice == "3":
|
|
convert_sft(ds, OUT_SFT_JSONL)
|
|
convert_dpo(ds, OUT_DPO_DIR)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|