Added comments
This commit is contained in:
parent
de2336183b
commit
4ef5f31659
@ -15,20 +15,27 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|||||||
|
|
||||||
|
|
||||||
# =========================
|
# =========================
|
||||||
# Fixed settings (your setup)
|
# Dataset source on Hugging Face Hub and which split to translate.
|
||||||
# =========================
|
|
||||||
DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF-30K"
|
DATASET_NAME = "PKU-Alignment/PKU-SafeRLHF-30K"
|
||||||
SPLIT = "train"
|
SPLIT = "train"
|
||||||
|
|
||||||
|
# Local path to the NLLB model (loaded from disk, not from the Hub).
|
||||||
NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B"
|
NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B"
|
||||||
|
|
||||||
# Output (translated dataset)
|
# Output directory where the fully translated dataset will be saved.
|
||||||
OUT_FINAL_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
|
OUT_FINAL_DIR = "/home/hyrenko/Diploma/datasets/PKU-SafeRLHF-30K_slk_Latn_SFT_DPO_ONLY"
|
||||||
|
|
||||||
|
# Language codes for NLLB (source and target).
|
||||||
SRC_LANG = "eng_Latn"
|
SRC_LANG = "eng_Latn"
|
||||||
TGT_LANG = "slk_Latn"
|
TGT_LANG = "slk_Latn"
|
||||||
|
|
||||||
# Speed defaults for 2x RTX Titan 24GB
|
# Performance-oriented defaults for a 2x RTX Titan 24GB setup.
|
||||||
|
# Notes:
|
||||||
|
# - MODEL_BATCH_SIZE controls GPU-side translation batching.
|
||||||
|
# - MAP_BATCH_SIZE controls CPU/RAM side batching when mapping over a dataset shard.
|
||||||
|
# - LONG_THRESHOLD_CHARS splits very long texts to avoid slow/unstable generation.
|
||||||
|
# - NUM_BEAMS must be 1 for speed (beam search is expensive).
|
||||||
|
# - MAX_NEW_TOKENS caps translation length; reducing it can speed up generation.
|
||||||
MODEL_BATCH_SIZE = 32 # if OOM: 24 -> 16
|
MODEL_BATCH_SIZE = 32 # if OOM: 24 -> 16
|
||||||
MAP_BATCH_SIZE = 128 # CPU/RAM side
|
MAP_BATCH_SIZE = 128 # CPU/RAM side
|
||||||
LONG_THRESHOLD_CHARS = 3500 # higher => fewer slow fallbacks
|
LONG_THRESHOLD_CHARS = 3500 # higher => fewer slow fallbacks
|
||||||
@ -39,12 +46,39 @@ MAX_NEW_TOKENS = 128 # can lower to 96 for speed
|
|||||||
# =========================
|
# =========================
|
||||||
# Utilities
|
# Utilities
|
||||||
# =========================
|
# =========================
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# normalize_text(x)
|
||||||
|
# Purpose:
|
||||||
|
# Converts arbitrary dataset values into a clean string:
|
||||||
|
# - Handles None values safely.
|
||||||
|
# - Removes the Unicode replacement char (<28>, U+FFFD) which can appear in noisy text.
|
||||||
|
# - Strips whitespace.
|
||||||
|
#
|
||||||
|
# Why it exists:
|
||||||
|
# Dataset fields may contain None / non-string types / corrupted characters.
|
||||||
|
# Translation should always receive a valid string input.
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
def normalize_text(x: Any) -> str:
|
def normalize_text(x: Any) -> str:
|
||||||
if x is None:
|
if x is None:
|
||||||
return ""
|
return ""
|
||||||
return str(x).replace("\uFFFD", "").strip()
|
return str(x).replace("\uFFFD", "").strip()
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# split_text_safely(text, max_chars)
|
||||||
|
# Purpose:
|
||||||
|
# Splits long text into smaller chunks to avoid slow generation / memory issues.
|
||||||
|
#
|
||||||
|
# Strategy:
|
||||||
|
# - Normalizes input text.
|
||||||
|
# - Splits by paragraph boundaries (2+ newlines).
|
||||||
|
# - Ensures each chunk is <= max_chars.
|
||||||
|
# - If a paragraph itself is too long, it is cut into fixed-size slices.
|
||||||
|
#
|
||||||
|
# Output:
|
||||||
|
# A list of non-empty chunks; returns [""] for empty input as a safe placeholder.
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
def split_text_safely(text: str, max_chars: int) -> List[str]:
|
def split_text_safely(text: str, max_chars: int) -> List[str]:
|
||||||
text = normalize_text(text)
|
text = normalize_text(text)
|
||||||
if not text:
|
if not text:
|
||||||
@ -52,6 +86,11 @@ def split_text_safely(text: str, max_chars: int) -> List[str]:
|
|||||||
|
|
||||||
chunks: List[str] = []
|
chunks: List[str] = []
|
||||||
|
|
||||||
|
# push(piece)
|
||||||
|
# Internal helper that:
|
||||||
|
# - ignores empty pieces
|
||||||
|
# - keeps pieces <= max_chars
|
||||||
|
# - slices very long pieces into fixed-size parts
|
||||||
def push(piece: str):
|
def push(piece: str):
|
||||||
piece = piece.strip()
|
piece = piece.strip()
|
||||||
if not piece:
|
if not piece:
|
||||||
@ -73,6 +112,21 @@ def split_text_safely(text: str, max_chars: int) -> List[str]:
|
|||||||
return chunks if chunks else [""]
|
return chunks if chunks else [""]
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# detect_needed_cols_for_sft_dpo(ds)
|
||||||
|
# Purpose:
|
||||||
|
# Auto-detect which columns should be translated so the output dataset is
|
||||||
|
# immediately usable for SFT and/or DPO preparation.
|
||||||
|
#
|
||||||
|
# Priority:
|
||||||
|
# 1) Classic preference schema: prompt + chosen + rejected
|
||||||
|
# 2) Pairwise schema: prompt + response_0 + response_1
|
||||||
|
# 3) Fallback: prompt + any response-like columns matching a regex pattern
|
||||||
|
# 4) Last resort: prompt only
|
||||||
|
#
|
||||||
|
# Result:
|
||||||
|
# A list of columns to translate, minimizing work and storage.
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
def detect_needed_cols_for_sft_dpo(ds: Dataset) -> List[str]:
|
def detect_needed_cols_for_sft_dpo(ds: Dataset) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Auto-detect columns needed for SFT->DPO training.
|
Auto-detect columns needed for SFT->DPO training.
|
||||||
@ -109,6 +163,21 @@ def detect_needed_cols_for_sft_dpo(ds: Dataset) -> List[str]:
|
|||||||
raise RuntimeError("Could not detect a 'prompt' column; cannot proceed.")
|
raise RuntimeError("Could not detect a 'prompt' column; cannot proceed.")
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# nllb_translate_batch(tokenizer, model, texts)
|
||||||
|
# Purpose:
|
||||||
|
# Performs batched translation for a list of short texts via NLLB.
|
||||||
|
#
|
||||||
|
# Details:
|
||||||
|
# - Sets tokenizer.src_lang each time to ensure correct source language.
|
||||||
|
# - Uses forced_bos_token_id to force decoding in the target language.
|
||||||
|
# - Tokenizes to max_length=1024 (encoder-side), pads and truncates as needed.
|
||||||
|
# - Uses model.generate() with max_new_tokens, num_beams, caching for speed.
|
||||||
|
# - Returns stripped decoded strings.
|
||||||
|
#
|
||||||
|
# Important:
|
||||||
|
# - @torch.inference_mode disables gradients for speed and lower memory.
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def nllb_translate_batch(tokenizer, model, texts: List[str]) -> List[str]:
|
def nllb_translate_batch(tokenizer, model, texts: List[str]) -> List[str]:
|
||||||
tokenizer.src_lang = SRC_LANG
|
tokenizer.src_lang = SRC_LANG
|
||||||
@ -135,6 +204,17 @@ def nllb_translate_batch(tokenizer, model, texts: List[str]) -> List[str]:
|
|||||||
return [o.strip() for o in out]
|
return [o.strip() for o in out]
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# translate_long_text(tokenizer, model, text, sub_batch)
|
||||||
|
# Purpose:
|
||||||
|
# Safe fallback path for very long texts that may be slow or unstable if
|
||||||
|
# translated in one shot.
|
||||||
|
#
|
||||||
|
# Strategy:
|
||||||
|
# - Split the text into <= LONG_THRESHOLD_CHARS chunks (paragraph-aware).
|
||||||
|
# - Translate chunks in sub-batches to control VRAM use.
|
||||||
|
# - Join translated chunks with blank lines to preserve paragraph separation.
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
def translate_long_text(tokenizer, model, text: str, sub_batch: int) -> str:
|
def translate_long_text(tokenizer, model, text: str, sub_batch: int) -> str:
|
||||||
text = normalize_text(text)
|
text = normalize_text(text)
|
||||||
if not text:
|
if not text:
|
||||||
@ -146,6 +226,11 @@ def translate_long_text(tokenizer, model, text: str, sub_batch: int) -> str:
|
|||||||
return "\n\n".join([c for c in out_chunks if c])
|
return "\n\n".join([c for c in out_chunks if c])
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# _safe_mkdir(p)
|
||||||
|
# Purpose:
|
||||||
|
# Convenience helper to create directories idempotently.
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
def _safe_mkdir(p: str):
|
def _safe_mkdir(p: str):
|
||||||
os.makedirs(p, exist_ok=True)
|
os.makedirs(p, exist_ok=True)
|
||||||
|
|
||||||
@ -153,6 +238,23 @@ def _safe_mkdir(p: str):
|
|||||||
# =========================
|
# =========================
|
||||||
# Worker (one GPU)
|
# Worker (one GPU)
|
||||||
# =========================
|
# =========================
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# worker_translate(...)
|
||||||
|
# Purpose:
|
||||||
|
# Runs translation on a shard of the dataset using a specific GPU.
|
||||||
|
#
|
||||||
|
# Key ideas:
|
||||||
|
# - Uses CUDA_VISIBLE_DEVICES to bind this worker process to a single GPU.
|
||||||
|
# - Loads the full dataset, then selects only rows belonging to this shard via:
|
||||||
|
# index % total_shards == shard_id
|
||||||
|
# This guarantees a complete partition with no missing rows across workers.
|
||||||
|
# - Translates only the required columns (cols_to_translate).
|
||||||
|
# - Uses a two-path translation:
|
||||||
|
# * fast path for short texts (batched GPU translation)
|
||||||
|
# * slow fallback for very long texts (split + smaller sub-batches)
|
||||||
|
# - Saves this shard’s translated dataset to tmp_dir/shard_XX for later merging.
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
def worker_translate(
|
def worker_translate(
|
||||||
shard_id: int,
|
shard_id: int,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
@ -161,8 +263,10 @@ def worker_translate(
|
|||||||
cols_to_translate: List[str],
|
cols_to_translate: List[str],
|
||||||
total_len: int,
|
total_len: int,
|
||||||
):
|
):
|
||||||
|
# Bind this process to a specific GPU.
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||||
|
|
||||||
|
# Global inference optimizations (no gradients).
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
@ -172,16 +276,18 @@ def worker_translate(
|
|||||||
print(f"\n[WORKER {shard_id}] GPU={gpu_id} device={device} dtype={dtype}")
|
print(f"\n[WORKER {shard_id}] GPU={gpu_id} device={device} dtype={dtype}")
|
||||||
print(f"[WORKER {shard_id}] Translating columns: {cols_to_translate}")
|
print(f"[WORKER {shard_id}] Translating columns: {cols_to_translate}")
|
||||||
|
|
||||||
|
# Load the dataset split (each worker loads it independently).
|
||||||
ds = load_dataset(DATASET_NAME, split=SPLIT)
|
ds = load_dataset(DATASET_NAME, split=SPLIT)
|
||||||
n = len(ds)
|
n = len(ds)
|
||||||
if n != total_len:
|
if n != total_len:
|
||||||
print(f"[WORKER {shard_id}] WARN: dataset length mismatch: got {n}, expected {total_len}")
|
print(f"[WORKER {shard_id}] WARN: dataset length mismatch: got {n}, expected {total_len}")
|
||||||
|
|
||||||
# strict sharding by index => no missing rows
|
# Strict sharding by original index to avoid missing rows and keep deterministic coverage.
|
||||||
indices = [i for i in range(n) if (i % total_shards) == shard_id]
|
indices = [i for i in range(n) if (i % total_shards) == shard_id]
|
||||||
ds_shard = ds.select(indices)
|
ds_shard = ds.select(indices)
|
||||||
shard_len = len(ds_shard)
|
shard_len = len(ds_shard)
|
||||||
|
|
||||||
|
# Load tokenizer + model from local path (NLLB).
|
||||||
tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True)
|
tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True)
|
||||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
NLLB_PATH,
|
NLLB_PATH,
|
||||||
@ -190,8 +296,20 @@ def worker_translate(
|
|||||||
).to(device)
|
).to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
# Number of CPU-side map batches (used only for progress bar granularity).
|
||||||
total_map_batches = math.ceil(shard_len / MAP_BATCH_SIZE)
|
total_map_batches = math.ceil(shard_len / MAP_BATCH_SIZE)
|
||||||
|
|
||||||
|
# map_fn(batch)
|
||||||
|
# Purpose:
|
||||||
|
# Translates the requested columns for a single CPU-side batch.
|
||||||
|
#
|
||||||
|
# Workflow per column:
|
||||||
|
# 1) Normalize all values to strings.
|
||||||
|
# 2) Split into short_texts and long_texts by LONG_THRESHOLD_CHARS.
|
||||||
|
# 3) Translate short texts with fast batched translation:
|
||||||
|
# - internal sub-batches of MODEL_BATCH_SIZE
|
||||||
|
# 4) Translate long texts individually through translate_long_text().
|
||||||
|
# 5) Return a dict: {col_name: translated_list} for each requested column.
|
||||||
def map_fn(batch: Dict[str, List[Any]]) -> Dict[str, List[str]]:
|
def map_fn(batch: Dict[str, List[Any]]) -> Dict[str, List[str]]:
|
||||||
out: Dict[str, List[str]] = {}
|
out: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
@ -203,6 +321,7 @@ def worker_translate(
|
|||||||
long_texts: List[str] = []
|
long_texts: List[str] = []
|
||||||
long_pos: List[int] = []
|
long_pos: List[int] = []
|
||||||
|
|
||||||
|
# Partition examples by length so most data goes through the fast path.
|
||||||
for j, t in enumerate(vals):
|
for j, t in enumerate(vals):
|
||||||
if (not t) or (len(t) <= LONG_THRESHOLD_CHARS):
|
if (not t) or (len(t) <= LONG_THRESHOLD_CHARS):
|
||||||
short_pos.append(j)
|
short_pos.append(j)
|
||||||
@ -213,7 +332,7 @@ def worker_translate(
|
|||||||
|
|
||||||
translated = [""] * len(vals)
|
translated = [""] * len(vals)
|
||||||
|
|
||||||
# fast batched translate
|
# Fast batched translate for short texts (GPU-optimized).
|
||||||
if short_texts:
|
if short_texts:
|
||||||
tr_short: List[str] = []
|
tr_short: List[str] = []
|
||||||
for k in range(0, len(short_texts), MODEL_BATCH_SIZE):
|
for k in range(0, len(short_texts), MODEL_BATCH_SIZE):
|
||||||
@ -221,7 +340,7 @@ def worker_translate(
|
|||||||
for pos, tr in zip(short_pos, tr_short):
|
for pos, tr in zip(short_pos, tr_short):
|
||||||
translated[pos] = tr
|
translated[pos] = tr
|
||||||
|
|
||||||
# slow fallback for very long texts
|
# Slow fallback for very long texts (split + smaller batches).
|
||||||
for pos, t in zip(long_pos, long_texts):
|
for pos, t in zip(long_pos, long_texts):
|
||||||
translated[pos] = translate_long_text(tokenizer, model, t, sub_batch=max(1, MODEL_BATCH_SIZE // 2))
|
translated[pos] = translate_long_text(tokenizer, model, t, sub_batch=max(1, MODEL_BATCH_SIZE // 2))
|
||||||
|
|
||||||
@ -229,7 +348,8 @@ def worker_translate(
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
# manual loop to show stable progress bar
|
# Manual loop (instead of a single ds_shard.map) to show a stable progress bar.
|
||||||
|
# Each loop translates a slice [start:end] of the shard.
|
||||||
parts: List[Dataset] = []
|
parts: List[Dataset] = []
|
||||||
with tqdm(total=total_map_batches, desc=f"GPU{gpu_id} shard{shard_id}", ncols=100) as pbar:
|
with tqdm(total=total_map_batches, desc=f"GPU{gpu_id} shard{shard_id}", ncols=100) as pbar:
|
||||||
for start in range(0, shard_len, MAP_BATCH_SIZE):
|
for start in range(0, shard_len, MAP_BATCH_SIZE):
|
||||||
@ -243,8 +363,10 @@ def worker_translate(
|
|||||||
parts.append(part)
|
parts.append(part)
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Concatenate translated parts into a single shard dataset.
|
||||||
ds_tr = concatenate_datasets(parts)
|
ds_tr = concatenate_datasets(parts)
|
||||||
|
|
||||||
|
# Save this shard to disk for later merge.
|
||||||
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
|
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
|
||||||
_safe_mkdir(shard_path)
|
_safe_mkdir(shard_path)
|
||||||
ds_tr.save_to_disk(shard_path)
|
ds_tr.save_to_disk(shard_path)
|
||||||
@ -254,21 +376,44 @@ def worker_translate(
|
|||||||
# =========================
|
# =========================
|
||||||
# Main (one-click)
|
# Main (one-click)
|
||||||
# =========================
|
# =========================
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# main()
|
||||||
|
# Purpose:
|
||||||
|
# One-click launcher for 2-GPU translation + shard merge.
|
||||||
|
#
|
||||||
|
# Steps:
|
||||||
|
# 1) Parse CLI args:
|
||||||
|
# --resume: skip workers for shards that already exist and only merge.
|
||||||
|
# 2) Validate:
|
||||||
|
# - NLLB model path exists
|
||||||
|
# - at least 2 GPUs are available (script expects exactly 2 workers here)
|
||||||
|
# 3) Load dataset metadata (length + schema) and determine which columns to translate.
|
||||||
|
# 4) Spawn 2 worker processes (GPU 0 and GPU 1) using spawn start method.
|
||||||
|
# 5) After workers finish, merge shard datasets and restore original row order.
|
||||||
|
# 6) Save the final merged dataset to OUT_FINAL_DIR and verify row count.
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
ap.add_argument("--resume", action="store_true", help="Skip existing shards and only merge.")
|
ap.add_argument("--resume", action="store_true", help="Skip existing shards and only merge.")
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
# Ensure the local NLLB path is valid before starting.
|
||||||
if not os.path.isdir(NLLB_PATH):
|
if not os.path.isdir(NLLB_PATH):
|
||||||
raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}")
|
raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}")
|
||||||
|
|
||||||
|
# This script is designed for a 2-GPU setup (two workers).
|
||||||
if torch.cuda.device_count() < 2:
|
if torch.cuda.device_count() < 2:
|
||||||
raise SystemExit(f"[ERROR] Need 2 GPUs; found: {torch.cuda.device_count()}")
|
raise SystemExit(f"[ERROR] Need 2 GPUs; found: {torch.cuda.device_count()}")
|
||||||
|
|
||||||
|
# Prepare output directories.
|
||||||
_safe_mkdir(OUT_FINAL_DIR)
|
_safe_mkdir(OUT_FINAL_DIR)
|
||||||
tmp_dir = os.path.join(OUT_FINAL_DIR, "_tmp_shards")
|
tmp_dir = os.path.join(OUT_FINAL_DIR, "_tmp_shards")
|
||||||
_safe_mkdir(tmp_dir)
|
_safe_mkdir(tmp_dir)
|
||||||
|
|
||||||
|
# Load dataset once in the main process to:
|
||||||
|
# - get total length n
|
||||||
|
# - auto-detect which columns are necessary for SFT/DPO
|
||||||
print("[INFO] Loading dataset metadata...")
|
print("[INFO] Loading dataset metadata...")
|
||||||
ds = load_dataset(DATASET_NAME, split=SPLIT)
|
ds = load_dataset(DATASET_NAME, split=SPLIT)
|
||||||
n = len(ds)
|
n = len(ds)
|
||||||
@ -281,12 +426,15 @@ def main():
|
|||||||
print(f"[INFO] Params: MODEL_BATCH_SIZE={MODEL_BATCH_SIZE}, MAP_BATCH_SIZE={MAP_BATCH_SIZE}, "
|
print(f"[INFO] Params: MODEL_BATCH_SIZE={MODEL_BATCH_SIZE}, MAP_BATCH_SIZE={MAP_BATCH_SIZE}, "
|
||||||
f"LONG_THRESHOLD_CHARS={LONG_THRESHOLD_CHARS}, NUM_BEAMS={NUM_BEAMS}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}")
|
f"LONG_THRESHOLD_CHARS={LONG_THRESHOLD_CHARS}, NUM_BEAMS={NUM_BEAMS}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}")
|
||||||
|
|
||||||
|
# Two workers => two shards, pinned to GPU 0 and GPU 1.
|
||||||
gpus = [0, 1]
|
gpus = [0, 1]
|
||||||
total_shards = 2
|
total_shards = 2
|
||||||
|
|
||||||
|
# Use "spawn" for CUDA safety in multiprocessing.
|
||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
procs = []
|
procs = []
|
||||||
|
|
||||||
|
# Start workers unless resume mode skips an existing shard.
|
||||||
for shard_id, gpu_id in enumerate(gpus):
|
for shard_id, gpu_id in enumerate(gpus):
|
||||||
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
|
shard_path = os.path.join(tmp_dir, f"shard_{shard_id:02d}")
|
||||||
if args.resume and os.path.isdir(shard_path):
|
if args.resume and os.path.isdir(shard_path):
|
||||||
@ -300,12 +448,16 @@ def main():
|
|||||||
p.start()
|
p.start()
|
||||||
procs.append(p)
|
procs.append(p)
|
||||||
|
|
||||||
|
# Wait for workers and fail fast if any worker exits with a non-zero code.
|
||||||
for p in procs:
|
for p in procs:
|
||||||
p.join()
|
p.join()
|
||||||
if p.exitcode != 0:
|
if p.exitcode != 0:
|
||||||
raise SystemExit("[ERROR] A worker failed. See logs above.")
|
raise SystemExit("[ERROR] A worker failed. See logs above.")
|
||||||
|
|
||||||
# merge
|
# Merge step:
|
||||||
|
# - load each shard dataset from disk
|
||||||
|
# - concatenate them
|
||||||
|
# - restore original ordering (because sharding interleaves indices)
|
||||||
print("\n[INFO] Merging shards...")
|
print("\n[INFO] Merging shards...")
|
||||||
shard_dirs = [os.path.join(tmp_dir, f"shard_{i:02d}") for i in range(total_shards)]
|
shard_dirs = [os.path.join(tmp_dir, f"shard_{i:02d}") for i in range(total_shards)]
|
||||||
for sd in shard_dirs:
|
for sd in shard_dirs:
|
||||||
@ -315,17 +467,21 @@ def main():
|
|||||||
shards = [Dataset.load_from_disk(sd) for sd in shard_dirs]
|
shards = [Dataset.load_from_disk(sd) for sd in shard_dirs]
|
||||||
merged = concatenate_datasets(shards)
|
merged = concatenate_datasets(shards)
|
||||||
|
|
||||||
# restore original order
|
# Restore original order:
|
||||||
|
# Each shard contains indices satisfying (i % total_shards == shard_id).
|
||||||
merged_indices = []
|
merged_indices = []
|
||||||
for shard_id in range(total_shards):
|
for shard_id in range(total_shards):
|
||||||
merged_indices.extend([i for i in range(n) if (i % total_shards) == shard_id])
|
merged_indices.extend([i for i in range(n) if (i % total_shards) == shard_id])
|
||||||
|
|
||||||
|
# Add an ordering column, sort by it, then remove it.
|
||||||
merged = merged.add_column("orig_index", merged_indices)
|
merged = merged.add_column("orig_index", merged_indices)
|
||||||
merged = merged.sort("orig_index").remove_columns(["orig_index"])
|
merged = merged.sort("orig_index").remove_columns(["orig_index"])
|
||||||
|
|
||||||
|
# Safety check: ensure no missing rows after merge.
|
||||||
if len(merged) != n:
|
if len(merged) != n:
|
||||||
raise SystemExit(f"[ERROR] Length mismatch after merge: merged={len(merged)} expected={n}")
|
raise SystemExit(f"[ERROR] Length mismatch after merge: merged={len(merged)} expected={n}")
|
||||||
|
|
||||||
|
# Save the final translated dataset.
|
||||||
print("[INFO] Saving final dataset...")
|
print("[INFO] Saving final dataset...")
|
||||||
merged.save_to_disk(OUT_FINAL_DIR)
|
merged.save_to_disk(OUT_FINAL_DIR)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user