241 lines
9.3 KiB
Python
241 lines
9.3 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import json
|
|
from datasets import load_from_disk, DatasetDict
|
|
|
|
SRC_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
|
|
OUT_DIR = "/home/hyrenko/Diploma/datasets"
|
|
|
|
SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft")
|
|
DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# find_schema(ds)
|
|
# Purpose:
|
|
# Detects which column layout (schema) the dataset uses, so the script knows
|
|
# how to build SFT and DPO outputs.
|
|
#
|
|
# What it checks:
|
|
# 1) If dataset already contains 'prompt', 'chosen', 'rejected' columns:
|
|
# -> returns "chosen_rejected"
|
|
# 2) If dataset contains 'prompt', 'response_0', 'response_1':
|
|
# -> tries to detect a preference column (winner indicator) among common names
|
|
# -> returns ("response01", <pref_column_or_None>)
|
|
#
|
|
# If nothing matches:
|
|
# Raises RuntimeError with the dataset column list, to make debugging easier.
|
|
# -----------------------------------------------------------------------------
|
|
def find_schema(ds):
|
|
cols = set(ds.column_names)
|
|
|
|
# Best case
|
|
if {"prompt", "chosen", "rejected"}.issubset(cols):
|
|
return "chosen_rejected"
|
|
|
|
# Pairwise variants
|
|
if {"prompt", "response_0", "response_1"}.issubset(cols):
|
|
# need preference column to choose winner; try common names
|
|
pref_candidates = [
|
|
"choice", "label", "winner", "preferred", "preference",
|
|
"better", "better_response", "better_response_id", "chosen_id",
|
|
"win", "response_choice"
|
|
]
|
|
pref = None
|
|
for c in pref_candidates:
|
|
if c in cols:
|
|
pref = c
|
|
break
|
|
return ("response01", pref)
|
|
|
|
# fallback
|
|
raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# normalize_prompt(p)
|
|
# Purpose:
|
|
# Normalizes prompt text to reduce formatting noise.
|
|
#
|
|
# Behavior:
|
|
# - If p is None, converts it to empty string.
|
|
# - Strips leading/trailing whitespace.
|
|
# -----------------------------------------------------------------------------
|
|
def normalize_prompt(p: str) -> str:
|
|
p = (p or "").strip()
|
|
return p
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# make_sft_text(prompt, answer)
|
|
# Purpose:
|
|
# Builds a single "text" field for SFT training using a simple stable template.
|
|
#
|
|
# Output format:
|
|
# ### User:
|
|
# <prompt>
|
|
#
|
|
# ### Assistant:
|
|
# <answer>
|
|
#
|
|
# Notes:
|
|
# - Keeps formatting generic (not model-specific chat tokens).
|
|
# - Calls normalize_prompt() for prompt and strips answer.
|
|
# -----------------------------------------------------------------------------
|
|
def make_sft_text(prompt: str, answer: str) -> str:
|
|
# Simple, stable template for causal LM (works even if model is not chat-instruct)
|
|
# You can change to Slovak system-style later.
|
|
prompt = normalize_prompt(prompt)
|
|
answer = (answer or "").strip()
|
|
return f"### User:\n{prompt}\n\n### Assistant:\n{answer}"
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# main()
|
|
# Purpose:
|
|
# Orchestrates the full conversion pipeline:
|
|
# 1) Loads dataset from disk (SRC_DIR).
|
|
# 2) If it's a DatasetDict, selects a split (prefers 'train').
|
|
# 3) Detects schema with find_schema().
|
|
# 4) Builds:
|
|
# - SFT dataset: single column 'text' (prompt + chosen answer)
|
|
# - DPO dataset: columns 'prompt', 'chosen', 'rejected'
|
|
# 5) Saves both datasets to disk (SFT_OUT, DPO_OUT).
|
|
#
|
|
# Important:
|
|
# - This function does not validate data quality; it just transforms based on schema.
|
|
# -----------------------------------------------------------------------------
|
|
def main():
|
|
# Load dataset stored on disk (Hugging Face datasets format).
|
|
ds = load_from_disk(SRC_DIR)
|
|
|
|
# If dataset is split into train/valid/test, pick 'train' when available.
|
|
if isinstance(ds, DatasetDict):
|
|
# take train if exists
|
|
split = "train" if "train" in ds else list(ds.keys())[0]
|
|
ds = ds[split]
|
|
|
|
# Detect which column schema is present.
|
|
schema = find_schema(ds)
|
|
print("[INFO] Detected schema:", schema)
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Branch A: Dataset already has explicit chosen/rejected columns.
|
|
# -------------------------------------------------------------------------
|
|
if schema == "chosen_rejected":
|
|
# to_sft(ex)
|
|
# Converts one row into SFT format:
|
|
# - Uses ex["prompt"] and ex["chosen"]
|
|
# - Produces {"text": <combined template string>}
|
|
def to_sft(ex):
|
|
return {"text": make_sft_text(ex["prompt"], ex["chosen"])}
|
|
|
|
# to_dpo(ex)
|
|
# Converts one row into DPO format:
|
|
# - Normalizes prompt
|
|
# - Strips chosen/rejected strings
|
|
# - Produces {"prompt": ..., "chosen": ..., "rejected": ...}
|
|
def to_dpo(ex):
|
|
return {
|
|
"prompt": normalize_prompt(ex["prompt"]),
|
|
"chosen": (ex["chosen"] or "").strip(),
|
|
"rejected": (ex["rejected"] or "").strip(),
|
|
}
|
|
|
|
# Build SFT dataset: remove original columns and keep only 'text'.
|
|
sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT")
|
|
|
|
# Build DPO dataset: keep only prompt/chosen/rejected and remove everything else.
|
|
dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO")
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Branch B: Dataset uses response_0/response_1 pairwise format (+ optional preference).
|
|
# -------------------------------------------------------------------------
|
|
else:
|
|
# response_0/response_1 + maybe preference column
|
|
_, pref_col = schema
|
|
|
|
# If no preference column exists, the script defaults:
|
|
# chosen = response_0
|
|
# rejected = response_1
|
|
# for every example.
|
|
if pref_col is None:
|
|
print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.")
|
|
else:
|
|
print("[INFO] Using preference column:", pref_col)
|
|
|
|
# pick_chosen_rejected(ex)
|
|
# Purpose:
|
|
# Converts pairwise responses into explicit chosen/rejected fields.
|
|
#
|
|
# Inputs:
|
|
# - prompt, response_0, response_1
|
|
# - optionally pref_col (winner encoding varies between datasets)
|
|
#
|
|
# Preference heuristics handled here:
|
|
# - int 0/1 meaning which response wins
|
|
# - str labels like "response_0"/"response_1" or "A"/"B"
|
|
# - bool: assumes True => response_1 is better (dataset-dependent)
|
|
def pick_chosen_rejected(ex):
|
|
p = normalize_prompt(ex["prompt"])
|
|
r0 = (ex["response_0"] or "").strip()
|
|
r1 = (ex["response_1"] or "").strip()
|
|
|
|
# If no preference column, keep the default mapping.
|
|
if pref_col is None:
|
|
chosen, rejected = r0, r1
|
|
else:
|
|
v = ex.get(pref_col)
|
|
# Heuristics for common encodings:
|
|
# - 0/1 meaning winner index
|
|
# - "response_0"/"response_1"
|
|
# - "A"/"B"
|
|
# - True means response_1 better (varies)
|
|
chosen, rejected = r0, r1
|
|
|
|
if isinstance(v, int):
|
|
if v == 0:
|
|
chosen, rejected = r0, r1
|
|
elif v == 1:
|
|
chosen, rejected = r1, r0
|
|
elif isinstance(v, str):
|
|
vv = v.strip().lower()
|
|
if vv in ("0", "response_0", "r0", "a"):
|
|
chosen, rejected = r0, r1
|
|
elif vv in ("1", "response_1", "r1", "b"):
|
|
chosen, rejected = r1, r0
|
|
elif isinstance(v, bool):
|
|
# ambiguous; assume True => response_1 wins
|
|
if v is True:
|
|
chosen, rejected = r1, r0
|
|
else:
|
|
chosen, rejected = r0, r1
|
|
|
|
return {"prompt": p, "chosen": chosen, "rejected": rejected}
|
|
|
|
# First, derive explicit prompt/chosen/rejected for each row.
|
|
dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected")
|
|
|
|
# Build SFT dataset from the chosen responses (single 'text' column).
|
|
sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])},
|
|
remove_columns=dpo_tmp.column_names, desc="Build SFT")
|
|
|
|
# Keep only required DPO columns.
|
|
dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"])
|
|
|
|
# Ensure output directories exist.
|
|
os.makedirs(SFT_OUT, exist_ok=True)
|
|
os.makedirs(DPO_OUT, exist_ok=True)
|
|
|
|
# Save datasets to disk.
|
|
sft.save_to_disk(SFT_OUT)
|
|
dpo.save_to_disk(DPO_OUT)
|
|
|
|
# Print resulting paths and row counts.
|
|
print("[OK] Saved SFT dataset:", SFT_OUT)
|
|
print("[OK] Saved DPO dataset:", DPO_OUT)
|
|
print("[INFO] SFT rows:", len(sft))
|
|
print("[INFO] DPO rows:", len(dpo))
|
|
|
|
if __name__ == "__main__":
|
|
# Standard Python entry point guard: run main() only when executed directly.
|
|
main()
|