Added script that converts PKF
This commit is contained in:
parent
b1b83e313f
commit
4e68925da5
121
Training/convert_dpo_sft.py
Normal file
121
Training/convert_dpo_sft.py
Normal file
@ -0,0 +1,121 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import json
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
# =========================
|
||||
# Config (как в твоих скриптах)
|
||||
# =========================
|
||||
DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K"
|
||||
SPLIT = "train"
|
||||
|
||||
# Outputs
|
||||
OUT_SFT_JSONL = "./data/pku_sft.jsonl"
|
||||
OUT_DPO_DIR = "./data/pku_dpo"
|
||||
|
||||
# Prompt template
|
||||
INSTR_PREFIX = "### Instruction:\n"
|
||||
RESP_PREFIX = "\n\n### Response:\n"
|
||||
|
||||
|
||||
def wrap_prompt(p: str) -> str:
|
||||
p = (p or "").strip()
|
||||
return f"{INSTR_PREFIX}{p}{RESP_PREFIX}"
|
||||
|
||||
|
||||
def convert_sft(ds, out_jsonl: str) -> None:
|
||||
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
|
||||
|
||||
rows = []
|
||||
dropped = 0
|
||||
|
||||
for ex in ds:
|
||||
prompt = (ex.get("prompt") or "").strip()
|
||||
r0 = (ex.get("response_0") or "").strip()
|
||||
r1 = (ex.get("response_1") or "").strip()
|
||||
sid = ex.get("safer_response_id", None)
|
||||
|
||||
if not prompt or not r0 or not r1 or sid not in (0, 1):
|
||||
dropped += 1
|
||||
continue
|
||||
|
||||
safe = r0 if sid == 0 else r1
|
||||
rows.append({"prompt": wrap_prompt(prompt), "response": safe})
|
||||
|
||||
with open(out_jsonl, "w", encoding="utf-8") as f:
|
||||
for r in rows:
|
||||
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"[OK] SFT saved: {out_jsonl}")
|
||||
print(f"rows: {len(rows)} | dropped: {dropped}")
|
||||
if rows:
|
||||
print("example keys:", rows[0].keys())
|
||||
|
||||
|
||||
def convert_dpo(ds, out_dir: str) -> None:
|
||||
os.makedirs(os.path.dirname(out_dir) or ".", exist_ok=True)
|
||||
|
||||
rows = []
|
||||
dropped = 0
|
||||
|
||||
for ex in ds:
|
||||
prompt = (ex.get("prompt") or "").strip()
|
||||
r0 = (ex.get("response_0") or "").strip()
|
||||
r1 = (ex.get("response_1") or "").strip()
|
||||
sid = ex.get("safer_response_id", None)
|
||||
|
||||
if not prompt or not r0 or not r1 or sid not in (0, 1):
|
||||
dropped += 1
|
||||
continue
|
||||
|
||||
chosen = r0 if sid == 0 else r1
|
||||
rejected = r1 if sid == 0 else r0
|
||||
|
||||
rows.append({"prompt": wrap_prompt(prompt), "chosen": chosen, "rejected": rejected})
|
||||
|
||||
dpo = Dataset.from_list(rows)
|
||||
dpo.save_to_disk(out_dir)
|
||||
|
||||
print(f"[OK] DPO saved to: {out_dir}")
|
||||
print(f"rows: {len(rows)} | dropped: {dropped}")
|
||||
print("columns:", dpo.column_names)
|
||||
|
||||
|
||||
def ask_choice() -> str:
|
||||
print("\nWhat do you want to generate?")
|
||||
print(" 1) SFT (jsonl: prompt+response)")
|
||||
print(" 2) DPO (dataset dir: prompt+chosen+rejected)")
|
||||
print(" 3) BOTH")
|
||||
print(" 0) Exit")
|
||||
while True:
|
||||
choice = input("Select [1/2/3/0]: ").strip()
|
||||
if choice in {"0", "1", "2", "3"}:
|
||||
return choice
|
||||
print("Invalid choice. Please enter 1, 2, 3, or 0.")
|
||||
|
||||
|
||||
def main():
|
||||
print(f"Loading dataset: {DATASET_ID} split={SPLIT}")
|
||||
ds = load_dataset(DATASET_ID, split=SPLIT)
|
||||
|
||||
choice = ask_choice()
|
||||
if choice == "0":
|
||||
print("Exit.")
|
||||
return
|
||||
|
||||
# Всегда создаём ./data как в оригинальных скриптах
|
||||
os.makedirs("./data", exist_ok=True)
|
||||
|
||||
if choice == "1":
|
||||
convert_sft(ds, OUT_SFT_JSONL)
|
||||
elif choice == "2":
|
||||
convert_dpo(ds, OUT_DPO_DIR)
|
||||
elif choice == "3":
|
||||
convert_sft(ds, OUT_SFT_JSONL)
|
||||
convert_dpo(ds, OUT_DPO_DIR)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user