Added script that converts PKF

This commit is contained in:
Artur Hyrenko 2026-02-02 15:07:53 +01:00
parent b1b83e313f
commit 4e68925da5

121
Training/convert_dpo_sft.py Normal file
View File

@ -0,0 +1,121 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
from datasets import load_dataset, Dataset
# =========================
# Config (как в твоих скриптах)
# =========================
DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K"
SPLIT = "train"
# Outputs
OUT_SFT_JSONL = "./data/pku_sft.jsonl"
OUT_DPO_DIR = "./data/pku_dpo"
# Prompt template
INSTR_PREFIX = "### Instruction:\n"
RESP_PREFIX = "\n\n### Response:\n"
def wrap_prompt(p: str) -> str:
p = (p or "").strip()
return f"{INSTR_PREFIX}{p}{RESP_PREFIX}"
def convert_sft(ds, out_jsonl: str) -> None:
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
rows = []
dropped = 0
for ex in ds:
prompt = (ex.get("prompt") or "").strip()
r0 = (ex.get("response_0") or "").strip()
r1 = (ex.get("response_1") or "").strip()
sid = ex.get("safer_response_id", None)
if not prompt or not r0 or not r1 or sid not in (0, 1):
dropped += 1
continue
safe = r0 if sid == 0 else r1
rows.append({"prompt": wrap_prompt(prompt), "response": safe})
with open(out_jsonl, "w", encoding="utf-8") as f:
for r in rows:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
print(f"[OK] SFT saved: {out_jsonl}")
print(f"rows: {len(rows)} | dropped: {dropped}")
if rows:
print("example keys:", rows[0].keys())
def convert_dpo(ds, out_dir: str) -> None:
os.makedirs(os.path.dirname(out_dir) or ".", exist_ok=True)
rows = []
dropped = 0
for ex in ds:
prompt = (ex.get("prompt") or "").strip()
r0 = (ex.get("response_0") or "").strip()
r1 = (ex.get("response_1") or "").strip()
sid = ex.get("safer_response_id", None)
if not prompt or not r0 or not r1 or sid not in (0, 1):
dropped += 1
continue
chosen = r0 if sid == 0 else r1
rejected = r1 if sid == 0 else r0
rows.append({"prompt": wrap_prompt(prompt), "chosen": chosen, "rejected": rejected})
dpo = Dataset.from_list(rows)
dpo.save_to_disk(out_dir)
print(f"[OK] DPO saved to: {out_dir}")
print(f"rows: {len(rows)} | dropped: {dropped}")
print("columns:", dpo.column_names)
def ask_choice() -> str:
print("\nWhat do you want to generate?")
print(" 1) SFT (jsonl: prompt+response)")
print(" 2) DPO (dataset dir: prompt+chosen+rejected)")
print(" 3) BOTH")
print(" 0) Exit")
while True:
choice = input("Select [1/2/3/0]: ").strip()
if choice in {"0", "1", "2", "3"}:
return choice
print("Invalid choice. Please enter 1, 2, 3, or 0.")
def main():
print(f"Loading dataset: {DATASET_ID} split={SPLIT}")
ds = load_dataset(DATASET_ID, split=SPLIT)
choice = ask_choice()
if choice == "0":
print("Exit.")
return
# Всегда создаём ./data как в оригинальных скриптах
os.makedirs("./data", exist_ok=True)
if choice == "1":
convert_sft(ds, OUT_SFT_JSONL)
elif choice == "2":
convert_dpo(ds, OUT_DPO_DIR)
elif choice == "3":
convert_sft(ds, OUT_SFT_JSONL)
convert_dpo(ds, OUT_DPO_DIR)
if __name__ == "__main__":
main()