DP/translate/translate_PKF.py
2026-02-04 20:26:38 +00:00

495 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import re
import math
import argparse
import multiprocessing as mp
from typing import List, Dict, Any
import torch
from tqdm import tqdm
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# =========================
# Dataset source on Hugging Face Hub and which split to translate.
DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF-30K"
SPLIT = "train"
# Local path to the NLLB model (loaded from disk, not from the Hub).
NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B"
# Output directory where the fully translated dataset will be saved.
OUT_FINAL_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
# Language codes for NLLB (source and target).
SRC_LANG = "eng_Latn"
TGT_LANG = "slk_Latn"
# Performance-oriented defaults for a 2x RTX Titan 24GB setup.
# Notes:
# - MODEL_BATCH_SIZE controls GPU-side translation batching.
# - MAP_BATCH_SIZE controls CPU/RAM side batching when mapping over a dataset shard.
# - LONG_THRESHOLD_CHARS splits very long texts to avoid slow/unstable generation.
# - NUM_BEAMS must be 1 for speed (beam search is expensive).
# - MAX_NEW_TOKENS caps translation length; reducing it can speed up generation.
MODEL_BATCH_SIZE = 32 # if OOM: 24 -> 16
MAP_BATCH_SIZE = 128 # CPU/RAM side
LONG_THRESHOLD_CHARS = 3500 # higher => fewer slow fallbacks
NUM_BEAMS = 1 # MUST be 1 for speed
MAX_NEW_TOKENS = 128 # can lower to 96 for speed
# =========================
# Utilities
# =========================
# -----------------------------------------------------------------------------
# normalize_text(x)
# Purpose:
# Converts arbitrary dataset values into a clean string:
# - Handles None values safely.
# - Removes the Unicode replacement char (<28>, U+FFFD) which can appear in noisy text.
# - Strips whitespace.
#
# Why it exists:
# Dataset fields may contain None / non-string types / corrupted characters.
# Translation should always receive a valid string input.
# -----------------------------------------------------------------------------
def normalize_text(x: Any) -> str:
if x is None:
return ""
return str(x).replace("\uFFFD", "").strip()
# -----------------------------------------------------------------------------
# split_text_safely(text, max_chars)
# Purpose:
# Splits long text into smaller chunks to avoid slow generation / memory issues.
#
# Strategy:
# - Normalizes input text.
# - Splits by paragraph boundaries (2+ newlines).
# - Ensures each chunk is <= max_chars.
# - If a paragraph itself is too long, it is cut into fixed-size slices.
#
# Output:
# A list of non-empty chunks; returns [""] for empty input as a safe placeholder.
# -----------------------------------------------------------------------------
def split_text_safely(text: str, max_chars: int) -> List[str]:
text = normalize_text(text)
if not text:
return [""]
chunks: List[str] = []
# push(piece)
# Internal helper that:
# - ignores empty pieces
# - keeps pieces <= max_chars
# - slices very long pieces into fixed-size parts
def push(piece: str):
piece = piece.strip()
if not piece:
return
if len(piece) <= max_chars:
chunks.append(piece)
else:
for i in range(0, len(piece), max_chars):
part = piece[i:i + max_chars].strip()
if part:
chunks.append(part)
paras = re.split(r"\n{2,}", text)
for p in paras:
p = p.strip()
if p:
push(p)
return chunks if chunks else [""]
# -----------------------------------------------------------------------------
# detect_needed_cols_for_sft_dpo(ds)
# Purpose:
# Auto-detect which columns should be translated so the output dataset is
# immediately usable for SFT and/or DPO preparation.
#
# Priority:
# 1) Classic preference schema: prompt + chosen + rejected
# 2) Pairwise schema: prompt + response_0 + response_1
# 3) Fallback: prompt + any response-like columns matching a regex pattern
# 4) Last resort: prompt only
#
# Result:
# A list of columns to translate, minimizing work and storage.
# -----------------------------------------------------------------------------
def detect_needed_cols_for_sft_dpo(ds: Dataset) -> List[str]:
"""
Auto-detect columns needed for SFT->DPO training.
Priority:
1) prompt + chosen + rejected (classic preference schema)
2) prompt + response_0 + response_1 (pairwise schema)
3) fallback: prompt + any reasonable response fields
"""
cols = set(ds.column_names)
# classic preference format
if {"prompt", "chosen", "rejected"}.issubset(cols):
return ["prompt", "chosen", "rejected"]
# common pairwise format
if {"prompt", "response_0", "response_1"}.issubset(cols):
return ["prompt", "response_0", "response_1"]
# fallback: translate prompt + any text response columns that exist
candidates = []
for c in ds.column_names:
if c == "prompt":
continue
if re.match(r"^(chosen|rejected|response_\d+|answer_\d+|completion_\d+)$", c):
candidates.append(c)
if "prompt" in cols and candidates:
return ["prompt"] + sorted(candidates)
# last resort: translate only prompt (still useful for later)
if "prompt" in cols:
return ["prompt"]
raise RuntimeError("Could not detect a 'prompt' column; cannot proceed.")
# -----------------------------------------------------------------------------
# nllb_translate_batch(tokenizer, model, texts)
# Purpose:
# Performs batched translation for a list of short texts via NLLB.
#
# Details:
# - Sets tokenizer.src_lang each time to ensure correct source language.
# - Uses forced_bos_token_id to force decoding in the target language.
# - Tokenizes to max_length=1024 (encoder-side), pads and truncates as needed.
# - Uses model.generate() with max_new_tokens, num_beams, caching for speed.
# - Returns stripped decoded strings.
#
# Important:
# - @torch.inference_mode disables gradients for speed and lower memory.
# -----------------------------------------------------------------------------
@torch.inference_mode()
def nllb_translate_batch(tokenizer, model, texts: List[str]) -> List[str]:
tokenizer.src_lang = SRC_LANG
forced_bos = tokenizer.convert_tokens_to_ids(TGT_LANG)
inputs = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen = model.generate(
**inputs,
forced_bos_token_id=forced_bos,
max_new_tokens=MAX_NEW_TOKENS,
num_beams=NUM_BEAMS,
do_sample=False,
use_cache=True,
)
out = tokenizer.batch_decode(gen, skip_special_tokens=True)
return [o.strip() for o in out]
# -----------------------------------------------------------------------------
# translate_long_text(tokenizer, model, text, sub_batch)
# Purpose:
# Safe fallback path for very long texts that may be slow or unstable if
# translated in one shot.
#
# Strategy:
# - Split the text into <= LONG_THRESHOLD_CHARS chunks (paragraph-aware).
# - Translate chunks in sub-batches to control VRAM use.
# - Join translated chunks with blank lines to preserve paragraph separation.
# -----------------------------------------------------------------------------
def translate_long_text(tokenizer, model, text: str, sub_batch: int) -> str:
text = normalize_text(text)
if not text:
return ""
chunks = split_text_safely(text, max_chars=LONG_THRESHOLD_CHARS)
out_chunks: List[str] = []
for i in range(0, len(chunks), sub_batch):
out_chunks.extend(nllb_translate_batch(tokenizer, model, chunks[i:i + sub_batch]))
return "\n\n".join([c for c in out_chunks if c])
# -----------------------------------------------------------------------------
# _safe_mkdir(p)
# Purpose:
# Convenience helper to create directories idempotently.
# -----------------------------------------------------------------------------
def _safe_mkdir(p: str):
os.makedirs(p, exist_ok=True)
# =========================
# Worker (one GPU)
# =========================
# -----------------------------------------------------------------------------
# worker_translate(...)
# Purpose:
# Runs translation on a shard of the dataset using a specific GPU.
#
# Key ideas:
# - Uses CUDA_VISIBLE_DEVICES to bind this worker process to a single GPU.
# - Loads the full dataset, then selects only rows belonging to this shard via:
# index % total_shards == shard_id
# This guarantees a complete partition with no missing rows across workers.
# - Translates only the required columns (cols_to_translate).
# - Uses a two-path translation:
# * fast path for short texts (batched GPU translation)
# * slow fallback for very long texts (split + smaller sub-batches)
# - Saves this shards translated dataset to tmp_dir/shard_XX for later merging.
# -----------------------------------------------------------------------------
def worker_translate(
shard_id: int,
gpu_id: int,
total_shards: int,
tmp_dir: str,
cols_to_translate: List[str],
total_len: int,
):
# Bind this process to a specific GPU.
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# Global inference optimizations (no gradients).
torch.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print(f"\n[WORKER {shard_id}] GPU={gpu_id} device={device} dtype={dtype}")
print(f"[WORKER {shard_id}] Translating columns: {cols_to_translate}")
# Load the dataset split (each worker loads it independently).
ds = load_dataset(DATASET_NAME, split=SPLIT)
n = len(ds)
if n != total_len:
print(f"[WORKER {shard_id}] WARN: dataset length mismatch: got {n}, expected {total_len}")
# Strict sharding by original index to avoid missing rows and keep deterministic coverage.
indices = [i for i in range(n) if (i % total_shards) == shard_id]
ds_shard = ds.select(indices)
shard_len = len(ds_shard)
# Load tokenizer + model from local path (NLLB).
tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
NLLB_PATH,
torch_dtype=dtype,
low_cpu_mem_usage=True,
).to(device)
model.eval()
# Number of CPU-side map batches (used only for progress bar granularity).
total_map_batches = math.ceil(shard_len / MAP_BATCH_SIZE)
# map_fn(batch)
# Purpose:
# Translates the requested columns for a single CPU-side batch.
#
# Workflow per column:
# 1) Normalize all values to strings.
# 2) Split into short_texts and long_texts by LONG_THRESHOLD_CHARS.
# 3) Translate short texts with fast batched translation:
# - internal sub-batches of MODEL_BATCH_SIZE
# 4) Translate long texts individually through translate_long_text().
# 5) Return a dict: {col_name: translated_list} for each requested column.
def map_fn(batch: Dict[str, List[Any]]) -> Dict[str, List[str]]:
out: Dict[str, List[str]] = {}
for col in cols_to_translate:
vals = [normalize_text(v) for v in batch[col]]
short_texts: List[str] = []
short_pos: List[int] = []
long_texts: List[str] = []
long_pos: List[int] = []
# Partition examples by length so most data goes through the fast path.
for j, t in enumerate(vals):
if (not t) or (len(t) <= LONG_THRESHOLD_CHARS):
short_pos.append(j)
short_texts.append(t)
else:
long_pos.append(j)
long_texts.append(t)
translated = [""] * len(vals)
# Fast batched translate for short texts (GPU-optimized).
if short_texts:
tr_short: List[str] = []
for k in range(0, len(short_texts), MODEL_BATCH_SIZE):
tr_short.extend(nllb_translate_batch(tokenizer, model, short_texts[k:k + MODEL_BATCH_SIZE]))
for pos, tr in zip(short_pos, tr_short):
translated[pos] = tr
# Slow fallback for very long texts (split + smaller batches).
for pos, t in zip(long_pos, long_texts):
translated[pos] = translate_long_text(tokenizer, model, t, sub_batch=max(1, MODEL_BATCH_SIZE // 2))
out[col] = translated
return out
# Manual loop (instead of a single ds_shard.map) to show a stable progress bar.
# Each loop translates a slice [start:end] of the shard.
parts: List[Dataset] = []
with tqdm(total=total_map_batches, desc=f"GPU{gpu_id} shard{shard_id}", ncols=100) as pbar:
for start in range(0, shard_len, MAP_BATCH_SIZE):
end = min(shard_len, start + MAP_BATCH_SIZE)
part = ds_shard.select(range(start, end)).map(
map_fn,
batched=True,
batch_size=end - start,
desc=None,
)
parts.append(part)
pbar.update(1)
# Concatenate translated parts into a single shard dataset.
ds_tr = concatenate_datasets(parts)
# Save this shard to disk for later merge.
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
_safe_mkdir(shard_path)
ds_tr.save_to_disk(shard_path)
print(f"[WORKER {shard_id}] Saved -> {shard_path}")
# =========================
# Main (one-click)
# =========================
# -----------------------------------------------------------------------------
# main()
# Purpose:
# One-click launcher for 2-GPU translation + shard merge.
#
# Steps:
# 1) Parse CLI args:
# --resume: skip workers for shards that already exist and only merge.
# 2) Validate:
# - NLLB model path exists
# - at least 2 GPUs are available (script expects exactly 2 workers here)
# 3) Load dataset metadata (length + schema) and determine which columns to translate.
# 4) Spawn 2 worker processes (GPU 0 and GPU 1) using spawn start method.
# 5) After workers finish, merge shard datasets and restore original row order.
# 6) Save the final merged dataset to OUT_FINAL_DIR and verify row count.
# -----------------------------------------------------------------------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--resume", action="store_true", help="Skip existing shards and only merge.")
args = ap.parse_args()
# Ensure the local NLLB path is valid before starting.
if not os.path.isdir(NLLB_PATH):
raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}")
# This script is designed for a 2-GPU setup (two workers).
if torch.cuda.device_count() < 2:
raise SystemExit(f"[ERROR] Need 2 GPUs; found: {torch.cuda.device_count()}")
# Prepare output directories.
_safe_mkdir(OUT_FINAL_DIR)
tmp_dir = os.path.join(OUT_FINAL_DIR, "_tmp_shards")
_safe_mkdir(tmp_dir)
# Load dataset once in the main process to:
# - get total length n
# - auto-detect which columns are necessary for SFT/DPO
print("[INFO] Loading dataset metadata...")
ds = load_dataset(DATASET_NAME, split=SPLIT)
n = len(ds)
cols_to_translate = detect_needed_cols_for_sft_dpo(ds)
print(f"[INFO] Dataset: {DATASET_NAME} split={SPLIT}")
print(f"[INFO] Total rows: {n}")
print(f"[INFO] Translating ONLY needed cols for SFT->DPO: {cols_to_translate}")
print(f"[INFO] Output: {OUT_FINAL_DIR}")
print(f"[INFO] Params: MODEL_BATCH_SIZE={MODEL_BATCH_SIZE}, MAP_BATCH_SIZE={MAP_BATCH_SIZE}, "
f"LONG_THRESHOLD_CHARS={LONG_THRESHOLD_CHARS}, NUM_BEAMS={NUM_BEAMS}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}")
# Two workers => two shards, pinned to GPU 0 and GPU 1.
gpus = [0, 1]
total_shards = 2
# Use "spawn" for CUDA safety in multiprocessing.
mp.set_start_method("spawn", force=True)
procs = []
# Start workers unless resume mode skips an existing shard.
for shard_id, gpu_id in enumerate(gpus):
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
if args.resume and os.path.isdir(shard_path):
print(f"[INFO] Resume: shard exists, skipping worker {shard_id} -> {shard_path}")
continue
p = mp.Process(
target=worker_translate,
args=(shard_id, gpu_id, total_shards, tmp_dir, cols_to_translate, n),
)
p.start()
procs.append(p)
# Wait for workers and fail fast if any worker exits with a non-zero code.
for p in procs:
p.join()
if p.exitcode != 0:
raise SystemExit("[ERROR] A worker failed. See logs above.")
# Merge step:
# - load each shard dataset from disk
# - concatenate them
# - restore original ordering (because sharding interleaves indices)
print("\n[INFO] Merging shards...")
shard_dirs = [os.path.join(tmp_dir, f"shard_{i:02d}") for i in range(total_shards)]
for sd in shard_dirs:
if not os.path.isdir(sd):
raise SystemExit(f"[ERROR] Missing shard directory: {sd}")
shards = [Dataset.load_from_disk(sd) for sd in shard_dirs]
merged = concatenate_datasets(shards)
# Restore original order:
# Each shard contains indices satisfying (i % total_shards == shard_id).
merged_indices = []
for shard_id in range(total_shards):
merged_indices.extend([i for i in range(n) if (i % total_shards) == shard_id])
# Add an ordering column, sort by it, then remove it.
merged = merged.add_column("orig_index", merged_indices)
merged = merged.sort("orig_index").remove_columns(["orig_index"])
# Safety check: ensure no missing rows after merge.
if len(merged) != n:
raise SystemExit(f"[ERROR] Length mismatch after merge: merged={len(merged)} expected={n}")
# Save the final translated dataset.
print("[INFO] Saving final dataset...")
merged.save_to_disk(OUT_FINAL_DIR)
print(f"[OK] Done. Saved FULL translated dataset to: {OUT_FINAL_DIR}")
print("[OK] Verified: no missing rows.")
print(f"[OK] Translated columns: {cols_to_translate}")
if __name__ == "__main__":
main()