106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import os
|
|
import torch
|
|
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
def chunk_list(lst, n):
|
|
for i in range(0, len(lst), n):
|
|
yield lst[i:i+n]
|
|
|
|
@torch.inference_mode()
|
|
def translate_batch(texts, tokenizer, model, src_lang, tgt_lang, max_length, num_beams, device):
|
|
tokenizer.src_lang = src_lang
|
|
forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_lang)
|
|
|
|
inputs = tokenizer(
|
|
texts,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=max_length,
|
|
).to(device)
|
|
|
|
generated = model.generate(
|
|
**inputs,
|
|
forced_bos_token_id=forced_bos_token_id,
|
|
max_length=max_length,
|
|
num_beams=num_beams,
|
|
)
|
|
return tokenizer.batch_decode(generated, skip_special_tokens=True)
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--base_dir", default="/home/hyrenko/Diploma/datasets",
|
|
help="Базовая директория для сохранения результата")
|
|
p.add_argument("--out_name", default="do_not_answer_sk",
|
|
help="Имя папки, которая будет создана внутри base_dir")
|
|
p.add_argument("--model", default="facebook/nllb-200-1.3B")
|
|
p.add_argument("--split", default="train")
|
|
|
|
p.add_argument("--translate_fields", default="question",
|
|
help="Поля для перевода через запятую. Например: question,risk_area,types_of_harm,specific_harms")
|
|
|
|
# Параметры генерации/производительности
|
|
p.add_argument("--batch_size", type=int, default=32)
|
|
p.add_argument("--max_length", type=int, default=256)
|
|
p.add_argument("--num_beams", type=int, default=4)
|
|
|
|
# Языковые коды NLLB
|
|
p.add_argument("--src_lang", default="eng_Latn")
|
|
p.add_argument("--tgt_lang", default="slk_Latn")
|
|
|
|
args = p.parse_args()
|
|
|
|
out_dir = os.path.join(args.base_dir, args.out_name)
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
fields = [x.strip() for x in args.translate_fields.split(",") if x.strip()]
|
|
|
|
# 1) Load dataset
|
|
ds = load_dataset("LibrAI/do-not-answer", split=args.split)
|
|
|
|
# 2) Load NLLB (FP16 на GPU)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
tok = AutoTokenizer.from_pretrained(args.model)
|
|
|
|
mdl = AutoModelForSeq2SeqLM.from_pretrained(
|
|
args.model,
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
mdl = mdl.to(device)
|
|
mdl.eval()
|
|
|
|
# 3) Translate
|
|
def map_fn(batch):
|
|
out = dict(batch)
|
|
for f in fields:
|
|
texts = batch[f]
|
|
translated_all = []
|
|
for sub in chunk_list(texts, args.batch_size):
|
|
translated_all.extend(
|
|
translate_batch(
|
|
sub, tok, mdl,
|
|
src_lang=args.src_lang,
|
|
tgt_lang=args.tgt_lang,
|
|
max_length=args.max_length,
|
|
num_beams=args.num_beams,
|
|
device=device,
|
|
)
|
|
)
|
|
out[f"{f}_sk"] = translated_all
|
|
return out
|
|
|
|
# datasets.map батч: можно больше, чем batch_size перевода (это разные уровни)
|
|
ds_sk = ds.map(map_fn, batched=True, batch_size=128, desc="Translating to Slovak")
|
|
|
|
# 4) Save
|
|
ds_sk.save_to_disk(out_dir)
|
|
print(f"Saved translated dataset to: {out_dir}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|