Delete Preparation/prepar_dat_pku_dpo.py

This commit is contained in:
Artur Hyrenko 2026-02-04 20:17:10 +00:00
parent b72f8b9e42
commit 10f4002231

View File

@ -1,240 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
from datasets import load_from_disk, DatasetDict
SRC_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
OUT_DIR = "/home/hyrenko/Diploma/datasets"
SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft")
DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo")
# -----------------------------------------------------------------------------
# find_schema(ds)
# Purpose:
# Detects which column layout (schema) the dataset uses, so the script knows
# how to build SFT and DPO outputs.
#
# What it checks:
# 1) If dataset already contains 'prompt', 'chosen', 'rejected' columns:
# -> returns "chosen_rejected"
# 2) If dataset contains 'prompt', 'response_0', 'response_1':
# -> tries to detect a preference column (winner indicator) among common names
# -> returns ("response01", <pref_column_or_None>)
#
# If nothing matches:
# Raises RuntimeError with the dataset column list, to make debugging easier.
# -----------------------------------------------------------------------------
def find_schema(ds):
cols = set(ds.column_names)
# Best case
if {"prompt", "chosen", "rejected"}.issubset(cols):
return "chosen_rejected"
# Pairwise variants
if {"prompt", "response_0", "response_1"}.issubset(cols):
# need preference column to choose winner; try common names
pref_candidates = [
"choice", "label", "winner", "preferred", "preference",
"better", "better_response", "better_response_id", "chosen_id",
"win", "response_choice"
]
pref = None
for c in pref_candidates:
if c in cols:
pref = c
break
return ("response01", pref)
# fallback
raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}")
# -----------------------------------------------------------------------------
# normalize_prompt(p)
# Purpose:
# Normalizes prompt text to reduce formatting noise.
#
# Behavior:
# - If p is None, converts it to empty string.
# - Strips leading/trailing whitespace.
# -----------------------------------------------------------------------------
def normalize_prompt(p: str) -> str:
p = (p or "").strip()
return p
# -----------------------------------------------------------------------------
# make_sft_text(prompt, answer)
# Purpose:
# Builds a single "text" field for SFT training using a simple stable template.
#
# Output format:
# ### User:
# <prompt>
#
# ### Assistant:
# <answer>
#
# Notes:
# - Keeps formatting generic (not model-specific chat tokens).
# - Calls normalize_prompt() for prompt and strips answer.
# -----------------------------------------------------------------------------
def make_sft_text(prompt: str, answer: str) -> str:
# Simple, stable template for causal LM (works even if model is not chat-instruct)
# You can change to Slovak system-style later.
prompt = normalize_prompt(prompt)
answer = (answer or "").strip()
return f"### User:\n{prompt}\n\n### Assistant:\n{answer}"
# -----------------------------------------------------------------------------
# main()
# Purpose:
# Orchestrates the full conversion pipeline:
# 1) Loads dataset from disk (SRC_DIR).
# 2) If it's a DatasetDict, selects a split (prefers 'train').
# 3) Detects schema with find_schema().
# 4) Builds:
# - SFT dataset: single column 'text' (prompt + chosen answer)
# - DPO dataset: columns 'prompt', 'chosen', 'rejected'
# 5) Saves both datasets to disk (SFT_OUT, DPO_OUT).
#
# Important:
# - This function does not validate data quality; it just transforms based on schema.
# -----------------------------------------------------------------------------
def main():
# Load dataset stored on disk (Hugging Face datasets format).
ds = load_from_disk(SRC_DIR)
# If dataset is split into train/valid/test, pick 'train' when available.
if isinstance(ds, DatasetDict):
# take train if exists
split = "train" if "train" in ds else list(ds.keys())[0]
ds = ds[split]
# Detect which column schema is present.
schema = find_schema(ds)
print("[INFO] Detected schema:", schema)
# -------------------------------------------------------------------------
# Branch A: Dataset already has explicit chosen/rejected columns.
# -------------------------------------------------------------------------
if schema == "chosen_rejected":
# to_sft(ex)
# Converts one row into SFT format:
# - Uses ex["prompt"] and ex["chosen"]
# - Produces {"text": <combined template string>}
def to_sft(ex):
return {"text": make_sft_text(ex["prompt"], ex["chosen"])}
# to_dpo(ex)
# Converts one row into DPO format:
# - Normalizes prompt
# - Strips chosen/rejected strings
# - Produces {"prompt": ..., "chosen": ..., "rejected": ...}
def to_dpo(ex):
return {
"prompt": normalize_prompt(ex["prompt"]),
"chosen": (ex["chosen"] or "").strip(),
"rejected": (ex["rejected"] or "").strip(),
}
# Build SFT dataset: remove original columns and keep only 'text'.
sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT")
# Build DPO dataset: keep only prompt/chosen/rejected and remove everything else.
dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO")
# -------------------------------------------------------------------------
# Branch B: Dataset uses response_0/response_1 pairwise format (+ optional preference).
# -------------------------------------------------------------------------
else:
# response_0/response_1 + maybe preference column
_, pref_col = schema
# If no preference column exists, the script defaults:
# chosen = response_0
# rejected = response_1
# for every example.
if pref_col is None:
print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.")
else:
print("[INFO] Using preference column:", pref_col)
# pick_chosen_rejected(ex)
# Purpose:
# Converts pairwise responses into explicit chosen/rejected fields.
#
# Inputs:
# - prompt, response_0, response_1
# - optionally pref_col (winner encoding varies between datasets)
#
# Preference heuristics handled here:
# - int 0/1 meaning which response wins
# - str labels like "response_0"/"response_1" or "A"/"B"
# - bool: assumes True => response_1 is better (dataset-dependent)
def pick_chosen_rejected(ex):
p = normalize_prompt(ex["prompt"])
r0 = (ex["response_0"] or "").strip()
r1 = (ex["response_1"] or "").strip()
# If no preference column, keep the default mapping.
if pref_col is None:
chosen, rejected = r0, r1
else:
v = ex.get(pref_col)
# Heuristics for common encodings:
# - 0/1 meaning winner index
# - "response_0"/"response_1"
# - "A"/"B"
# - True means response_1 better (varies)
chosen, rejected = r0, r1
if isinstance(v, int):
if v == 0:
chosen, rejected = r0, r1
elif v == 1:
chosen, rejected = r1, r0
elif isinstance(v, str):
vv = v.strip().lower()
if vv in ("0", "response_0", "r0", "a"):
chosen, rejected = r0, r1
elif vv in ("1", "response_1", "r1", "b"):
chosen, rejected = r1, r0
elif isinstance(v, bool):
# ambiguous; assume True => response_1 wins
if v is True:
chosen, rejected = r1, r0
else:
chosen, rejected = r0, r1
return {"prompt": p, "chosen": chosen, "rejected": rejected}
# First, derive explicit prompt/chosen/rejected for each row.
dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected")
# Build SFT dataset from the chosen responses (single 'text' column).
sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])},
remove_columns=dpo_tmp.column_names, desc="Build SFT")
# Keep only required DPO columns.
dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"])
# Ensure output directories exist.
os.makedirs(SFT_OUT, exist_ok=True)
os.makedirs(DPO_OUT, exist_ok=True)
# Save datasets to disk.
sft.save_to_disk(SFT_OUT)
dpo.save_to_disk(DPO_OUT)
# Print resulting paths and row counts.
print("[OK] Saved SFT dataset:", SFT_OUT)
print("[OK] Saved DPO dataset:", DPO_OUT)
print("[INFO] SFT rows:", len(sft))
print("[INFO] DPO rows:", len(dpo))
if __name__ == "__main__":
# Standard Python entry point guard: run main() only when executed directly.
main()