diff --git a/Preparation/prepar_dat_pku_dpo.py b/Preparation/prepar_dat_pku_dpo.py deleted file mode 100644 index 299b199..0000000 --- a/Preparation/prepar_dat_pku_dpo.py +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import os -import json -from datasets import load_from_disk, DatasetDict - -SRC_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY" -OUT_DIR = "/home/hyrenko/Diploma/datasets" - -SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft") -DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo") - -# ----------------------------------------------------------------------------- -# find_schema(ds) -# Purpose: -# Detects which column layout (schema) the dataset uses, so the script knows -# how to build SFT and DPO outputs. -# -# What it checks: -# 1) If dataset already contains 'prompt', 'chosen', 'rejected' columns: -# -> returns "chosen_rejected" -# 2) If dataset contains 'prompt', 'response_0', 'response_1': -# -> tries to detect a preference column (winner indicator) among common names -# -> returns ("response01", ) -# -# If nothing matches: -# Raises RuntimeError with the dataset column list, to make debugging easier. -# ----------------------------------------------------------------------------- -def find_schema(ds): - cols = set(ds.column_names) - - # Best case - if {"prompt", "chosen", "rejected"}.issubset(cols): - return "chosen_rejected" - - # Pairwise variants - if {"prompt", "response_0", "response_1"}.issubset(cols): - # need preference column to choose winner; try common names - pref_candidates = [ - "choice", "label", "winner", "preferred", "preference", - "better", "better_response", "better_response_id", "chosen_id", - "win", "response_choice" - ] - pref = None - for c in pref_candidates: - if c in cols: - pref = c - break - return ("response01", pref) - - # fallback - raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}") - -# ----------------------------------------------------------------------------- -# normalize_prompt(p) -# Purpose: -# Normalizes prompt text to reduce formatting noise. -# -# Behavior: -# - If p is None, converts it to empty string. -# - Strips leading/trailing whitespace. -# ----------------------------------------------------------------------------- -def normalize_prompt(p: str) -> str: - p = (p or "").strip() - return p - -# ----------------------------------------------------------------------------- -# make_sft_text(prompt, answer) -# Purpose: -# Builds a single "text" field for SFT training using a simple stable template. -# -# Output format: -# ### User: -# -# -# ### Assistant: -# -# -# Notes: -# - Keeps formatting generic (not model-specific chat tokens). -# - Calls normalize_prompt() for prompt and strips answer. -# ----------------------------------------------------------------------------- -def make_sft_text(prompt: str, answer: str) -> str: - # Simple, stable template for causal LM (works even if model is not chat-instruct) - # You can change to Slovak system-style later. - prompt = normalize_prompt(prompt) - answer = (answer or "").strip() - return f"### User:\n{prompt}\n\n### Assistant:\n{answer}" - -# ----------------------------------------------------------------------------- -# main() -# Purpose: -# Orchestrates the full conversion pipeline: -# 1) Loads dataset from disk (SRC_DIR). -# 2) If it's a DatasetDict, selects a split (prefers 'train'). -# 3) Detects schema with find_schema(). -# 4) Builds: -# - SFT dataset: single column 'text' (prompt + chosen answer) -# - DPO dataset: columns 'prompt', 'chosen', 'rejected' -# 5) Saves both datasets to disk (SFT_OUT, DPO_OUT). -# -# Important: -# - This function does not validate data quality; it just transforms based on schema. -# ----------------------------------------------------------------------------- -def main(): - # Load dataset stored on disk (Hugging Face datasets format). - ds = load_from_disk(SRC_DIR) - - # If dataset is split into train/valid/test, pick 'train' when available. - if isinstance(ds, DatasetDict): - # take train if exists - split = "train" if "train" in ds else list(ds.keys())[0] - ds = ds[split] - - # Detect which column schema is present. - schema = find_schema(ds) - print("[INFO] Detected schema:", schema) - - # ------------------------------------------------------------------------- - # Branch A: Dataset already has explicit chosen/rejected columns. - # ------------------------------------------------------------------------- - if schema == "chosen_rejected": - # to_sft(ex) - # Converts one row into SFT format: - # - Uses ex["prompt"] and ex["chosen"] - # - Produces {"text": } - def to_sft(ex): - return {"text": make_sft_text(ex["prompt"], ex["chosen"])} - - # to_dpo(ex) - # Converts one row into DPO format: - # - Normalizes prompt - # - Strips chosen/rejected strings - # - Produces {"prompt": ..., "chosen": ..., "rejected": ...} - def to_dpo(ex): - return { - "prompt": normalize_prompt(ex["prompt"]), - "chosen": (ex["chosen"] or "").strip(), - "rejected": (ex["rejected"] or "").strip(), - } - - # Build SFT dataset: remove original columns and keep only 'text'. - sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT") - - # Build DPO dataset: keep only prompt/chosen/rejected and remove everything else. - dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO") - - # ------------------------------------------------------------------------- - # Branch B: Dataset uses response_0/response_1 pairwise format (+ optional preference). - # ------------------------------------------------------------------------- - else: - # response_0/response_1 + maybe preference column - _, pref_col = schema - - # If no preference column exists, the script defaults: - # chosen = response_0 - # rejected = response_1 - # for every example. - if pref_col is None: - print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.") - else: - print("[INFO] Using preference column:", pref_col) - - # pick_chosen_rejected(ex) - # Purpose: - # Converts pairwise responses into explicit chosen/rejected fields. - # - # Inputs: - # - prompt, response_0, response_1 - # - optionally pref_col (winner encoding varies between datasets) - # - # Preference heuristics handled here: - # - int 0/1 meaning which response wins - # - str labels like "response_0"/"response_1" or "A"/"B" - # - bool: assumes True => response_1 is better (dataset-dependent) - def pick_chosen_rejected(ex): - p = normalize_prompt(ex["prompt"]) - r0 = (ex["response_0"] or "").strip() - r1 = (ex["response_1"] or "").strip() - - # If no preference column, keep the default mapping. - if pref_col is None: - chosen, rejected = r0, r1 - else: - v = ex.get(pref_col) - # Heuristics for common encodings: - # - 0/1 meaning winner index - # - "response_0"/"response_1" - # - "A"/"B" - # - True means response_1 better (varies) - chosen, rejected = r0, r1 - - if isinstance(v, int): - if v == 0: - chosen, rejected = r0, r1 - elif v == 1: - chosen, rejected = r1, r0 - elif isinstance(v, str): - vv = v.strip().lower() - if vv in ("0", "response_0", "r0", "a"): - chosen, rejected = r0, r1 - elif vv in ("1", "response_1", "r1", "b"): - chosen, rejected = r1, r0 - elif isinstance(v, bool): - # ambiguous; assume True => response_1 wins - if v is True: - chosen, rejected = r1, r0 - else: - chosen, rejected = r0, r1 - - return {"prompt": p, "chosen": chosen, "rejected": rejected} - - # First, derive explicit prompt/chosen/rejected for each row. - dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected") - - # Build SFT dataset from the chosen responses (single 'text' column). - sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])}, - remove_columns=dpo_tmp.column_names, desc="Build SFT") - - # Keep only required DPO columns. - dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"]) - - # Ensure output directories exist. - os.makedirs(SFT_OUT, exist_ok=True) - os.makedirs(DPO_OUT, exist_ok=True) - - # Save datasets to disk. - sft.save_to_disk(SFT_OUT) - dpo.save_to_disk(DPO_OUT) - - # Print resulting paths and row counts. - print("[OK] Saved SFT dataset:", SFT_OUT) - print("[OK] Saved DPO dataset:", DPO_OUT) - print("[INFO] SFT rows:", len(sft)) - print("[INFO] DPO rows:", len(dpo)) - -if __name__ == "__main__": - # Standard Python entry point guard: run main() only when executed directly. - main()