#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import json from datasets import load_from_disk, DatasetDict SRC_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY" OUT_DIR = "/home/hyrenko/Diploma/datasets" SFT_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_sft") DPO_OUT = os.path.join(OUT_DIR, "pku_saferlhf_sk_dpo") def find_schema(ds): cols = set(ds.column_names) # Best case if {"prompt", "chosen", "rejected"}.issubset(cols): return "chosen_rejected" # Pairwise variants if {"prompt", "response_0", "response_1"}.issubset(cols): # need preference column to choose winner; try common names pref_candidates = [ "choice", "label", "winner", "preferred", "preference", "better", "better_response", "better_response_id", "chosen_id", "win", "response_choice" ] pref = None for c in pref_candidates: if c in cols: pref = c break return ("response01", pref) # fallback raise RuntimeError(f"Unknown schema. Columns: {ds.column_names}") def normalize_prompt(p: str) -> str: p = (p or "").strip() return p def make_sft_text(prompt: str, answer: str) -> str: # Simple, stable template for causal LM (works even if model is not chat-instruct) # You can change to Slovak system-style later. prompt = normalize_prompt(prompt) answer = (answer or "").strip() return f"### User:\n{prompt}\n\n### Assistant:\n{answer}" def main(): ds = load_from_disk(SRC_DIR) if isinstance(ds, DatasetDict): # take train if exists split = "train" if "train" in ds else list(ds.keys())[0] ds = ds[split] schema = find_schema(ds) print("[INFO] Detected schema:", schema) if schema == "chosen_rejected": def to_sft(ex): return {"text": make_sft_text(ex["prompt"], ex["chosen"])} def to_dpo(ex): return { "prompt": normalize_prompt(ex["prompt"]), "chosen": (ex["chosen"] or "").strip(), "rejected": (ex["rejected"] or "").strip(), } sft = ds.map(to_sft, remove_columns=ds.column_names, desc="Build SFT") dpo = ds.map(to_dpo, remove_columns=[c for c in ds.column_names if c not in ("prompt", "chosen", "rejected")], desc="Build DPO") else: # response_0/response_1 + maybe preference column _, pref_col = schema if pref_col is None: print("[WARN] Preference column not found. Defaulting chosen=response_0, rejected=response_1 for ALL rows.") else: print("[INFO] Using preference column:", pref_col) def pick_chosen_rejected(ex): p = normalize_prompt(ex["prompt"]) r0 = (ex["response_0"] or "").strip() r1 = (ex["response_1"] or "").strip() if pref_col is None: chosen, rejected = r0, r1 else: v = ex.get(pref_col) # Heuristics for common encodings: # - 0/1 meaning winner index # - "response_0"/"response_1" # - "A"/"B" # - True means response_1 better (varies) chosen, rejected = r0, r1 if isinstance(v, int): if v == 0: chosen, rejected = r0, r1 elif v == 1: chosen, rejected = r1, r0 elif isinstance(v, str): vv = v.strip().lower() if vv in ("0", "response_0", "r0", "a"): chosen, rejected = r0, r1 elif vv in ("1", "response_1", "r1", "b"): chosen, rejected = r1, r0 elif isinstance(v, bool): # ambiguous; assume True => response_1 wins if v is True: chosen, rejected = r1, r0 else: chosen, rejected = r0, r1 return {"prompt": p, "chosen": chosen, "rejected": rejected} dpo_tmp = ds.map(pick_chosen_rejected, desc="Derive chosen/rejected") sft = dpo_tmp.map(lambda ex: {"text": make_sft_text(ex["prompt"], ex["chosen"])}, remove_columns=dpo_tmp.column_names, desc="Build SFT") dpo = dpo_tmp.select_columns(["prompt", "chosen", "rejected"]) os.makedirs(SFT_OUT, exist_ok=True) os.makedirs(DPO_OUT, exist_ok=True) sft.save_to_disk(SFT_OUT) dpo.save_to_disk(DPO_OUT) print("[OK] Saved SFT dataset:", SFT_OUT) print("[OK] Saved DPO dataset:", DPO_OUT) print("[INFO] SFT rows:", len(sft)) print("[INFO] DPO rows:", len(dpo)) if __name__ == "__main__": main()