diff --git a/Training/convert_dpo_sft.py b/Training/convert_dpo_sft.py new file mode 100644 index 0000000..4148e36 --- /dev/null +++ b/Training/convert_dpo_sft.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import json +from datasets import load_dataset, Dataset + +# ========================= +# Config (как в твоих скриптах) +# ========================= +DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K" +SPLIT = "train" + +# Outputs +OUT_SFT_JSONL = "./data/pku_sft.jsonl" +OUT_DPO_DIR = "./data/pku_dpo" + +# Prompt template +INSTR_PREFIX = "### Instruction:\n" +RESP_PREFIX = "\n\n### Response:\n" + + +def wrap_prompt(p: str) -> str: + p = (p or "").strip() + return f"{INSTR_PREFIX}{p}{RESP_PREFIX}" + + +def convert_sft(ds, out_jsonl: str) -> None: + os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True) + + rows = [] + dropped = 0 + + for ex in ds: + prompt = (ex.get("prompt") or "").strip() + r0 = (ex.get("response_0") or "").strip() + r1 = (ex.get("response_1") or "").strip() + sid = ex.get("safer_response_id", None) + + if not prompt or not r0 or not r1 or sid not in (0, 1): + dropped += 1 + continue + + safe = r0 if sid == 0 else r1 + rows.append({"prompt": wrap_prompt(prompt), "response": safe}) + + with open(out_jsonl, "w", encoding="utf-8") as f: + for r in rows: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + + print(f"[OK] SFT saved: {out_jsonl}") + print(f"rows: {len(rows)} | dropped: {dropped}") + if rows: + print("example keys:", rows[0].keys()) + + +def convert_dpo(ds, out_dir: str) -> None: + os.makedirs(os.path.dirname(out_dir) or ".", exist_ok=True) + + rows = [] + dropped = 0 + + for ex in ds: + prompt = (ex.get("prompt") or "").strip() + r0 = (ex.get("response_0") or "").strip() + r1 = (ex.get("response_1") or "").strip() + sid = ex.get("safer_response_id", None) + + if not prompt or not r0 or not r1 or sid not in (0, 1): + dropped += 1 + continue + + chosen = r0 if sid == 0 else r1 + rejected = r1 if sid == 0 else r0 + + rows.append({"prompt": wrap_prompt(prompt), "chosen": chosen, "rejected": rejected}) + + dpo = Dataset.from_list(rows) + dpo.save_to_disk(out_dir) + + print(f"[OK] DPO saved to: {out_dir}") + print(f"rows: {len(rows)} | dropped: {dropped}") + print("columns:", dpo.column_names) + + +def ask_choice() -> str: + print("\nWhat do you want to generate?") + print(" 1) SFT (jsonl: prompt+response)") + print(" 2) DPO (dataset dir: prompt+chosen+rejected)") + print(" 3) BOTH") + print(" 0) Exit") + while True: + choice = input("Select [1/2/3/0]: ").strip() + if choice in {"0", "1", "2", "3"}: + return choice + print("Invalid choice. Please enter 1, 2, 3, or 0.") + + +def main(): + print(f"Loading dataset: {DATASET_ID} split={SPLIT}") + ds = load_dataset(DATASET_ID, split=SPLIT) + + choice = ask_choice() + if choice == "0": + print("Exit.") + return + + # Всегда создаём ./data как в оригинальных скриптах + os.makedirs("./data", exist_ok=True) + + if choice == "1": + convert_sft(ds, OUT_SFT_JSONL) + elif choice == "2": + convert_dpo(ds, OUT_DPO_DIR) + elif choice == "3": + convert_sft(ds, OUT_SFT_JSONL) + convert_dpo(ds, OUT_DPO_DIR) + + +if __name__ == "__main__": + main()