191 lines
6.7 KiB
Python
191 lines
6.7 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import os
|
|
import torch
|
|
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# chunk_list(lst, n)
|
|
# Purpose:
|
|
# Splits a list into consecutive chunks of size n.
|
|
#
|
|
# How it works:
|
|
# - Iterates from 0 to len(lst) in steps of n.
|
|
# - Yields slices lst[i:i+n].
|
|
#
|
|
# Why it exists:
|
|
# - The dataset mapping can provide large batches, but translation should be
|
|
# performed in smaller sub-batches to control GPU memory usage and latency.
|
|
# -----------------------------------------------------------------------------
|
|
def chunk_list(lst, n):
|
|
for i in range(0, len(lst), n):
|
|
yield lst[i:i+n]
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# translate_batch(...)
|
|
# Purpose:
|
|
# Translates a list of texts in one forward pass (generation) using an NLLB
|
|
# seq2seq model and tokenizer.
|
|
#
|
|
# Key details:
|
|
# - Sets tokenizer.src_lang to the source language code.
|
|
# - Uses forced_bos_token_id to force generation to start in the target language.
|
|
# - Tokenizes with padding/truncation, moves tensors to the target device.
|
|
# - Calls model.generate() with beam search parameters.
|
|
# - Decodes tokens back into strings.
|
|
#
|
|
# Torch behavior:
|
|
# - @torch.inference_mode() disables gradient tracking for faster inference and
|
|
# reduced memory usage.
|
|
# -----------------------------------------------------------------------------
|
|
@torch.inference_mode()
|
|
def translate_batch(texts, tokenizer, model, src_lang, tgt_lang, max_length, num_beams, device):
|
|
tokenizer.src_lang = src_lang
|
|
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_lang)
|
|
|
|
# Tokenize inputs:
|
|
# - return_tensors="pt" produces PyTorch tensors
|
|
# - padding=True aligns batch sequences
|
|
# - truncation=True ensures max_length is respected
|
|
inputs = tokenizer(
|
|
texts,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=max_length,
|
|
).to(device)
|
|
|
|
# Generate translated sequences
|
|
generated = model.generate(
|
|
**inputs,
|
|
forced_bos_token_id=forced_bos_token_id,
|
|
max_length=max_length,
|
|
num_beams=num_beams,
|
|
)
|
|
|
|
# Convert generated token IDs into readable strings
|
|
return tokenizer.batch_decode(generated, skip_special_tokens=True)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# main()
|
|
# Purpose:
|
|
# End-to-end pipeline to translate selected fields of the "LibrAI/do-not-answer"
|
|
# dataset into Slovak (or any NLLB-supported target language) and save the result
|
|
# to disk.
|
|
#
|
|
# Pipeline steps:
|
|
# 1) Parse CLI arguments (paths, model, split, translation fields, generation params).
|
|
# 2) Load the dataset split from Hugging Face Hub.
|
|
# 3) Load NLLB tokenizer + model, choose device (CUDA if available).
|
|
# 4) Translate requested fields via datasets.map() in batched mode:
|
|
# - The datasets.map batch size (batch_size=128) is separate from translation sub-batches.
|
|
# - Translation sub-batches are controlled by args.batch_size.
|
|
# 5) Save the translated dataset to disk and print the output path.
|
|
# -----------------------------------------------------------------------------
|
|
def main():
|
|
# CLI arguments
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument(
|
|
"--base_dir",
|
|
default="/home/hyrenko/Diploma/datasets",
|
|
help="Base directory where the translated dataset will be saved",
|
|
)
|
|
p.add_argument(
|
|
"--out_name",
|
|
default="do_not_answer_sk",
|
|
help="Name of the output folder to create inside base_dir",
|
|
)
|
|
p.add_argument("--model", default="facebook/nllb-200-1.3B")
|
|
p.add_argument("--split", default="train")
|
|
|
|
p.add_argument(
|
|
"--translate_fields",
|
|
default="question",
|
|
help="Comma-separated list of fields to translate (e.g., question,risk_area,types_of_harm,specific_harms)",
|
|
)
|
|
|
|
# Generation / performance parameters
|
|
p.add_argument("--batch_size", type=int, default=32)
|
|
p.add_argument("--max_length", type=int, default=256)
|
|
p.add_argument("--num_beams", type=int, default=4)
|
|
|
|
# NLLB language codes
|
|
p.add_argument("--src_lang", default="eng_Latn")
|
|
p.add_argument("--tgt_lang", default="slk_Latn")
|
|
|
|
args = p.parse_args()
|
|
|
|
# Output directory setup
|
|
out_dir = os.path.join(args.base_dir, args.out_name)
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
# Parse fields to translate
|
|
fields = [x.strip() for x in args.translate_fields.split(",") if x.strip()]
|
|
|
|
# 1) Load dataset split from Hugging Face Hub
|
|
ds = load_dataset("LibrAI/do-not-answer", split=args.split)
|
|
|
|
# 2) Load NLLB model/tokenizer
|
|
# If CUDA is available, use FP16 for better performance and lower VRAM usage.
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
tok = AutoTokenizer.from_pretrained(args.model)
|
|
|
|
mdl = AutoModelForSeq2SeqLM.from_pretrained(
|
|
args.model,
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
mdl = mdl.to(device)
|
|
mdl.eval()
|
|
|
|
# 3) Translate
|
|
# map_fn(batch)
|
|
# Purpose:
|
|
# For each requested field, translate the field values and store them into
|
|
# a new column named "<field>_sk".
|
|
#
|
|
# Behavior:
|
|
# - Copies the incoming batch dict to out.
|
|
# - For each field in 'fields':
|
|
# * Takes batch[f] as a list of strings.
|
|
# * Splits into translation sub-batches of size args.batch_size.
|
|
# * Calls translate_batch() and concatenates results.
|
|
# * Writes translated list to out[f"{f}_sk"].
|
|
# - Returns out, which datasets.map merges into the dataset.
|
|
def map_fn(batch):
|
|
out = dict(batch)
|
|
for f in fields:
|
|
texts = batch[f]
|
|
translated_all = []
|
|
for sub in chunk_list(texts, args.batch_size):
|
|
translated_all.extend(
|
|
translate_batch(
|
|
sub, tok, mdl,
|
|
src_lang=args.src_lang,
|
|
tgt_lang=args.tgt_lang,
|
|
max_length=args.max_length,
|
|
num_beams=args.num_beams,
|
|
device=device,
|
|
)
|
|
)
|
|
out[f"{f}_sk"] = translated_all
|
|
return out
|
|
|
|
# datasets.map batch size can be larger than translation batch_size
|
|
# (these are two different levels of batching).
|
|
ds_sk = ds.map(map_fn, batched=True, batch_size=128, desc="Translating to Slovak")
|
|
|
|
# 4) Save to disk
|
|
ds_sk.save_to_disk(out_dir)
|
|
print(f"Saved translated dataset to: {out_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Standard entry point guard: run main() only when executed as a script
|
|
main()
|