#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import json from datasets import load_from_disk, DatasetDict SRC_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY" OUT_DIR = "/home/hyrenko/Diploma/datasets" SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft") DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo") # ----------------------------------------------------------------------------- # find_schema(ds) # Purpose: # Detects which column layout (schema) the dataset uses, so the script knows # how to build SFT and DPO outputs. # # What it checks: # 1) If dataset already contains 'prompt', 'chosen', 'rejected' columns: # -> returns "chosen_rejected" # 2) If dataset contains 'prompt', 'response_0', 'response_1': # -> tries to detect a preference column (winner indicator) among common names # -> returns ("response01", ) # # If nothing matches: # Raises RuntimeError with the dataset column list, to make debugging easier. # ----------------------------------------------------------------------------- def find_schema(ds): cols = set(ds.column_names) # Best case if {"prompt", "chosen", "rejected"}.issubset(cols): return "chosen_rejected" # Pairwise variants if {"prompt", "response_0", "response_1"}.issubset(cols): # need preference column to choose winner; try common names pref_candidates = [ "choice", "label", "winner", "preferred", "preference", "better", "better_response", "better_response_id", "chosen_id", "win", "response_choice" ] pref = None for c in pref_candidates: if c in cols: pref = c break return ("response01", pref) # fallback raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}") # ----------------------------------------------------------------------------- # normalize_prompt(p) # Purpose: # Normalizes prompt text to reduce formatting noise. # # Behavior: # - If p is None, converts it to empty string. # - Strips leading/trailing whitespace. # ----------------------------------------------------------------------------- def normalize_prompt(p: str) -> str: p = (p or "").strip() return p # ----------------------------------------------------------------------------- # make_sft_text(prompt, answer) # Purpose: # Builds a single "text" field for SFT training using a simple stable template. # # Output format: # ### User: # # # ### Assistant: # # # Notes: # - Keeps formatting generic (not model-specific chat tokens). # - Calls normalize_prompt() for prompt and strips answer. # ----------------------------------------------------------------------------- def make_sft_text(prompt: str, answer: str) -> str: # Simple, stable template for causal LM (works even if model is not chat-instruct) # You can change to Slovak system-style later. prompt = normalize_prompt(prompt) answer = (answer or "").strip() return f"### User:\n{prompt}\n\n### Assistant:\n{answer}" # ----------------------------------------------------------------------------- # main() # Purpose: # Orchestrates the full conversion pipeline: # 1) Loads dataset from disk (SRC_DIR). # 2) If it's a DatasetDict, selects a split (prefers 'train'). # 3) Detects schema with find_schema(). # 4) Builds: # - SFT dataset: single column 'text' (prompt + chosen answer) # - DPO dataset: columns 'prompt', 'chosen', 'rejected' # 5) Saves both datasets to disk (SFT_OUT, DPO_OUT). # # Important: # - This function does not validate data quality; it just transforms based on schema. # ----------------------------------------------------------------------------- def main(): # Load dataset stored on disk (Hugging Face datasets format). ds = load_from_disk(SRC_DIR) # If dataset is split into train/valid/test, pick 'train' when available. if isinstance(ds, DatasetDict): # take train if exists split = "train" if "train" in ds else list(ds.keys())[0] ds = ds[split] # Detect which column schema is present. schema = find_schema(ds) print("[INFO] Detected schema:", schema) # ------------------------------------------------------------------------- # Branch A: Dataset already has explicit chosen/rejected columns. # ------------------------------------------------------------------------- if schema == "chosen_rejected": # to_sft(ex) # Converts one row into SFT format: # - Uses ex["prompt"] and ex["chosen"] # - Produces {"text": } def to_sft(ex): return {"text": make_sft_text(ex["prompt"], ex["chosen"])} # to_dpo(ex) # Converts one row into DPO format: # - Normalizes prompt # - Strips chosen/rejected strings # - Produces {"prompt": ..., "chosen": ..., "rejected": ...} def to_dpo(ex): return { "prompt": normalize_prompt(ex["prompt"]), "chosen": (ex["chosen"] or "").strip(), "rejected": (ex["rejected"] or "").strip(), } # Build SFT dataset: remove original columns and keep only 'text'. sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT") # Build DPO dataset: keep only prompt/chosen/rejected and remove everything else. dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO") # ------------------------------------------------------------------------- # Branch B: Dataset uses response_0/response_1 pairwise format (+ optional preference). # ------------------------------------------------------------------------- else: # response_0/response_1 + maybe preference column _, pref_col = schema # If no preference column exists, the script defaults: # chosen = response_0 # rejected = response_1 # for every example. if pref_col is None: print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.") else: print("[INFO] Using preference column:", pref_col) # pick_chosen_rejected(ex) # Purpose: # Converts pairwise responses into explicit chosen/rejected fields. # # Inputs: # - prompt, response_0, response_1 # - optionally pref_col (winner encoding varies between datasets) # # Preference heuristics handled here: # - int 0/1 meaning which response wins # - str labels like "response_0"/"response_1" or "A"/"B" # - bool: assumes True => response_1 is better (dataset-dependent) def pick_chosen_rejected(ex): p = normalize_prompt(ex["prompt"]) r0 = (ex["response_0"] or "").strip() r1 = (ex["response_1"] or "").strip() # If no preference column, keep the default mapping. if pref_col is None: chosen, rejected = r0, r1 else: v = ex.get(pref_col) # Heuristics for common encodings: # - 0/1 meaning winner index # - "response_0"/"response_1" # - "A"/"B" # - True means response_1 better (varies) chosen, rejected = r0, r1 if isinstance(v, int): if v == 0: chosen, rejected = r0, r1 elif v == 1: chosen, rejected = r1, r0 elif isinstance(v, str): vv = v.strip().lower() if vv in ("0", "response_0", "r0", "a"): chosen, rejected = r0, r1 elif vv in ("1", "response_1", "r1", "b"): chosen, rejected = r1, r0 elif isinstance(v, bool): # ambiguous; assume True => response_1 wins if v is True: chosen, rejected = r1, r0 else: chosen, rejected = r0, r1 return {"prompt": p, "chosen": chosen, "rejected": rejected} # First, derive explicit prompt/chosen/rejected for each row. dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected") # Build SFT dataset from the chosen responses (single 'text' column). sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])}, remove_columns=dpo_tmp.column_names, desc="Build SFT") # Keep only required DPO columns. dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"]) # Ensure output directories exist. os.makedirs(SFT_OUT, exist_ok=True) os.makedirs(DPO_OUT, exist_ok=True) # Save datasets to disk. sft.save_to_disk(SFT_OUT) dpo.save_to_disk(DPO_OUT) # Print resulting paths and row counts. print("[OK] Saved SFT dataset:", SFT_OUT) print("[OK] Saved DPO dataset:", DPO_OUT) print("[INFO] SFT rows:", len(sft)) print("[INFO] DPO rows:", len(dpo)) if __name__ == "__main__": # Standard Python entry point guard: run main() only when executed directly. main()