DP/translate/translate_PKF.py
2026-02-04 21:07:17 +01:00

339 lines
10 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import re
import math
import argparse
import multiprocessing as mp
from typing import List, Dict, Any
import torch
from tqdm import tqdm
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# =========================
# Fixed settings (your setup)
# =========================
DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF-30K"
SPLIT = "train"
NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B"
# Output (translated dataset)
OUT_FINAL_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
SRC_LANG = "eng_Latn"
TGT_LANG = "slk_Latn"
# Speed defaults for 2x RTX Titan 24GB
MODEL_BATCH_SIZE = 32 # if OOM: 24 -> 16
MAP_BATCH_SIZE = 128 # CPU/RAM side
LONG_THRESHOLD_CHARS = 3500 # higher => fewer slow fallbacks
NUM_BEAMS = 1 # MUST be 1 for speed
MAX_NEW_TOKENS = 128 # can lower to 96 for speed
# =========================
# Utilities
# =========================
def normalize_text(x: Any) -> str:
if x is None:
return ""
return str(x).replace("\uFFFD", "").strip()
def split_text_safely(text: str, max_chars: int) -> List[str]:
text = normalize_text(text)
if not text:
return [""]
chunks: List[str] = []
def push(piece: str):
piece = piece.strip()
if not piece:
return
if len(piece) <= max_chars:
chunks.append(piece)
else:
for i in range(0, len(piece), max_chars):
part = piece[i:i + max_chars].strip()
if part:
chunks.append(part)
paras = re.split(r"\n{2,}", text)
for p in paras:
p = p.strip()
if p:
push(p)
return chunks if chunks else [""]
def detect_needed_cols_for_sft_dpo(ds: Dataset) -> List[str]:
"""
Auto-detect columns needed for SFT->DPO training.
Priority:
1) prompt + chosen + rejected (classic preference schema)
2) prompt + response_0 + response_1 (pairwise schema)
3) fallback: prompt + any reasonable response fields
"""
cols = set(ds.column_names)
# classic preference format
if {"prompt", "chosen", "rejected"}.issubset(cols):
return ["prompt", "chosen", "rejected"]
# common pairwise format
if {"prompt", "response_0", "response_1"}.issubset(cols):
return ["prompt", "response_0", "response_1"]
# fallback: translate prompt + any text response columns that exist
candidates = []
for c in ds.column_names:
if c == "prompt":
continue
if re.match(r"^(chosen|rejected|response_\d+|answer_\d+|completion_\d+)$", c):
candidates.append(c)
if "prompt" in cols and candidates:
return ["prompt"] + sorted(candidates)
# last resort: translate only prompt (still useful for later)
if "prompt" in cols:
return ["prompt"]
raise RuntimeError("Could not detect a 'prompt' column; cannot proceed.")
@torch.inference_mode()
def nllb_translate_batch(tokenizer, model, texts: List[str]) -> List[str]:
tokenizer.src_lang = SRC_LANG
forced_bos = tokenizer.convert_tokens_to_ids(TGT_LANG)
inputs = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen = model.generate(
**inputs,
forced_bos_token_id=forced_bos,
max_new_tokens=MAX_NEW_TOKENS,
num_beams=NUM_BEAMS,
do_sample=False,
use_cache=True,
)
out = tokenizer.batch_decode(gen, skip_special_tokens=True)
return [o.strip() for o in out]
def translate_long_text(tokenizer, model, text: str, sub_batch: int) -> str:
text = normalize_text(text)
if not text:
return ""
chunks = split_text_safely(text, max_chars=LONG_THRESHOLD_CHARS)
out_chunks: List[str] = []
for i in range(0, len(chunks), sub_batch):
out_chunks.extend(nllb_translate_batch(tokenizer, model, chunks[i:i + sub_batch]))
return "\n\n".join([c for c in out_chunks if c])
def _safe_mkdir(p: str):
os.makedirs(p, exist_ok=True)
# =========================
# Worker (one GPU)
# =========================
def worker_translate(
shard_id: int,
gpu_id: int,
total_shards: int,
tmp_dir: str,
cols_to_translate: List[str],
total_len: int,
):
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print(f"\n[WORKER {shard_id}] GPU={gpu_id} device={device} dtype={dtype}")
print(f"[WORKER {shard_id}] Translating columns: {cols_to_translate}")
ds = load_dataset(DATASET_NAME, split=SPLIT)
n = len(ds)
if n != total_len:
print(f"[WORKER {shard_id}] WARN: dataset length mismatch: got {n}, expected {total_len}")
# strict sharding by index => no missing rows
indices = [i for i in range(n) if (i % total_shards) == shard_id]
ds_shard = ds.select(indices)
shard_len = len(ds_shard)
tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
NLLB_PATH,
torch_dtype=dtype,
low_cpu_mem_usage=True,
).to(device)
model.eval()
total_map_batches = math.ceil(shard_len / MAP_BATCH_SIZE)
def map_fn(batch: Dict[str, List[Any]]) -> Dict[str, List[str]]:
out: Dict[str, List[str]] = {}
for col in cols_to_translate:
vals = [normalize_text(v) for v in batch[col]]
short_texts: List[str] = []
short_pos: List[int] = []
long_texts: List[str] = []
long_pos: List[int] = []
for j, t in enumerate(vals):
if (not t) or (len(t) <= LONG_THRESHOLD_CHARS):
short_pos.append(j)
short_texts.append(t)
else:
long_pos.append(j)
long_texts.append(t)
translated = [""] * len(vals)
# fast batched translate
if short_texts:
tr_short: List[str] = []
for k in range(0, len(short_texts), MODEL_BATCH_SIZE):
tr_short.extend(nllb_translate_batch(tokenizer, model, short_texts[k:k + MODEL_BATCH_SIZE]))
for pos, tr in zip(short_pos, tr_short):
translated[pos] = tr
# slow fallback for very long texts
for pos, t in zip(long_pos, long_texts):
translated[pos] = translate_long_text(tokenizer, model, t, sub_batch=max(1, MODEL_BATCH_SIZE // 2))
out[col] = translated
return out
# manual loop to show stable progress bar
parts: List[Dataset] = []
with tqdm(total=total_map_batches, desc=f"GPU{gpu_id} shard{shard_id}", ncols=100) as pbar:
for start in range(0, shard_len, MAP_BATCH_SIZE):
end = min(shard_len, start + MAP_BATCH_SIZE)
part = ds_shard.select(range(start, end)).map(
map_fn,
batched=True,
batch_size=end - start,
desc=None,
)
parts.append(part)
pbar.update(1)
ds_tr = concatenate_datasets(parts)
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
_safe_mkdir(shard_path)
ds_tr.save_to_disk(shard_path)
print(f"[WORKER {shard_id}] Saved -> {shard_path}")
# =========================
# Main (one-click)
# =========================
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--resume", action="store_true", help="Skip existing shards and only merge.")
args = ap.parse_args()
if not os.path.isdir(NLLB_PATH):
raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}")
if torch.cuda.device_count() < 2:
raise SystemExit(f"[ERROR] Need 2 GPUs; found: {torch.cuda.device_count()}")
_safe_mkdir(OUT_FINAL_DIR)
tmp_dir = os.path.join(OUT_FINAL_DIR, "_tmp_shards")
_safe_mkdir(tmp_dir)
print("[INFO] Loading dataset metadata...")
ds = load_dataset(DATASET_NAME, split=SPLIT)
n = len(ds)
cols_to_translate = detect_needed_cols_for_sft_dpo(ds)
print(f"[INFO] Dataset: {DATASET_NAME} split={SPLIT}")
print(f"[INFO] Total rows: {n}")
print(f"[INFO] Translating ONLY needed cols for SFT->DPO: {cols_to_translate}")
print(f"[INFO] Output: {OUT_FINAL_DIR}")
print(f"[INFO] Params: MODEL_BATCH_SIZE={MODEL_BATCH_SIZE}, MAP_BATCH_SIZE={MAP_BATCH_SIZE}, "
f"LONG_THRESHOLD_CHARS={LONG_THRESHOLD_CHARS}, NUM_BEAMS={NUM_BEAMS}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}")
gpus = [0, 1]
total_shards = 2
mp.set_start_method("spawn", force=True)
procs = []
for shard_id, gpu_id in enumerate(gpus):
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
if args.resume and os.path.isdir(shard_path):
print(f"[INFO] Resume: shard exists, skipping worker {shard_id} -> {shard_path}")
continue
p = mp.Process(
target=worker_translate,
args=(shard_id, gpu_id, total_shards, tmp_dir, cols_to_translate, n),
)
p.start()
procs.append(p)
for p in procs:
p.join()
if p.exitcode != 0:
raise SystemExit("[ERROR] A worker failed. See logs above.")
# merge
print("\n[INFO] Merging shards...")
shard_dirs = [os.path.join(tmp_dir, f"shard_{i:02d}") for i in range(total_shards)]
for sd in shard_dirs:
if not os.path.isdir(sd):
raise SystemExit(f"[ERROR] Missing shard directory: {sd}")
shards = [Dataset.load_from_disk(sd) for sd in shard_dirs]
merged = concatenate_datasets(shards)
# restore original order
merged_indices = []
for shard_id in range(total_shards):
merged_indices.extend([i for i in range(n) if (i % total_shards) == shard_id])
merged = merged.add_column("orig_index", merged_indices)
merged = merged.sort("orig_index").remove_columns(["orig_index"])
if len(merged) != n:
raise SystemExit(f"[ERROR] Length mismatch after merge: merged={len(merged)} expected={n}")
print("[INFO] Saving final dataset...")
merged.save_to_disk(OUT_FINAL_DIR)
print(f"[OK] Done. Saved FULL translated dataset to: {OUT_FINAL_DIR}")
print("[OK] Verified: no missing rows.")
print(f"[OK] Translated columns: {cols_to_translate}")
if __name__ == "__main__":
main()