DP/Preparation/prepar_dat_pku_dpo.py
2026-02-04 21:07:17 +01:00

136 lines
4.7 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
from datasets import load_from_disk, DatasetDict
SRC_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
OUT_DIR = "/home/hyrenko/Diploma/datasets"
SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft")
DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo")
def find_schema(ds):
cols = set(ds.column_names)
# Best case
if {"prompt", "chosen", "rejected"}.issubset(cols):
return "chosen_rejected"
# Pairwise variants
if {"prompt", "response_0", "response_1"}.issubset(cols):
# need preference column to choose winner; try common names
pref_candidates = [
"choice", "label", "winner", "preferred", "preference",
"better", "better_response", "better_response_id", "chosen_id",
"win", "response_choice"
]
pref = None
for c in pref_candidates:
if c in cols:
pref = c
break
return ("response01", pref)
# fallback
raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}")
def normalize_prompt(p: str) -> str:
p = (p or "").strip()
return p
def make_sft_text(prompt: str, answer: str) -> str:
# Simple, stable template for causal LM (works even if model is not chat-instruct)
# You can change to Slovak system-style later.
prompt = normalize_prompt(prompt)
answer = (answer or "").strip()
return f"### User:\n{prompt}\n\n### Assistant:\n{answer}"
def main():
ds = load_from_disk(SRC_DIR)
if isinstance(ds, DatasetDict):
# take train if exists
split = "train" if "train" in ds else list(ds.keys())[0]
ds = ds[split]
schema = find_schema(ds)
print("[INFO] Detected schema:", schema)
if schema == "chosen_rejected":
def to_sft(ex):
return {"text": make_sft_text(ex["prompt"], ex["chosen"])}
def to_dpo(ex):
return {
"prompt": normalize_prompt(ex["prompt"]),
"chosen": (ex["chosen"] or "").strip(),
"rejected": (ex["rejected"] or "").strip(),
}
sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT")
dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO")
else:
# response_0/response_1 + maybe preference column
_, pref_col = schema
if pref_col is None:
print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.")
else:
print("[INFO] Using preference column:", pref_col)
def pick_chosen_rejected(ex):
p = normalize_prompt(ex["prompt"])
r0 = (ex["response_0"] or "").strip()
r1 = (ex["response_1"] or "").strip()
if pref_col is None:
chosen, rejected = r0, r1
else:
v = ex.get(pref_col)
# Heuristics for common encodings:
# - 0/1 meaning winner index
# - "response_0"/"response_1"
# - "A"/"B"
# - True means response_1 better (varies)
chosen, rejected = r0, r1
if isinstance(v, int):
if v == 0:
chosen, rejected = r0, r1
elif v == 1:
chosen, rejected = r1, r0
elif isinstance(v, str):
vv = v.strip().lower()
if vv in ("0", "response_0", "r0", "a"):
chosen, rejected = r0, r1
elif vv in ("1", "response_1", "r1", "b"):
chosen, rejected = r1, r0
elif isinstance(v, bool):
# ambiguous; assume True => response_1 wins
if v is True:
chosen, rejected = r1, r0
else:
chosen, rejected = r0, r1
return {"prompt": p, "chosen": chosen, "rejected": rejected}
dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected")
sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])},
remove_columns=dpo_tmp.column_names, desc="Build SFT")
dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"])
os.makedirs(SFT_OUT, exist_ok=True)
os.makedirs(DPO_OUT, exist_ok=True)
sft.save_to_disk(SFT_OUT)
dpo.save_to_disk(DPO_OUT)
print("[OK] Saved SFT dataset:", SFT_OUT)
print("[OK] Saved DPO dataset:", DPO_OUT)
print("[INFO] SFT rows:", len(sft))
print("[INFO] DPO rows:", len(dpo))
if __name__ == "__main__":
main()