diff --git a/Training/convert_dpo_sft.py b/Training/convert_dpo_sft.py index 4148e36..16f1712 100644 --- a/Training/convert_dpo_sft.py +++ b/Training/convert_dpo_sft.py @@ -5,9 +5,7 @@ import os import json from datasets import load_dataset, Dataset -# ========================= -# Config (как в твоих скриптах) -# ========================= + DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K" SPLIT = "train" @@ -20,11 +18,47 @@ INSTR_PREFIX = "### Instruction:\n" RESP_PREFIX = "\n\n### Response:\n" +# ----------------------------------------------------------------------------- +# wrap_prompt(p) +# Purpose: +# Wrap a raw prompt string into a stable instruction/response template. +# +# Behavior: +# - Converts None to an empty string +# - Strips leading/trailing whitespace +# - Returns a formatted prompt with fixed prefixes +# +# Output format: +# ### Instruction: +# +# +# ### Response: +# ----------------------------------------------------------------------------- def wrap_prompt(p: str) -> str: p = (p or "").strip() return f"{INSTR_PREFIX}{p}{RESP_PREFIX}" +# ----------------------------------------------------------------------------- +# convert_sft(ds, out_jsonl) +# Purpose: +# Convert the dataset into an SFT jsonl file, where each line is: +# {"prompt": , "response": } +# +# Expected input schema (per row): +# - prompt +# - response_0 +# - response_1 +# - safer_response_id (0 or 1 indicates which response is considered "safe") +# +# Filtering rules: +# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1) +# +# Side effects: +# - Creates the output directory if needed +# - Writes JSON Lines to out_jsonl +# - Prints basic stats (saved rows, dropped rows) +# ----------------------------------------------------------------------------- def convert_sft(ds, out_jsonl: str) -> None: os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True) @@ -54,6 +88,29 @@ def convert_sft(ds, out_jsonl: str) -> None: print("example keys:", rows[0].keys()) +# ----------------------------------------------------------------------------- +# convert_dpo(ds, out_dir) +# Purpose: +# Convert the dataset into a DPO dataset directory (Hugging Face datasets format), +# with columns: +# - prompt (wrapped) +# - chosen +# - rejected +# +# Expected input schema (per row): +# - prompt +# - response_0 +# - response_1 +# - safer_response_id (0 or 1 indicates which response is "chosen") +# +# Filtering rules: +# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1) +# +# Side effects: +# - Creates the output directory if needed +# - Saves the dataset to out_dir using save_to_disk() +# - Prints basic stats (saved rows, dropped rows, resulting columns) +# ----------------------------------------------------------------------------- def convert_dpo(ds, out_dir: str) -> None: os.makedirs(os.path.dirname(out_dir) or ".", exist_ok=True) @@ -83,6 +140,19 @@ def convert_dpo(ds, out_dir: str) -> None: print("columns:", dpo.column_names) +# ----------------------------------------------------------------------------- +# ask_choice() +# Purpose: +# Simple CLI menu to select what to generate: +# 1) SFT only +# 2) DPO only +# 3) Both +# 0) Exit +# +# Behavior: +# - Loops until a valid choice is entered +# - Returns one of: "0", "1", "2", "3" +# ----------------------------------------------------------------------------- def ask_choice() -> str: print("\nWhat do you want to generate?") print(" 1) SFT (jsonl: prompt+response)") @@ -96,6 +166,14 @@ def ask_choice() -> str: print("Invalid choice. Please enter 1, 2, 3, or 0.") +# ----------------------------------------------------------------------------- +# main() +# Purpose: +# Entry point: +# - Loads the dataset from Hugging Face Hub +# - Presents an interactive menu +# - Generates SFT and/or DPO outputs under ./data +# ----------------------------------------------------------------------------- def main(): print(f"Loading dataset: {DATASET_ID} split={SPLIT}") ds = load_dataset(DATASET_ID, split=SPLIT) @@ -105,7 +183,7 @@ def main(): print("Exit.") return - # Всегда создаём ./data как в оригинальных скриптах + # Ensure ./data exists (used by the output paths) os.makedirs("./data", exist_ok=True) if choice == "1":