339 lines
10 KiB
Python
339 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import re
|
|
import math
|
|
import argparse
|
|
import multiprocessing as mp
|
|
from typing import List, Dict, Any
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
from datasets import load_dataset, Dataset, concatenate_datasets
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
# =========================
|
|
# Fixed settings (your setup)
|
|
# =========================
|
|
DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF-30K"
|
|
SPLIT = "train"
|
|
|
|
NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B"
|
|
|
|
# Output (translated dataset)
|
|
OUT_FINAL_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
|
|
|
|
SRC_LANG = "eng_Latn"
|
|
TGT_LANG = "slk_Latn"
|
|
|
|
# Speed defaults for 2x RTX Titan 24GB
|
|
MODEL_BATCH_SIZE = 32 # if OOM: 24 -> 16
|
|
MAP_BATCH_SIZE = 128 # CPU/RAM side
|
|
LONG_THRESHOLD_CHARS = 3500 # higher => fewer slow fallbacks
|
|
NUM_BEAMS = 1 # MUST be 1 for speed
|
|
MAX_NEW_TOKENS = 128 # can lower to 96 for speed
|
|
|
|
|
|
# =========================
|
|
# Utilities
|
|
# =========================
|
|
def normalize_text(x: Any) -> str:
|
|
if x is None:
|
|
return ""
|
|
return str(x).replace("\uFFFD", "").strip()
|
|
|
|
|
|
def split_text_safely(text: str, max_chars: int) -> List[str]:
|
|
text = normalize_text(text)
|
|
if not text:
|
|
return [""]
|
|
|
|
chunks: List[str] = []
|
|
|
|
def push(piece: str):
|
|
piece = piece.strip()
|
|
if not piece:
|
|
return
|
|
if len(piece) <= max_chars:
|
|
chunks.append(piece)
|
|
else:
|
|
for i in range(0, len(piece), max_chars):
|
|
part = piece[i:i + max_chars].strip()
|
|
if part:
|
|
chunks.append(part)
|
|
|
|
paras = re.split(r"\n{2,}", text)
|
|
for p in paras:
|
|
p = p.strip()
|
|
if p:
|
|
push(p)
|
|
|
|
return chunks if chunks else [""]
|
|
|
|
|
|
def detect_needed_cols_for_sft_dpo(ds: Dataset) -> List[str]:
|
|
"""
|
|
Auto-detect columns needed for SFT->DPO training.
|
|
Priority:
|
|
1) prompt + chosen + rejected (classic preference schema)
|
|
2) prompt + response_0 + response_1 (pairwise schema)
|
|
3) fallback: prompt + any reasonable response fields
|
|
"""
|
|
cols = set(ds.column_names)
|
|
|
|
# classic preference format
|
|
if {"prompt", "chosen", "rejected"}.issubset(cols):
|
|
return ["prompt", "chosen", "rejected"]
|
|
|
|
# common pairwise format
|
|
if {"prompt", "response_0", "response_1"}.issubset(cols):
|
|
return ["prompt", "response_0", "response_1"]
|
|
|
|
# fallback: translate prompt + any text response columns that exist
|
|
candidates = []
|
|
for c in ds.column_names:
|
|
if c == "prompt":
|
|
continue
|
|
if re.match(r"^(chosen|rejected|response_\d+|answer_\d+|completion_\d+)$", c):
|
|
candidates.append(c)
|
|
|
|
if "prompt" in cols and candidates:
|
|
return ["prompt"] + sorted(candidates)
|
|
|
|
# last resort: translate only prompt (still useful for later)
|
|
if "prompt" in cols:
|
|
return ["prompt"]
|
|
|
|
raise RuntimeError("Could not detect a 'prompt' column; cannot proceed.")
|
|
|
|
|
|
@torch.inference_mode()
|
|
def nllb_translate_batch(tokenizer, model, texts: List[str]) -> List[str]:
|
|
tokenizer.src_lang = SRC_LANG
|
|
forced_bos = tokenizer.convert_tokens_to_ids(TGT_LANG)
|
|
|
|
inputs = tokenizer(
|
|
texts,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=1024,
|
|
)
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
|
|
|
gen = model.generate(
|
|
**inputs,
|
|
forced_bos_token_id=forced_bos,
|
|
max_new_tokens=MAX_NEW_TOKENS,
|
|
num_beams=NUM_BEAMS,
|
|
do_sample=False,
|
|
use_cache=True,
|
|
)
|
|
out = tokenizer.batch_decode(gen, skip_special_tokens=True)
|
|
return [o.strip() for o in out]
|
|
|
|
|
|
def translate_long_text(tokenizer, model, text: str, sub_batch: int) -> str:
|
|
text = normalize_text(text)
|
|
if not text:
|
|
return ""
|
|
chunks = split_text_safely(text, max_chars=LONG_THRESHOLD_CHARS)
|
|
out_chunks: List[str] = []
|
|
for i in range(0, len(chunks), sub_batch):
|
|
out_chunks.extend(nllb_translate_batch(tokenizer, model, chunks[i:i + sub_batch]))
|
|
return "\n\n".join([c for c in out_chunks if c])
|
|
|
|
|
|
def _safe_mkdir(p: str):
|
|
os.makedirs(p, exist_ok=True)
|
|
|
|
|
|
# =========================
|
|
# Worker (one GPU)
|
|
# =========================
|
|
def worker_translate(
|
|
shard_id: int,
|
|
gpu_id: int,
|
|
total_shards: int,
|
|
tmp_dir: str,
|
|
cols_to_translate: List[str],
|
|
total_len: int,
|
|
):
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
|
|
torch.set_grad_enabled(False)
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
dtype = torch.float16 if device == "cuda" else torch.float32
|
|
|
|
print(f"\n[WORKER {shard_id}] GPU={gpu_id} device={device} dtype={dtype}")
|
|
print(f"[WORKER {shard_id}] Translating columns: {cols_to_translate}")
|
|
|
|
ds = load_dataset(DATASET_NAME, split=SPLIT)
|
|
n = len(ds)
|
|
if n != total_len:
|
|
print(f"[WORKER {shard_id}] WARN: dataset length mismatch: got {n}, expected {total_len}")
|
|
|
|
# strict sharding by index => no missing rows
|
|
indices = [i for i in range(n) if (i % total_shards) == shard_id]
|
|
ds_shard = ds.select(indices)
|
|
shard_len = len(ds_shard)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
NLLB_PATH,
|
|
torch_dtype=dtype,
|
|
low_cpu_mem_usage=True,
|
|
).to(device)
|
|
model.eval()
|
|
|
|
total_map_batches = math.ceil(shard_len / MAP_BATCH_SIZE)
|
|
|
|
def map_fn(batch: Dict[str, List[Any]]) -> Dict[str, List[str]]:
|
|
out: Dict[str, List[str]] = {}
|
|
|
|
for col in cols_to_translate:
|
|
vals = [normalize_text(v) for v in batch[col]]
|
|
|
|
short_texts: List[str] = []
|
|
short_pos: List[int] = []
|
|
long_texts: List[str] = []
|
|
long_pos: List[int] = []
|
|
|
|
for j, t in enumerate(vals):
|
|
if (not t) or (len(t) <= LONG_THRESHOLD_CHARS):
|
|
short_pos.append(j)
|
|
short_texts.append(t)
|
|
else:
|
|
long_pos.append(j)
|
|
long_texts.append(t)
|
|
|
|
translated = [""] * len(vals)
|
|
|
|
# fast batched translate
|
|
if short_texts:
|
|
tr_short: List[str] = []
|
|
for k in range(0, len(short_texts), MODEL_BATCH_SIZE):
|
|
tr_short.extend(nllb_translate_batch(tokenizer, model, short_texts[k:k + MODEL_BATCH_SIZE]))
|
|
for pos, tr in zip(short_pos, tr_short):
|
|
translated[pos] = tr
|
|
|
|
# slow fallback for very long texts
|
|
for pos, t in zip(long_pos, long_texts):
|
|
translated[pos] = translate_long_text(tokenizer, model, t, sub_batch=max(1, MODEL_BATCH_SIZE // 2))
|
|
|
|
out[col] = translated
|
|
|
|
return out
|
|
|
|
# manual loop to show stable progress bar
|
|
parts: List[Dataset] = []
|
|
with tqdm(total=total_map_batches, desc=f"GPU{gpu_id} shard{shard_id}", ncols=100) as pbar:
|
|
for start in range(0, shard_len, MAP_BATCH_SIZE):
|
|
end = min(shard_len, start + MAP_BATCH_SIZE)
|
|
part = ds_shard.select(range(start, end)).map(
|
|
map_fn,
|
|
batched=True,
|
|
batch_size=end - start,
|
|
desc=None,
|
|
)
|
|
parts.append(part)
|
|
pbar.update(1)
|
|
|
|
ds_tr = concatenate_datasets(parts)
|
|
|
|
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
|
|
_safe_mkdir(shard_path)
|
|
ds_tr.save_to_disk(shard_path)
|
|
print(f"[WORKER {shard_id}] Saved -> {shard_path}")
|
|
|
|
|
|
# =========================
|
|
# Main (one-click)
|
|
# =========================
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--resume", action="store_true", help="Skip existing shards and only merge.")
|
|
args = ap.parse_args()
|
|
|
|
if not os.path.isdir(NLLB_PATH):
|
|
raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}")
|
|
|
|
if torch.cuda.device_count() < 2:
|
|
raise SystemExit(f"[ERROR] Need 2 GPUs; found: {torch.cuda.device_count()}")
|
|
|
|
_safe_mkdir(OUT_FINAL_DIR)
|
|
tmp_dir = os.path.join(OUT_FINAL_DIR, "_tmp_shards")
|
|
_safe_mkdir(tmp_dir)
|
|
|
|
print("[INFO] Loading dataset metadata...")
|
|
ds = load_dataset(DATASET_NAME, split=SPLIT)
|
|
n = len(ds)
|
|
|
|
cols_to_translate = detect_needed_cols_for_sft_dpo(ds)
|
|
print(f"[INFO] Dataset: {DATASET_NAME} split={SPLIT}")
|
|
print(f"[INFO] Total rows: {n}")
|
|
print(f"[INFO] Translating ONLY needed cols for SFT->DPO: {cols_to_translate}")
|
|
print(f"[INFO] Output: {OUT_FINAL_DIR}")
|
|
print(f"[INFO] Params: MODEL_BATCH_SIZE={MODEL_BATCH_SIZE}, MAP_BATCH_SIZE={MAP_BATCH_SIZE}, "
|
|
f"LONG_THRESHOLD_CHARS={LONG_THRESHOLD_CHARS}, NUM_BEAMS={NUM_BEAMS}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}")
|
|
|
|
gpus = [0, 1]
|
|
total_shards = 2
|
|
|
|
mp.set_start_method("spawn", force=True)
|
|
procs = []
|
|
|
|
for shard_id, gpu_id in enumerate(gpus):
|
|
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
|
|
if args.resume and os.path.isdir(shard_path):
|
|
print(f"[INFO] Resume: shard exists, skipping worker {shard_id} -> {shard_path}")
|
|
continue
|
|
|
|
p = mp.Process(
|
|
target=worker_translate,
|
|
args=(shard_id, gpu_id, total_shards, tmp_dir, cols_to_translate, n),
|
|
)
|
|
p.start()
|
|
procs.append(p)
|
|
|
|
for p in procs:
|
|
p.join()
|
|
if p.exitcode != 0:
|
|
raise SystemExit("[ERROR] A worker failed. See logs above.")
|
|
|
|
# merge
|
|
print("\n[INFO] Merging shards...")
|
|
shard_dirs = [os.path.join(tmp_dir, f"shard_{i:02d}") for i in range(total_shards)]
|
|
for sd in shard_dirs:
|
|
if not os.path.isdir(sd):
|
|
raise SystemExit(f"[ERROR] Missing shard directory: {sd}")
|
|
|
|
shards = [Dataset.load_from_disk(sd) for sd in shard_dirs]
|
|
merged = concatenate_datasets(shards)
|
|
|
|
# restore original order
|
|
merged_indices = []
|
|
for shard_id in range(total_shards):
|
|
merged_indices.extend([i for i in range(n) if (i % total_shards) == shard_id])
|
|
|
|
merged = merged.add_column("orig_index", merged_indices)
|
|
merged = merged.sort("orig_index").remove_columns(["orig_index"])
|
|
|
|
if len(merged) != n:
|
|
raise SystemExit(f"[ERROR] Length mismatch after merge: merged={len(merged)} expected={n}")
|
|
|
|
print("[INFO] Saving final dataset...")
|
|
merged.save_to_disk(OUT_FINAL_DIR)
|
|
|
|
print(f"[OK] Done. Saved FULL translated dataset to: {OUT_FINAL_DIR}")
|
|
print("[OK] Verified: no missing rows.")
|
|
print(f"[OK] Translated columns: {cols_to_translate}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|