From eb4acde0a1c6ab7048d58d6eb7371d17cd946d28 Mon Sep 17 00:00:00 2001 From: Artur Hyrenko Date: Wed, 4 Feb 2026 20:15:07 +0000 Subject: [PATCH] Added comments --- Preparation/prepar_dat_pku_dpo.py | 105 ++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/Preparation/prepar_dat_pku_dpo.py b/Preparation/prepar_dat_pku_dpo.py index 2a4c0fb..299b199 100644 --- a/Preparation/prepar_dat_pku_dpo.py +++ b/Preparation/prepar_dat_pku_dpo.py @@ -11,6 +11,22 @@ OUT_DIR = "/home/hyrenko/Diploma/datasets" SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft") DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo") +# ----------------------------------------------------------------------------- +# find_schema(ds) +# Purpose: +# Detects which column layout (schema) the dataset uses, so the script knows +# how to build SFT and DPO outputs. +# +# What it checks: +# 1) If dataset already contains 'prompt', 'chosen', 'rejected' columns: +# -> returns "chosen_rejected" +# 2) If dataset contains 'prompt', 'response_0', 'response_1': +# -> tries to detect a preference column (winner indicator) among common names +# -> returns ("response01", ) +# +# If nothing matches: +# Raises RuntimeError with the dataset column list, to make debugging easier. +# ----------------------------------------------------------------------------- def find_schema(ds): cols = set(ds.column_names) @@ -36,10 +52,35 @@ def find_schema(ds): # fallback raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}") +# ----------------------------------------------------------------------------- +# normalize_prompt(p) +# Purpose: +# Normalizes prompt text to reduce formatting noise. +# +# Behavior: +# - If p is None, converts it to empty string. +# - Strips leading/trailing whitespace. +# ----------------------------------------------------------------------------- def normalize_prompt(p: str) -> str: p = (p or "").strip() return p +# ----------------------------------------------------------------------------- +# make_sft_text(prompt, answer) +# Purpose: +# Builds a single "text" field for SFT training using a simple stable template. +# +# Output format: +# ### User: +# +# +# ### Assistant: +# +# +# Notes: +# - Keeps formatting generic (not model-specific chat tokens). +# - Calls normalize_prompt() for prompt and strips answer. +# ----------------------------------------------------------------------------- def make_sft_text(prompt: str, answer: str) -> str: # Simple, stable template for causal LM (works even if model is not chat-instruct) # You can change to Slovak system-style later. @@ -47,20 +88,51 @@ def make_sft_text(prompt: str, answer: str) -> str: answer = (answer or "").strip() return f"### User:\n{prompt}\n\n### Assistant:\n{answer}" +# ----------------------------------------------------------------------------- +# main() +# Purpose: +# Orchestrates the full conversion pipeline: +# 1) Loads dataset from disk (SRC_DIR). +# 2) If it's a DatasetDict, selects a split (prefers 'train'). +# 3) Detects schema with find_schema(). +# 4) Builds: +# - SFT dataset: single column 'text' (prompt + chosen answer) +# - DPO dataset: columns 'prompt', 'chosen', 'rejected' +# 5) Saves both datasets to disk (SFT_OUT, DPO_OUT). +# +# Important: +# - This function does not validate data quality; it just transforms based on schema. +# ----------------------------------------------------------------------------- def main(): + # Load dataset stored on disk (Hugging Face datasets format). ds = load_from_disk(SRC_DIR) + + # If dataset is split into train/valid/test, pick 'train' when available. if isinstance(ds, DatasetDict): # take train if exists split = "train" if "train" in ds else list(ds.keys())[0] ds = ds[split] + # Detect which column schema is present. schema = find_schema(ds) print("[INFO] Detected schema:", schema) + # ------------------------------------------------------------------------- + # Branch A: Dataset already has explicit chosen/rejected columns. + # ------------------------------------------------------------------------- if schema == "chosen_rejected": + # to_sft(ex) + # Converts one row into SFT format: + # - Uses ex["prompt"] and ex["chosen"] + # - Produces {"text": } def to_sft(ex): return {"text": make_sft_text(ex["prompt"], ex["chosen"])} + # to_dpo(ex) + # Converts one row into DPO format: + # - Normalizes prompt + # - Strips chosen/rejected strings + # - Produces {"prompt": ..., "chosen": ..., "rejected": ...} def to_dpo(ex): return { "prompt": normalize_prompt(ex["prompt"]), @@ -68,22 +140,46 @@ def main(): "rejected": (ex["rejected"] or "").strip(), } + # Build SFT dataset: remove original columns and keep only 'text'. sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT") + + # Build DPO dataset: keep only prompt/chosen/rejected and remove everything else. dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO") + # ------------------------------------------------------------------------- + # Branch B: Dataset uses response_0/response_1 pairwise format (+ optional preference). + # ------------------------------------------------------------------------- else: # response_0/response_1 + maybe preference column _, pref_col = schema + + # If no preference column exists, the script defaults: + # chosen = response_0 + # rejected = response_1 + # for every example. if pref_col is None: print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.") else: print("[INFO] Using preference column:", pref_col) + # pick_chosen_rejected(ex) + # Purpose: + # Converts pairwise responses into explicit chosen/rejected fields. + # + # Inputs: + # - prompt, response_0, response_1 + # - optionally pref_col (winner encoding varies between datasets) + # + # Preference heuristics handled here: + # - int 0/1 meaning which response wins + # - str labels like "response_0"/"response_1" or "A"/"B" + # - bool: assumes True => response_1 is better (dataset-dependent) def pick_chosen_rejected(ex): p = normalize_prompt(ex["prompt"]) r0 = (ex["response_0"] or "").strip() r1 = (ex["response_1"] or "").strip() + # If no preference column, keep the default mapping. if pref_col is None: chosen, rejected = r0, r1 else: @@ -115,21 +211,30 @@ def main(): return {"prompt": p, "chosen": chosen, "rejected": rejected} + # First, derive explicit prompt/chosen/rejected for each row. dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected") + + # Build SFT dataset from the chosen responses (single 'text' column). sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])}, remove_columns=dpo_tmp.column_names, desc="Build SFT") + + # Keep only required DPO columns. dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"]) + # Ensure output directories exist. os.makedirs(SFT_OUT, exist_ok=True) os.makedirs(DPO_OUT, exist_ok=True) + # Save datasets to disk. sft.save_to_disk(SFT_OUT) dpo.save_to_disk(DPO_OUT) + # Print resulting paths and row counts. print("[OK] Saved SFT dataset:", SFT_OUT) print("[OK] Saved DPO dataset:", DPO_OUT) print("[INFO] SFT rows:", len(sft)) print("[INFO] DPO rows:", len(dpo)) if __name__ == "__main__": + # Standard Python entry point guard: run main() only when executed directly. main()