DP/translate/translate_do-not_answer.py
2026-02-04 21:07:17 +01:00

106 lines
3.6 KiB
Python

#!/usr/bin/env python3
import argparse
import os
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
def chunk_list(lst, n):
for i in range(0, len(lst), n):
yield lst[i:i+n]
@torch.inference_mode()
def translate_batch(texts, tokenizer, model, src_lang, tgt_lang, max_length, num_beams, device):
tokenizer.src_lang = src_lang
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_lang)
inputs = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
).to(device)
generated = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=max_length,
num_beams=num_beams,
)
return tokenizer.batch_decode(generated, skip_special_tokens=True)
def main():
p = argparse.ArgumentParser()
p.add_argument("--base_dir", default="/home/hyrenko/Diploma/datasets",
help="Базовая директория для сохранения результата")
p.add_argument("--out_name", default="do_not_answer_sk",
help="Имя папки, которая будет создана внутри base_dir")
p.add_argument("--model", default="facebook/nllb-200-1.3B")
p.add_argument("--split", default="train")
p.add_argument("--translate_fields", default="question",
help="Поля для перевода через запятую. Например: question,risk_area,types_of_harm,specific_harms")
# Параметры генерации/производительности
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--max_length", type=int, default=256)
p.add_argument("--num_beams", type=int, default=4)
# Языковые коды NLLB
p.add_argument("--src_lang", default="eng_Latn")
p.add_argument("--tgt_lang", default="slk_Latn")
args = p.parse_args()
out_dir = os.path.join(args.base_dir, args.out_name)
os.makedirs(out_dir, exist_ok=True)
fields = [x.strip() for x in args.translate_fields.split(",") if x.strip()]
# 1) Load dataset
ds = load_dataset("LibrAI/do-not-answer", split=args.split)
# 2) Load NLLB (FP16 на GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
tok = AutoTokenizer.from_pretrained(args.model)
mdl = AutoModelForSeq2SeqLM.from_pretrained(
args.model,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True,
)
mdl = mdl.to(device)
mdl.eval()
# 3) Translate
def map_fn(batch):
out = dict(batch)
for f in fields:
texts = batch[f]
translated_all = []
for sub in chunk_list(texts, args.batch_size):
translated_all.extend(
translate_batch(
sub, tok, mdl,
src_lang=args.src_lang,
tgt_lang=args.tgt_lang,
max_length=args.max_length,
num_beams=args.num_beams,
device=device,
)
)
out[f"{f}_sk"] = translated_all
return out
# datasets.map батч: можно больше, чем batch_size перевода (это разные уровни)
ds_sk = ds.map(map_fn, batched=True, batch_size=128, desc="Translating to Slovak")
# 4) Save
ds_sk.save_to_disk(out_dir)
print(f"Saved translated dataset to: {out_dir}")
if __name__ == "__main__":
main()