diff --git a/translate/translate_PKF.py b/translate/translate_PKF.py index 1565013..729aeb8 100644 --- a/translate/translate_PKF.py +++ b/translate/translate_PKF.py @@ -15,20 +15,27 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # ========================= -# Fixed settings (your setup) -# ========================= +# Dataset source on Hugging Face Hub and which split to translate. DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF-30K" SPLIT = "train" +# Local path to the NLLB model (loaded from disk, not from the Hub). NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B" -# Output (translated dataset) +# Output directory where the fully translated dataset will be saved. OUT_FINAL_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY" +# Language codes for NLLB (source and target). SRC_LANG = "eng_Latn" TGT_LANG = "slk_Latn" -# Speed defaults for 2x RTX Titan 24GB +# Performance-oriented defaults for a 2x RTX Titan 24GB setup. +# Notes: +# - MODEL_BATCH_SIZE controls GPU-side translation batching. +# - MAP_BATCH_SIZE controls CPU/RAM side batching when mapping over a dataset shard. +# - LONG_THRESHOLD_CHARS splits very long texts to avoid slow/unstable generation. +# - NUM_BEAMS must be 1 for speed (beam search is expensive). +# - MAX_NEW_TOKENS caps translation length; reducing it can speed up generation. MODEL_BATCH_SIZE = 32 # if OOM: 24 -> 16 MAP_BATCH_SIZE = 128 # CPU/RAM side LONG_THRESHOLD_CHARS = 3500 # higher => fewer slow fallbacks @@ -39,12 +46,39 @@ MAX_NEW_TOKENS = 128 # can lower to 96 for speed # ========================= # Utilities # ========================= + +# ----------------------------------------------------------------------------- +# normalize_text(x) +# Purpose: +# Converts arbitrary dataset values into a clean string: +# - Handles None values safely. +# - Removes the Unicode replacement char (�, U+FFFD) which can appear in noisy text. +# - Strips whitespace. +# +# Why it exists: +# Dataset fields may contain None / non-string types / corrupted characters. +# Translation should always receive a valid string input. +# ----------------------------------------------------------------------------- def normalize_text(x: Any) -> str: if x is None: return "" return str(x).replace("\uFFFD", "").strip() +# ----------------------------------------------------------------------------- +# split_text_safely(text, max_chars) +# Purpose: +# Splits long text into smaller chunks to avoid slow generation / memory issues. +# +# Strategy: +# - Normalizes input text. +# - Splits by paragraph boundaries (2+ newlines). +# - Ensures each chunk is <= max_chars. +# - If a paragraph itself is too long, it is cut into fixed-size slices. +# +# Output: +# A list of non-empty chunks; returns [""] for empty input as a safe placeholder. +# ----------------------------------------------------------------------------- def split_text_safely(text: str, max_chars: int) -> List[str]: text = normalize_text(text) if not text: @@ -52,6 +86,11 @@ def split_text_safely(text: str, max_chars: int) -> List[str]: chunks: List[str] = [] + # push(piece) + # Internal helper that: + # - ignores empty pieces + # - keeps pieces <= max_chars + # - slices very long pieces into fixed-size parts def push(piece: str): piece = piece.strip() if not piece: @@ -73,6 +112,21 @@ def split_text_safely(text: str, max_chars: int) -> List[str]: return chunks if chunks else [""] +# ----------------------------------------------------------------------------- +# detect_needed_cols_for_sft_dpo(ds) +# Purpose: +# Auto-detect which columns should be translated so the output dataset is +# immediately usable for SFT and/or DPO preparation. +# +# Priority: +# 1) Classic preference schema: prompt + chosen + rejected +# 2) Pairwise schema: prompt + response_0 + response_1 +# 3) Fallback: prompt + any response-like columns matching a regex pattern +# 4) Last resort: prompt only +# +# Result: +# A list of columns to translate, minimizing work and storage. +# ----------------------------------------------------------------------------- def detect_needed_cols_for_sft_dpo(ds: Dataset) -> List[str]: """ Auto-detect columns needed for SFT->DPO training. @@ -109,6 +163,21 @@ def detect_needed_cols_for_sft_dpo(ds: Dataset) -> List[str]: raise RuntimeError("Could not detect a 'prompt' column; cannot proceed.") +# ----------------------------------------------------------------------------- +# nllb_translate_batch(tokenizer, model, texts) +# Purpose: +# Performs batched translation for a list of short texts via NLLB. +# +# Details: +# - Sets tokenizer.src_lang each time to ensure correct source language. +# - Uses forced_bos_token_id to force decoding in the target language. +# - Tokenizes to max_length=1024 (encoder-side), pads and truncates as needed. +# - Uses model.generate() with max_new_tokens, num_beams, caching for speed. +# - Returns stripped decoded strings. +# +# Important: +# - @torch.inference_mode disables gradients for speed and lower memory. +# ----------------------------------------------------------------------------- @torch.inference_mode() def nllb_translate_batch(tokenizer, model, texts: List[str]) -> List[str]: tokenizer.src_lang = SRC_LANG @@ -135,6 +204,17 @@ def nllb_translate_batch(tokenizer, model, texts: List[str]) -> List[str]: return [o.strip() for o in out] +# ----------------------------------------------------------------------------- +# translate_long_text(tokenizer, model, text, sub_batch) +# Purpose: +# Safe fallback path for very long texts that may be slow or unstable if +# translated in one shot. +# +# Strategy: +# - Split the text into <= LONG_THRESHOLD_CHARS chunks (paragraph-aware). +# - Translate chunks in sub-batches to control VRAM use. +# - Join translated chunks with blank lines to preserve paragraph separation. +# ----------------------------------------------------------------------------- def translate_long_text(tokenizer, model, text: str, sub_batch: int) -> str: text = normalize_text(text) if not text: @@ -146,6 +226,11 @@ def translate_long_text(tokenizer, model, text: str, sub_batch: int) -> str: return "\n\n".join([c for c in out_chunks if c]) +# ----------------------------------------------------------------------------- +# _safe_mkdir(p) +# Purpose: +# Convenience helper to create directories idempotently. +# ----------------------------------------------------------------------------- def _safe_mkdir(p: str): os.makedirs(p, exist_ok=True) @@ -153,6 +238,23 @@ def _safe_mkdir(p: str): # ========================= # Worker (one GPU) # ========================= + +# ----------------------------------------------------------------------------- +# worker_translate(...) +# Purpose: +# Runs translation on a shard of the dataset using a specific GPU. +# +# Key ideas: +# - Uses CUDA_VISIBLE_DEVICES to bind this worker process to a single GPU. +# - Loads the full dataset, then selects only rows belonging to this shard via: +# index % total_shards == shard_id +# This guarantees a complete partition with no missing rows across workers. +# - Translates only the required columns (cols_to_translate). +# - Uses a two-path translation: +# * fast path for short texts (batched GPU translation) +# * slow fallback for very long texts (split + smaller sub-batches) +# - Saves this shard’s translated dataset to tmp_dir/shard_XX for later merging. +# ----------------------------------------------------------------------------- def worker_translate( shard_id: int, gpu_id: int, @@ -161,8 +263,10 @@ def worker_translate( cols_to_translate: List[str], total_len: int, ): + # Bind this process to a specific GPU. os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + # Global inference optimizations (no gradients). torch.set_grad_enabled(False) torch.backends.cudnn.benchmark = True @@ -172,16 +276,18 @@ def worker_translate( print(f"\n[WORKER {shard_id}] GPU={gpu_id} device={device} dtype={dtype}") print(f"[WORKER {shard_id}] Translating columns: {cols_to_translate}") + # Load the dataset split (each worker loads it independently). ds = load_dataset(DATASET_NAME, split=SPLIT) n = len(ds) if n != total_len: print(f"[WORKER {shard_id}] WARN: dataset length mismatch: got {n}, expected {total_len}") - # strict sharding by index => no missing rows + # Strict sharding by original index to avoid missing rows and keep deterministic coverage. indices = [i for i in range(n) if (i % total_shards) == shard_id] ds_shard = ds.select(indices) shard_len = len(ds_shard) + # Load tokenizer + model from local path (NLLB). tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True) model = AutoModelForSeq2SeqLM.from_pretrained( NLLB_PATH, @@ -190,8 +296,20 @@ def worker_translate( ).to(device) model.eval() + # Number of CPU-side map batches (used only for progress bar granularity). total_map_batches = math.ceil(shard_len / MAP_BATCH_SIZE) + # map_fn(batch) + # Purpose: + # Translates the requested columns for a single CPU-side batch. + # + # Workflow per column: + # 1) Normalize all values to strings. + # 2) Split into short_texts and long_texts by LONG_THRESHOLD_CHARS. + # 3) Translate short texts with fast batched translation: + # - internal sub-batches of MODEL_BATCH_SIZE + # 4) Translate long texts individually through translate_long_text(). + # 5) Return a dict: {col_name: translated_list} for each requested column. def map_fn(batch: Dict[str, List[Any]]) -> Dict[str, List[str]]: out: Dict[str, List[str]] = {} @@ -203,6 +321,7 @@ def worker_translate( long_texts: List[str] = [] long_pos: List[int] = [] + # Partition examples by length so most data goes through the fast path. for j, t in enumerate(vals): if (not t) or (len(t) <= LONG_THRESHOLD_CHARS): short_pos.append(j) @@ -213,7 +332,7 @@ def worker_translate( translated = [""] * len(vals) - # fast batched translate + # Fast batched translate for short texts (GPU-optimized). if short_texts: tr_short: List[str] = [] for k in range(0, len(short_texts), MODEL_BATCH_SIZE): @@ -221,7 +340,7 @@ def worker_translate( for pos, tr in zip(short_pos, tr_short): translated[pos] = tr - # slow fallback for very long texts + # Slow fallback for very long texts (split + smaller batches). for pos, t in zip(long_pos, long_texts): translated[pos] = translate_long_text(tokenizer, model, t, sub_batch=max(1, MODEL_BATCH_SIZE // 2)) @@ -229,7 +348,8 @@ def worker_translate( return out - # manual loop to show stable progress bar + # Manual loop (instead of a single ds_shard.map) to show a stable progress bar. + # Each loop translates a slice [start:end] of the shard. parts: List[Dataset] = [] with tqdm(total=total_map_batches, desc=f"GPU{gpu_id} shard{shard_id}", ncols=100) as pbar: for start in range(0, shard_len, MAP_BATCH_SIZE): @@ -243,8 +363,10 @@ def worker_translate( parts.append(part) pbar.update(1) + # Concatenate translated parts into a single shard dataset. ds_tr = concatenate_datasets(parts) + # Save this shard to disk for later merge. shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}") _safe_mkdir(shard_path) ds_tr.save_to_disk(shard_path) @@ -254,21 +376,44 @@ def worker_translate( # ========================= # Main (one-click) # ========================= + +# ----------------------------------------------------------------------------- +# main() +# Purpose: +# One-click launcher for 2-GPU translation + shard merge. +# +# Steps: +# 1) Parse CLI args: +# --resume: skip workers for shards that already exist and only merge. +# 2) Validate: +# - NLLB model path exists +# - at least 2 GPUs are available (script expects exactly 2 workers here) +# 3) Load dataset metadata (length + schema) and determine which columns to translate. +# 4) Spawn 2 worker processes (GPU 0 and GPU 1) using spawn start method. +# 5) After workers finish, merge shard datasets and restore original row order. +# 6) Save the final merged dataset to OUT_FINAL_DIR and verify row count. +# ----------------------------------------------------------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--resume", action="store_true", help="Skip existing shards and only merge.") args = ap.parse_args() + # Ensure the local NLLB path is valid before starting. if not os.path.isdir(NLLB_PATH): raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}") + # This script is designed for a 2-GPU setup (two workers). if torch.cuda.device_count() < 2: raise SystemExit(f"[ERROR] Need 2 GPUs; found: {torch.cuda.device_count()}") + # Prepare output directories. _safe_mkdir(OUT_FINAL_DIR) tmp_dir = os.path.join(OUT_FINAL_DIR, "_tmp_shards") _safe_mkdir(tmp_dir) + # Load dataset once in the main process to: + # - get total length n + # - auto-detect which columns are necessary for SFT/DPO print("[INFO] Loading dataset metadata...") ds = load_dataset(DATASET_NAME, split=SPLIT) n = len(ds) @@ -281,12 +426,15 @@ def main(): print(f"[INFO] Params: MODEL_BATCH_SIZE={MODEL_BATCH_SIZE}, MAP_BATCH_SIZE={MAP_BATCH_SIZE}, " f"LONG_THRESHOLD_CHARS={LONG_THRESHOLD_CHARS}, NUM_BEAMS={NUM_BEAMS}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}") + # Two workers => two shards, pinned to GPU 0 and GPU 1. gpus = [0, 1] total_shards = 2 + # Use "spawn" for CUDA safety in multiprocessing. mp.set_start_method("spawn", force=True) procs = [] + # Start workers unless resume mode skips an existing shard. for shard_id, gpu_id in enumerate(gpus): shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}") if args.resume and os.path.isdir(shard_path): @@ -300,12 +448,16 @@ def main(): p.start() procs.append(p) + # Wait for workers and fail fast if any worker exits with a non-zero code. for p in procs: p.join() if p.exitcode != 0: raise SystemExit("[ERROR] A worker failed. See logs above.") - # merge + # Merge step: + # - load each shard dataset from disk + # - concatenate them + # - restore original ordering (because sharding interleaves indices) print("\n[INFO] Merging shards...") shard_dirs = [os.path.join(tmp_dir, f"shard_{i:02d}") for i in range(total_shards)] for sd in shard_dirs: @@ -315,17 +467,21 @@ def main(): shards = [Dataset.load_from_disk(sd) for sd in shard_dirs] merged = concatenate_datasets(shards) - # restore original order + # Restore original order: + # Each shard contains indices satisfying (i % total_shards == shard_id). merged_indices = [] for shard_id in range(total_shards): merged_indices.extend([i for i in range(n) if (i % total_shards) == shard_id]) + # Add an ordering column, sort by it, then remove it. merged = merged.add_column("orig_index", merged_indices) merged = merged.sort("orig_index").remove_columns(["orig_index"]) + # Safety check: ensure no missing rows after merge. if len(merged) != n: raise SystemExit(f"[ERROR] Length mismatch after merge: merged={len(merged)} expected={n}") + # Save the final translated dataset. print("[INFO] Saving final dataset...") merged.save_to_disk(OUT_FINAL_DIR)