changed name of directory
This commit is contained in:
parent
eb4acde0a1
commit
b72f8b9e42
240
preparation/prepar_dat_pku_dpo.py
Normal file
240
preparation/prepar_dat_pku_dpo.py
Normal file
@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import json
|
||||
from datasets import load_from_disk, DatasetDict
|
||||
|
||||
SRC_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
|
||||
OUT_DIR = "/home/hyrenko/Diploma/datasets"
|
||||
|
||||
SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft")
|
||||
DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# find_schema(ds)
|
||||
# Purpose:
|
||||
# Detects which column layout (schema) the dataset uses, so the script knows
|
||||
# how to build SFT and DPO outputs.
|
||||
#
|
||||
# What it checks:
|
||||
# 1) If dataset already contains 'prompt', 'chosen', 'rejected' columns:
|
||||
# -> returns "chosen_rejected"
|
||||
# 2) If dataset contains 'prompt', 'response_0', 'response_1':
|
||||
# -> tries to detect a preference column (winner indicator) among common names
|
||||
# -> returns ("response01", <pref_column_or_None>)
|
||||
#
|
||||
# If nothing matches:
|
||||
# Raises RuntimeError with the dataset column list, to make debugging easier.
|
||||
# -----------------------------------------------------------------------------
|
||||
def find_schema(ds):
|
||||
cols = set(ds.column_names)
|
||||
|
||||
# Best case
|
||||
if {"prompt", "chosen", "rejected"}.issubset(cols):
|
||||
return "chosen_rejected"
|
||||
|
||||
# Pairwise variants
|
||||
if {"prompt", "response_0", "response_1"}.issubset(cols):
|
||||
# need preference column to choose winner; try common names
|
||||
pref_candidates = [
|
||||
"choice", "label", "winner", "preferred", "preference",
|
||||
"better", "better_response", "better_response_id", "chosen_id",
|
||||
"win", "response_choice"
|
||||
]
|
||||
pref = None
|
||||
for c in pref_candidates:
|
||||
if c in cols:
|
||||
pref = c
|
||||
break
|
||||
return ("response01", pref)
|
||||
|
||||
# fallback
|
||||
raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# normalize_prompt(p)
|
||||
# Purpose:
|
||||
# Normalizes prompt text to reduce formatting noise.
|
||||
#
|
||||
# Behavior:
|
||||
# - If p is None, converts it to empty string.
|
||||
# - Strips leading/trailing whitespace.
|
||||
# -----------------------------------------------------------------------------
|
||||
def normalize_prompt(p: str) -> str:
|
||||
p = (p or "").strip()
|
||||
return p
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# make_sft_text(prompt, answer)
|
||||
# Purpose:
|
||||
# Builds a single "text" field for SFT training using a simple stable template.
|
||||
#
|
||||
# Output format:
|
||||
# ### User:
|
||||
# <prompt>
|
||||
#
|
||||
# ### Assistant:
|
||||
# <answer>
|
||||
#
|
||||
# Notes:
|
||||
# - Keeps formatting generic (not model-specific chat tokens).
|
||||
# - Calls normalize_prompt() for prompt and strips answer.
|
||||
# -----------------------------------------------------------------------------
|
||||
def make_sft_text(prompt: str, answer: str) -> str:
|
||||
# Simple, stable template for causal LM (works even if model is not chat-instruct)
|
||||
# You can change to Slovak system-style later.
|
||||
prompt = normalize_prompt(prompt)
|
||||
answer = (answer or "").strip()
|
||||
return f"### User:\n{prompt}\n\n### Assistant:\n{answer}"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# main()
|
||||
# Purpose:
|
||||
# Orchestrates the full conversion pipeline:
|
||||
# 1) Loads dataset from disk (SRC_DIR).
|
||||
# 2) If it's a DatasetDict, selects a split (prefers 'train').
|
||||
# 3) Detects schema with find_schema().
|
||||
# 4) Builds:
|
||||
# - SFT dataset: single column 'text' (prompt + chosen answer)
|
||||
# - DPO dataset: columns 'prompt', 'chosen', 'rejected'
|
||||
# 5) Saves both datasets to disk (SFT_OUT, DPO_OUT).
|
||||
#
|
||||
# Important:
|
||||
# - This function does not validate data quality; it just transforms based on schema.
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
# Load dataset stored on disk (Hugging Face datasets format).
|
||||
ds = load_from_disk(SRC_DIR)
|
||||
|
||||
# If dataset is split into train/valid/test, pick 'train' when available.
|
||||
if isinstance(ds, DatasetDict):
|
||||
# take train if exists
|
||||
split = "train" if "train" in ds else list(ds.keys())[0]
|
||||
ds = ds[split]
|
||||
|
||||
# Detect which column schema is present.
|
||||
schema = find_schema(ds)
|
||||
print("[INFO] Detected schema:", schema)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Branch A: Dataset already has explicit chosen/rejected columns.
|
||||
# -------------------------------------------------------------------------
|
||||
if schema == "chosen_rejected":
|
||||
# to_sft(ex)
|
||||
# Converts one row into SFT format:
|
||||
# - Uses ex["prompt"] and ex["chosen"]
|
||||
# - Produces {"text": <combined template string>}
|
||||
def to_sft(ex):
|
||||
return {"text": make_sft_text(ex["prompt"], ex["chosen"])}
|
||||
|
||||
# to_dpo(ex)
|
||||
# Converts one row into DPO format:
|
||||
# - Normalizes prompt
|
||||
# - Strips chosen/rejected strings
|
||||
# - Produces {"prompt": ..., "chosen": ..., "rejected": ...}
|
||||
def to_dpo(ex):
|
||||
return {
|
||||
"prompt": normalize_prompt(ex["prompt"]),
|
||||
"chosen": (ex["chosen"] or "").strip(),
|
||||
"rejected": (ex["rejected"] or "").strip(),
|
||||
}
|
||||
|
||||
# Build SFT dataset: remove original columns and keep only 'text'.
|
||||
sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT")
|
||||
|
||||
# Build DPO dataset: keep only prompt/chosen/rejected and remove everything else.
|
||||
dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Branch B: Dataset uses response_0/response_1 pairwise format (+ optional preference).
|
||||
# -------------------------------------------------------------------------
|
||||
else:
|
||||
# response_0/response_1 + maybe preference column
|
||||
_, pref_col = schema
|
||||
|
||||
# If no preference column exists, the script defaults:
|
||||
# chosen = response_0
|
||||
# rejected = response_1
|
||||
# for every example.
|
||||
if pref_col is None:
|
||||
print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.")
|
||||
else:
|
||||
print("[INFO] Using preference column:", pref_col)
|
||||
|
||||
# pick_chosen_rejected(ex)
|
||||
# Purpose:
|
||||
# Converts pairwise responses into explicit chosen/rejected fields.
|
||||
#
|
||||
# Inputs:
|
||||
# - prompt, response_0, response_1
|
||||
# - optionally pref_col (winner encoding varies between datasets)
|
||||
#
|
||||
# Preference heuristics handled here:
|
||||
# - int 0/1 meaning which response wins
|
||||
# - str labels like "response_0"/"response_1" or "A"/"B"
|
||||
# - bool: assumes True => response_1 is better (dataset-dependent)
|
||||
def pick_chosen_rejected(ex):
|
||||
p = normalize_prompt(ex["prompt"])
|
||||
r0 = (ex["response_0"] or "").strip()
|
||||
r1 = (ex["response_1"] or "").strip()
|
||||
|
||||
# If no preference column, keep the default mapping.
|
||||
if pref_col is None:
|
||||
chosen, rejected = r0, r1
|
||||
else:
|
||||
v = ex.get(pref_col)
|
||||
# Heuristics for common encodings:
|
||||
# - 0/1 meaning winner index
|
||||
# - "response_0"/"response_1"
|
||||
# - "A"/"B"
|
||||
# - True means response_1 better (varies)
|
||||
chosen, rejected = r0, r1
|
||||
|
||||
if isinstance(v, int):
|
||||
if v == 0:
|
||||
chosen, rejected = r0, r1
|
||||
elif v == 1:
|
||||
chosen, rejected = r1, r0
|
||||
elif isinstance(v, str):
|
||||
vv = v.strip().lower()
|
||||
if vv in ("0", "response_0", "r0", "a"):
|
||||
chosen, rejected = r0, r1
|
||||
elif vv in ("1", "response_1", "r1", "b"):
|
||||
chosen, rejected = r1, r0
|
||||
elif isinstance(v, bool):
|
||||
# ambiguous; assume True => response_1 wins
|
||||
if v is True:
|
||||
chosen, rejected = r1, r0
|
||||
else:
|
||||
chosen, rejected = r0, r1
|
||||
|
||||
return {"prompt": p, "chosen": chosen, "rejected": rejected}
|
||||
|
||||
# First, derive explicit prompt/chosen/rejected for each row.
|
||||
dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected")
|
||||
|
||||
# Build SFT dataset from the chosen responses (single 'text' column).
|
||||
sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])},
|
||||
remove_columns=dpo_tmp.column_names, desc="Build SFT")
|
||||
|
||||
# Keep only required DPO columns.
|
||||
dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"])
|
||||
|
||||
# Ensure output directories exist.
|
||||
os.makedirs(SFT_OUT, exist_ok=True)
|
||||
os.makedirs(DPO_OUT, exist_ok=True)
|
||||
|
||||
# Save datasets to disk.
|
||||
sft.save_to_disk(SFT_OUT)
|
||||
dpo.save_to_disk(DPO_OUT)
|
||||
|
||||
# Print resulting paths and row counts.
|
||||
print("[OK] Saved SFT dataset:", SFT_OUT)
|
||||
print("[OK] Saved DPO dataset:", DPO_OUT)
|
||||
print("[INFO] SFT rows:", len(sft))
|
||||
print("[INFO] DPO rows:", len(dpo))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Standard Python entry point guard: run main() only when executed directly.
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user