136 lines
4.7 KiB
Python
136 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import json
|
|
from datasets import load_from_disk, DatasetDict
|
|
|
|
SRC_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
|
|
OUT_DIR = "/home/hyrenko/Diploma/datasets"
|
|
|
|
SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft")
|
|
DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo")
|
|
|
|
def find_schema(ds):
|
|
cols = set(ds.column_names)
|
|
|
|
# Best case
|
|
if {"prompt", "chosen", "rejected"}.issubset(cols):
|
|
return "chosen_rejected"
|
|
|
|
# Pairwise variants
|
|
if {"prompt", "response_0", "response_1"}.issubset(cols):
|
|
# need preference column to choose winner; try common names
|
|
pref_candidates = [
|
|
"choice", "label", "winner", "preferred", "preference",
|
|
"better", "better_response", "better_response_id", "chosen_id",
|
|
"win", "response_choice"
|
|
]
|
|
pref = None
|
|
for c in pref_candidates:
|
|
if c in cols:
|
|
pref = c
|
|
break
|
|
return ("response01", pref)
|
|
|
|
# fallback
|
|
raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}")
|
|
|
|
def normalize_prompt(p: str) -> str:
|
|
p = (p or "").strip()
|
|
return p
|
|
|
|
def make_sft_text(prompt: str, answer: str) -> str:
|
|
# Simple, stable template for causal LM (works even if model is not chat-instruct)
|
|
# You can change to Slovak system-style later.
|
|
prompt = normalize_prompt(prompt)
|
|
answer = (answer or "").strip()
|
|
return f"### User:\n{prompt}\n\n### Assistant:\n{answer}"
|
|
|
|
def main():
|
|
ds = load_from_disk(SRC_DIR)
|
|
if isinstance(ds, DatasetDict):
|
|
# take train if exists
|
|
split = "train" if "train" in ds else list(ds.keys())[0]
|
|
ds = ds[split]
|
|
|
|
schema = find_schema(ds)
|
|
print("[INFO] Detected schema:", schema)
|
|
|
|
if schema == "chosen_rejected":
|
|
def to_sft(ex):
|
|
return {"text": make_sft_text(ex["prompt"], ex["chosen"])}
|
|
|
|
def to_dpo(ex):
|
|
return {
|
|
"prompt": normalize_prompt(ex["prompt"]),
|
|
"chosen": (ex["chosen"] or "").strip(),
|
|
"rejected": (ex["rejected"] or "").strip(),
|
|
}
|
|
|
|
sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT")
|
|
dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO")
|
|
|
|
else:
|
|
# response_0/response_1 + maybe preference column
|
|
_, pref_col = schema
|
|
if pref_col is None:
|
|
print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.")
|
|
else:
|
|
print("[INFO] Using preference column:", pref_col)
|
|
|
|
def pick_chosen_rejected(ex):
|
|
p = normalize_prompt(ex["prompt"])
|
|
r0 = (ex["response_0"] or "").strip()
|
|
r1 = (ex["response_1"] or "").strip()
|
|
|
|
if pref_col is None:
|
|
chosen, rejected = r0, r1
|
|
else:
|
|
v = ex.get(pref_col)
|
|
# Heuristics for common encodings:
|
|
# - 0/1 meaning winner index
|
|
# - "response_0"/"response_1"
|
|
# - "A"/"B"
|
|
# - True means response_1 better (varies)
|
|
chosen, rejected = r0, r1
|
|
|
|
if isinstance(v, int):
|
|
if v == 0:
|
|
chosen, rejected = r0, r1
|
|
elif v == 1:
|
|
chosen, rejected = r1, r0
|
|
elif isinstance(v, str):
|
|
vv = v.strip().lower()
|
|
if vv in ("0", "response_0", "r0", "a"):
|
|
chosen, rejected = r0, r1
|
|
elif vv in ("1", "response_1", "r1", "b"):
|
|
chosen, rejected = r1, r0
|
|
elif isinstance(v, bool):
|
|
# ambiguous; assume True => response_1 wins
|
|
if v is True:
|
|
chosen, rejected = r1, r0
|
|
else:
|
|
chosen, rejected = r0, r1
|
|
|
|
return {"prompt": p, "chosen": chosen, "rejected": rejected}
|
|
|
|
dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected")
|
|
sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])},
|
|
remove_columns=dpo_tmp.column_names, desc="Build SFT")
|
|
dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"])
|
|
|
|
os.makedirs(SFT_OUT, exist_ok=True)
|
|
os.makedirs(DPO_OUT, exist_ok=True)
|
|
|
|
sft.save_to_disk(SFT_OUT)
|
|
dpo.save_to_disk(DPO_OUT)
|
|
|
|
print("[OK] Saved SFT dataset:", SFT_OUT)
|
|
print("[OK] Saved DPO dataset:", DPO_OUT)
|
|
print("[INFO] SFT rows:", len(sft))
|
|
print("[INFO] DPO rows:", len(dpo))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|