Added comments
This commit is contained in:
parent
cf50efba3c
commit
18cdd5daa6
@ -5,9 +5,7 @@ import os
|
||||
import json
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
# =========================
|
||||
# Config (как в твоих скриптах)
|
||||
# =========================
|
||||
|
||||
DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K"
|
||||
SPLIT = "train"
|
||||
|
||||
@ -20,11 +18,47 @@ INSTR_PREFIX = "### Instruction:\n"
|
||||
RESP_PREFIX = "\n\n### Response:\n"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# wrap_prompt(p)
|
||||
# Purpose:
|
||||
# Wrap a raw prompt string into a stable instruction/response template.
|
||||
#
|
||||
# Behavior:
|
||||
# - Converts None to an empty string
|
||||
# - Strips leading/trailing whitespace
|
||||
# - Returns a formatted prompt with fixed prefixes
|
||||
#
|
||||
# Output format:
|
||||
# ### Instruction:
|
||||
# <prompt>
|
||||
#
|
||||
# ### Response:
|
||||
# -----------------------------------------------------------------------------
|
||||
def wrap_prompt(p: str) -> str:
|
||||
p = (p or "").strip()
|
||||
return f"{INSTR_PREFIX}{p}{RESP_PREFIX}"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# convert_sft(ds, out_jsonl)
|
||||
# Purpose:
|
||||
# Convert the dataset into an SFT jsonl file, where each line is:
|
||||
# {"prompt": <wrapped_prompt>, "response": <safe_response>}
|
||||
#
|
||||
# Expected input schema (per row):
|
||||
# - prompt
|
||||
# - response_0
|
||||
# - response_1
|
||||
# - safer_response_id (0 or 1 indicates which response is considered "safe")
|
||||
#
|
||||
# Filtering rules:
|
||||
# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1)
|
||||
#
|
||||
# Side effects:
|
||||
# - Creates the output directory if needed
|
||||
# - Writes JSON Lines to out_jsonl
|
||||
# - Prints basic stats (saved rows, dropped rows)
|
||||
# -----------------------------------------------------------------------------
|
||||
def convert_sft(ds, out_jsonl: str) -> None:
|
||||
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
|
||||
|
||||
@ -54,6 +88,29 @@ def convert_sft(ds, out_jsonl: str) -> None:
|
||||
print("example keys:", rows[0].keys())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# convert_dpo(ds, out_dir)
|
||||
# Purpose:
|
||||
# Convert the dataset into a DPO dataset directory (Hugging Face datasets format),
|
||||
# with columns:
|
||||
# - prompt (wrapped)
|
||||
# - chosen
|
||||
# - rejected
|
||||
#
|
||||
# Expected input schema (per row):
|
||||
# - prompt
|
||||
# - response_0
|
||||
# - response_1
|
||||
# - safer_response_id (0 or 1 indicates which response is "chosen")
|
||||
#
|
||||
# Filtering rules:
|
||||
# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1)
|
||||
#
|
||||
# Side effects:
|
||||
# - Creates the output directory if needed
|
||||
# - Saves the dataset to out_dir using save_to_disk()
|
||||
# - Prints basic stats (saved rows, dropped rows, resulting columns)
|
||||
# -----------------------------------------------------------------------------
|
||||
def convert_dpo(ds, out_dir: str) -> None:
|
||||
os.makedirs(os.path.dirname(out_dir) or ".", exist_ok=True)
|
||||
|
||||
@ -83,6 +140,19 @@ def convert_dpo(ds, out_dir: str) -> None:
|
||||
print("columns:", dpo.column_names)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ask_choice()
|
||||
# Purpose:
|
||||
# Simple CLI menu to select what to generate:
|
||||
# 1) SFT only
|
||||
# 2) DPO only
|
||||
# 3) Both
|
||||
# 0) Exit
|
||||
#
|
||||
# Behavior:
|
||||
# - Loops until a valid choice is entered
|
||||
# - Returns one of: "0", "1", "2", "3"
|
||||
# -----------------------------------------------------------------------------
|
||||
def ask_choice() -> str:
|
||||
print("\nWhat do you want to generate?")
|
||||
print(" 1) SFT (jsonl: prompt+response)")
|
||||
@ -96,6 +166,14 @@ def ask_choice() -> str:
|
||||
print("Invalid choice. Please enter 1, 2, 3, or 0.")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# main()
|
||||
# Purpose:
|
||||
# Entry point:
|
||||
# - Loads the dataset from Hugging Face Hub
|
||||
# - Presents an interactive menu
|
||||
# - Generates SFT and/or DPO outputs under ./data
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
print(f"Loading dataset: {DATASET_ID} split={SPLIT}")
|
||||
ds = load_dataset(DATASET_ID, split=SPLIT)
|
||||
@ -105,7 +183,7 @@ def main():
|
||||
print("Exit.")
|
||||
return
|
||||
|
||||
# Всегда создаём ./data как в оригинальных скриптах
|
||||
# Ensure ./data exists (used by the output paths)
|
||||
os.makedirs("./data", exist_ok=True)
|
||||
|
||||
if choice == "1":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user