Added comments

This commit is contained in:
Artur Hyrenko 2026-02-04 20:34:11 +00:00
parent cf50efba3c
commit 18cdd5daa6

View File

@ -5,9 +5,7 @@ import os
import json
from datasets import load_dataset, Dataset
# =========================
# Config (как в твоих скриптах)
# =========================
DATASET_ID = "PKU-Alignment/PKU-SafeRLHF-30K"
SPLIT = "train"
@ -20,11 +18,47 @@ INSTR_PREFIX = "### Instruction:\n"
RESP_PREFIX = "\n\n### Response:\n"
# -----------------------------------------------------------------------------
# wrap_prompt(p)
# Purpose:
# Wrap a raw prompt string into a stable instruction/response template.
#
# Behavior:
# - Converts None to an empty string
# - Strips leading/trailing whitespace
# - Returns a formatted prompt with fixed prefixes
#
# Output format:
# ### Instruction:
# <prompt>
#
# ### Response:
# -----------------------------------------------------------------------------
def wrap_prompt(p: str) -> str:
p = (p or "").strip()
return f"{INSTR_PREFIX}{p}{RESP_PREFIX}"
# -----------------------------------------------------------------------------
# convert_sft(ds, out_jsonl)
# Purpose:
# Convert the dataset into an SFT jsonl file, where each line is:
# {"prompt": <wrapped_prompt>, "response": <safe_response>}
#
# Expected input schema (per row):
# - prompt
# - response_0
# - response_1
# - safer_response_id (0 or 1 indicates which response is considered "safe")
#
# Filtering rules:
# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1)
#
# Side effects:
# - Creates the output directory if needed
# - Writes JSON Lines to out_jsonl
# - Prints basic stats (saved rows, dropped rows)
# -----------------------------------------------------------------------------
def convert_sft(ds, out_jsonl: str) -> None:
os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True)
@ -54,6 +88,29 @@ def convert_sft(ds, out_jsonl: str) -> None:
print("example keys:", rows[0].keys())
# -----------------------------------------------------------------------------
# convert_dpo(ds, out_dir)
# Purpose:
# Convert the dataset into a DPO dataset directory (Hugging Face datasets format),
# with columns:
# - prompt (wrapped)
# - chosen
# - rejected
#
# Expected input schema (per row):
# - prompt
# - response_0
# - response_1
# - safer_response_id (0 or 1 indicates which response is "chosen")
#
# Filtering rules:
# - Drops rows if prompt/response_0/response_1 are empty, or safer_response_id not in (0, 1)
#
# Side effects:
# - Creates the output directory if needed
# - Saves the dataset to out_dir using save_to_disk()
# - Prints basic stats (saved rows, dropped rows, resulting columns)
# -----------------------------------------------------------------------------
def convert_dpo(ds, out_dir: str) -> None:
os.makedirs(os.path.dirname(out_dir) or ".", exist_ok=True)
@ -83,6 +140,19 @@ def convert_dpo(ds, out_dir: str) -> None:
print("columns:", dpo.column_names)
# -----------------------------------------------------------------------------
# ask_choice()
# Purpose:
# Simple CLI menu to select what to generate:
# 1) SFT only
# 2) DPO only
# 3) Both
# 0) Exit
#
# Behavior:
# - Loops until a valid choice is entered
# - Returns one of: "0", "1", "2", "3"
# -----------------------------------------------------------------------------
def ask_choice() -> str:
print("\nWhat do you want to generate?")
print(" 1) SFT (jsonl: prompt+response)")
@ -96,6 +166,14 @@ def ask_choice() -> str:
print("Invalid choice. Please enter 1, 2, 3, or 0.")
# -----------------------------------------------------------------------------
# main()
# Purpose:
# Entry point:
# - Loads the dataset from Hugging Face Hub
# - Presents an interactive menu
# - Generates SFT and/or DPO outputs under ./data
# -----------------------------------------------------------------------------
def main():
print(f"Loading dataset: {DATASET_ID} split={SPLIT}")
ds = load_dataset(DATASET_ID, split=SPLIT)
@ -105,7 +183,7 @@ def main():
print("Exit.")
return
# Всегда создаём ./data как в оригинальных скриптах
# Ensure ./data exists (used by the output paths)
os.makedirs("./data", exist_ok=True)
if choice == "1":