#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import re import math import argparse import multiprocessing as mp from typing import List, Dict, Any import torch from tqdm import tqdm from datasets import load_dataset, Dataset, concatenate_datasets from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ========================= # Fixed settings (your setup) # ========================= DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF-30K" SPLIT = "train" NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B" # Output (translated dataset) OUT_FINAL_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY" SRC_LANG = "eng_Latn" TGT_LANG = "slk_Latn" # Speed defaults for 2x RTX Titan 24GB MODEL_BATCH_SIZE = 32 # if OOM: 24 -> 16 MAP_BATCH_SIZE = 128 # CPU/RAM side LONG_THRESHOLD_CHARS = 3500 # higher => fewer slow fallbacks NUM_BEAMS = 1 # MUST be 1 for speed MAX_NEW_TOKENS = 128 # can lower to 96 for speed # ========================= # Utilities # ========================= def normalize_text(x: Any) -> str: if x is None: return "" return str(x).replace("\uFFFD", "").strip() def split_text_safely(text: str, max_chars: int) -> List[str]: text = normalize_text(text) if not text: return [""] chunks: List[str] = [] def push(piece: str): piece = piece.strip() if not piece: return if len(piece) <= max_chars: chunks.append(piece) else: for i in range(0, len(piece), max_chars): part = piece[i:i + max_chars].strip() if part: chunks.append(part) paras = re.split(r"\n{2,}", text) for p in paras: p = p.strip() if p: push(p) return chunks if chunks else [""] def detect_needed_cols_for_sft_dpo(ds: Dataset) -> List[str]: """ Auto-detect columns needed for SFT->DPO training. Priority: 1) prompt + chosen + rejected (classic preference schema) 2) prompt + response_0 + response_1 (pairwise schema) 3) fallback: prompt + any reasonable response fields """ cols = set(ds.column_names) # classic preference format if {"prompt", "chosen", "rejected"}.issubset(cols): return ["prompt", "chosen", "rejected"] # common pairwise format if {"prompt", "response_0", "response_1"}.issubset(cols): return ["prompt", "response_0", "response_1"] # fallback: translate prompt + any text response columns that exist candidates = [] for c in ds.column_names: if c == "prompt": continue if re.match(r"^(chosen|rejected|response_\d+|answer_\d+|completion_\d+)$", c): candidates.append(c) if "prompt" in cols and candidates: return ["prompt"] + sorted(candidates) # last resort: translate only prompt (still useful for later) if "prompt" in cols: return ["prompt"] raise RuntimeError("Could not detect a 'prompt' column; cannot proceed.") @torch.inference_mode() def nllb_translate_batch(tokenizer, model, texts: List[str]) -> List[str]: tokenizer.src_lang = SRC_LANG forced_bos = tokenizer.convert_tokens_to_ids(TGT_LANG) inputs = tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=1024, ) inputs = {k: v.to(model.device) for k, v in inputs.items()} gen = model.generate( **inputs, forced_bos_token_id=forced_bos, max_new_tokens=MAX_NEW_TOKENS, num_beams=NUM_BEAMS, do_sample=False, use_cache=True, ) out = tokenizer.batch_decode(gen, skip_special_tokens=True) return [o.strip() for o in out] def translate_long_text(tokenizer, model, text: str, sub_batch: int) -> str: text = normalize_text(text) if not text: return "" chunks = split_text_safely(text, max_chars=LONG_THRESHOLD_CHARS) out_chunks: List[str] = [] for i in range(0, len(chunks), sub_batch): out_chunks.extend(nllb_translate_batch(tokenizer, model, chunks[i:i + sub_batch])) return "\n\n".join([c for c in out_chunks if c]) def _safe_mkdir(p: str): os.makedirs(p, exist_ok=True) # ========================= # Worker (one GPU) # ========================= def worker_translate( shard_id: int, gpu_id: int, total_shards: int, tmp_dir: str, cols_to_translate: List[str], total_len: int, ): os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) torch.set_grad_enabled(False) torch.backends.cudnn.benchmark = True device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 print(f"\n[WORKER {shard_id}] GPU={gpu_id} device={device} dtype={dtype}") print(f"[WORKER {shard_id}] Translating columns: {cols_to_translate}") ds = load_dataset(DATASET_NAME, split=SPLIT) n = len(ds) if n != total_len: print(f"[WORKER {shard_id}] WARN: dataset length mismatch: got {n}, expected {total_len}") # strict sharding by index => no missing rows indices = [i for i in range(n) if (i % total_shards) == shard_id] ds_shard = ds.select(indices) shard_len = len(ds_shard) tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True) model = AutoModelForSeq2SeqLM.from_pretrained( NLLB_PATH, torch_dtype=dtype, low_cpu_mem_usage=True, ).to(device) model.eval() total_map_batches = math.ceil(shard_len / MAP_BATCH_SIZE) def map_fn(batch: Dict[str, List[Any]]) -> Dict[str, List[str]]: out: Dict[str, List[str]] = {} for col in cols_to_translate: vals = [normalize_text(v) for v in batch[col]] short_texts: List[str] = [] short_pos: List[int] = [] long_texts: List[str] = [] long_pos: List[int] = [] for j, t in enumerate(vals): if (not t) or (len(t) <= LONG_THRESHOLD_CHARS): short_pos.append(j) short_texts.append(t) else: long_pos.append(j) long_texts.append(t) translated = [""] * len(vals) # fast batched translate if short_texts: tr_short: List[str] = [] for k in range(0, len(short_texts), MODEL_BATCH_SIZE): tr_short.extend(nllb_translate_batch(tokenizer, model, short_texts[k:k + MODEL_BATCH_SIZE])) for pos, tr in zip(short_pos, tr_short): translated[pos] = tr # slow fallback for very long texts for pos, t in zip(long_pos, long_texts): translated[pos] = translate_long_text(tokenizer, model, t, sub_batch=max(1, MODEL_BATCH_SIZE // 2)) out[col] = translated return out # manual loop to show stable progress bar parts: List[Dataset] = [] with tqdm(total=total_map_batches, desc=f"GPU{gpu_id} shard{shard_id}", ncols=100) as pbar: for start in range(0, shard_len, MAP_BATCH_SIZE): end = min(shard_len, start + MAP_BATCH_SIZE) part = ds_shard.select(range(start, end)).map( map_fn, batched=True, batch_size=end - start, desc=None, ) parts.append(part) pbar.update(1) ds_tr = concatenate_datasets(parts) shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}") _safe_mkdir(shard_path) ds_tr.save_to_disk(shard_path) print(f"[WORKER {shard_id}] Saved -> {shard_path}") # ========================= # Main (one-click) # ========================= def main(): ap = argparse.ArgumentParser() ap.add_argument("--resume", action="store_true", help="Skip existing shards and only merge.") args = ap.parse_args() if not os.path.isdir(NLLB_PATH): raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}") if torch.cuda.device_count() < 2: raise SystemExit(f"[ERROR] Need 2 GPUs; found: {torch.cuda.device_count()}") _safe_mkdir(OUT_FINAL_DIR) tmp_dir = os.path.join(OUT_FINAL_DIR, "_tmp_shards") _safe_mkdir(tmp_dir) print("[INFO] Loading dataset metadata...") ds = load_dataset(DATASET_NAME, split=SPLIT) n = len(ds) cols_to_translate = detect_needed_cols_for_sft_dpo(ds) print(f"[INFO] Dataset: {DATASET_NAME} split={SPLIT}") print(f"[INFO] Total rows: {n}") print(f"[INFO] Translating ONLY needed cols for SFT->DPO: {cols_to_translate}") print(f"[INFO] Output: {OUT_FINAL_DIR}") print(f"[INFO] Params: MODEL_BATCH_SIZE={MODEL_BATCH_SIZE}, MAP_BATCH_SIZE={MAP_BATCH_SIZE}, " f"LONG_THRESHOLD_CHARS={LONG_THRESHOLD_CHARS}, NUM_BEAMS={NUM_BEAMS}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}") gpus = [0, 1] total_shards = 2 mp.set_start_method("spawn", force=True) procs = [] for shard_id, gpu_id in enumerate(gpus): shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}") if args.resume and os.path.isdir(shard_path): print(f"[INFO] Resume: shard exists, skipping worker {shard_id} -> {shard_path}") continue p = mp.Process( target=worker_translate, args=(shard_id, gpu_id, total_shards, tmp_dir, cols_to_translate, n), ) p.start() procs.append(p) for p in procs: p.join() if p.exitcode != 0: raise SystemExit("[ERROR] A worker failed. See logs above.") # merge print("\n[INFO] Merging shards...") shard_dirs = [os.path.join(tmp_dir, f"shard_{i:02d}") for i in range(total_shards)] for sd in shard_dirs: if not os.path.isdir(sd): raise SystemExit(f"[ERROR] Missing shard directory: {sd}") shards = [Dataset.load_from_disk(sd) for sd in shard_dirs] merged = concatenate_datasets(shards) # restore original order merged_indices = [] for shard_id in range(total_shards): merged_indices.extend([i for i in range(n) if (i % total_shards) == shard_id]) merged = merged.add_column("orig_index", merged_indices) merged = merged.sort("orig_index").remove_columns(["orig_index"]) if len(merged) != n: raise SystemExit(f"[ERROR] Length mismatch after merge: merged={len(merged)} expected={n}") print("[INFO] Saving final dataset...") merged.save_to_disk(OUT_FINAL_DIR) print(f"[OK] Done. Saved FULL translated dataset to: {OUT_FINAL_DIR}") print("[OK] Verified: no missing rows.") print(f"[OK] Translated columns: {cols_to_translate}") if __name__ == "__main__": main()