diff --git a/preparation/prepar_dat_pku_dpo.py b/preparation/prepar_dat_pku_dpo.py new file mode 100644 index 0000000..299b199 --- /dev/null +++ b/preparation/prepar_dat_pku_dpo.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import json +from datasets import load_from_disk, DatasetDict + +SRC_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY" +OUT_DIR = "/home/hyrenko/Diploma/datasets" + +SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft") +DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo") + +# ----------------------------------------------------------------------------- +# find_schema(ds) +# Purpose: +# Detects which column layout (schema) the dataset uses, so the script knows +# how to build SFT and DPO outputs. +# +# What it checks: +# 1) If dataset already contains 'prompt', 'chosen', 'rejected' columns: +# -> returns "chosen_rejected" +# 2) If dataset contains 'prompt', 'response_0', 'response_1': +# -> tries to detect a preference column (winner indicator) among common names +# -> returns ("response01", ) +# +# If nothing matches: +# Raises RuntimeError with the dataset column list, to make debugging easier. +# ----------------------------------------------------------------------------- +def find_schema(ds): + cols = set(ds.column_names) + + # Best case + if {"prompt", "chosen", "rejected"}.issubset(cols): + return "chosen_rejected" + + # Pairwise variants + if {"prompt", "response_0", "response_1"}.issubset(cols): + # need preference column to choose winner; try common names + pref_candidates = [ + "choice", "label", "winner", "preferred", "preference", + "better", "better_response", "better_response_id", "chosen_id", + "win", "response_choice" + ] + pref = None + for c in pref_candidates: + if c in cols: + pref = c + break + return ("response01", pref) + + # fallback + raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}") + +# ----------------------------------------------------------------------------- +# normalize_prompt(p) +# Purpose: +# Normalizes prompt text to reduce formatting noise. +# +# Behavior: +# - If p is None, converts it to empty string. +# - Strips leading/trailing whitespace. +# ----------------------------------------------------------------------------- +def normalize_prompt(p: str) -> str: + p = (p or "").strip() + return p + +# ----------------------------------------------------------------------------- +# make_sft_text(prompt, answer) +# Purpose: +# Builds a single "text" field for SFT training using a simple stable template. +# +# Output format: +# ### User: +# +# +# ### Assistant: +# +# +# Notes: +# - Keeps formatting generic (not model-specific chat tokens). +# - Calls normalize_prompt() for prompt and strips answer. +# ----------------------------------------------------------------------------- +def make_sft_text(prompt: str, answer: str) -> str: + # Simple, stable template for causal LM (works even if model is not chat-instruct) + # You can change to Slovak system-style later. + prompt = normalize_prompt(prompt) + answer = (answer or "").strip() + return f"### User:\n{prompt}\n\n### Assistant:\n{answer}" + +# ----------------------------------------------------------------------------- +# main() +# Purpose: +# Orchestrates the full conversion pipeline: +# 1) Loads dataset from disk (SRC_DIR). +# 2) If it's a DatasetDict, selects a split (prefers 'train'). +# 3) Detects schema with find_schema(). +# 4) Builds: +# - SFT dataset: single column 'text' (prompt + chosen answer) +# - DPO dataset: columns 'prompt', 'chosen', 'rejected' +# 5) Saves both datasets to disk (SFT_OUT, DPO_OUT). +# +# Important: +# - This function does not validate data quality; it just transforms based on schema. +# ----------------------------------------------------------------------------- +def main(): + # Load dataset stored on disk (Hugging Face datasets format). + ds = load_from_disk(SRC_DIR) + + # If dataset is split into train/valid/test, pick 'train' when available. + if isinstance(ds, DatasetDict): + # take train if exists + split = "train" if "train" in ds else list(ds.keys())[0] + ds = ds[split] + + # Detect which column schema is present. + schema = find_schema(ds) + print("[INFO] Detected schema:", schema) + + # ------------------------------------------------------------------------- + # Branch A: Dataset already has explicit chosen/rejected columns. + # ------------------------------------------------------------------------- + if schema == "chosen_rejected": + # to_sft(ex) + # Converts one row into SFT format: + # - Uses ex["prompt"] and ex["chosen"] + # - Produces {"text": } + def to_sft(ex): + return {"text": make_sft_text(ex["prompt"], ex["chosen"])} + + # to_dpo(ex) + # Converts one row into DPO format: + # - Normalizes prompt + # - Strips chosen/rejected strings + # - Produces {"prompt": ..., "chosen": ..., "rejected": ...} + def to_dpo(ex): + return { + "prompt": normalize_prompt(ex["prompt"]), + "chosen": (ex["chosen"] or "").strip(), + "rejected": (ex["rejected"] or "").strip(), + } + + # Build SFT dataset: remove original columns and keep only 'text'. + sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT") + + # Build DPO dataset: keep only prompt/chosen/rejected and remove everything else. + dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO") + + # ------------------------------------------------------------------------- + # Branch B: Dataset uses response_0/response_1 pairwise format (+ optional preference). + # ------------------------------------------------------------------------- + else: + # response_0/response_1 + maybe preference column + _, pref_col = schema + + # If no preference column exists, the script defaults: + # chosen = response_0 + # rejected = response_1 + # for every example. + if pref_col is None: + print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.") + else: + print("[INFO] Using preference column:", pref_col) + + # pick_chosen_rejected(ex) + # Purpose: + # Converts pairwise responses into explicit chosen/rejected fields. + # + # Inputs: + # - prompt, response_0, response_1 + # - optionally pref_col (winner encoding varies between datasets) + # + # Preference heuristics handled here: + # - int 0/1 meaning which response wins + # - str labels like "response_0"/"response_1" or "A"/"B" + # - bool: assumes True => response_1 is better (dataset-dependent) + def pick_chosen_rejected(ex): + p = normalize_prompt(ex["prompt"]) + r0 = (ex["response_0"] or "").strip() + r1 = (ex["response_1"] or "").strip() + + # If no preference column, keep the default mapping. + if pref_col is None: + chosen, rejected = r0, r1 + else: + v = ex.get(pref_col) + # Heuristics for common encodings: + # - 0/1 meaning winner index + # - "response_0"/"response_1" + # - "A"/"B" + # - True means response_1 better (varies) + chosen, rejected = r0, r1 + + if isinstance(v, int): + if v == 0: + chosen, rejected = r0, r1 + elif v == 1: + chosen, rejected = r1, r0 + elif isinstance(v, str): + vv = v.strip().lower() + if vv in ("0", "response_0", "r0", "a"): + chosen, rejected = r0, r1 + elif vv in ("1", "response_1", "r1", "b"): + chosen, rejected = r1, r0 + elif isinstance(v, bool): + # ambiguous; assume True => response_1 wins + if v is True: + chosen, rejected = r1, r0 + else: + chosen, rejected = r0, r1 + + return {"prompt": p, "chosen": chosen, "rejected": rejected} + + # First, derive explicit prompt/chosen/rejected for each row. + dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected") + + # Build SFT dataset from the chosen responses (single 'text' column). + sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])}, + remove_columns=dpo_tmp.column_names, desc="Build SFT") + + # Keep only required DPO columns. + dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"]) + + # Ensure output directories exist. + os.makedirs(SFT_OUT, exist_ok=True) + os.makedirs(DPO_OUT, exist_ok=True) + + # Save datasets to disk. + sft.save_to_disk(SFT_OUT) + dpo.save_to_disk(DPO_OUT) + + # Print resulting paths and row counts. + print("[OK] Saved SFT dataset:", SFT_OUT) + print("[OK] Saved DPO dataset:", DPO_OUT) + print("[INFO] SFT rows:", len(sft)) + print("[INFO] DPO rows:", len(dpo)) + +if __name__ == "__main__": + # Standard Python entry point guard: run main() only when executed directly. + main()