384 lines
12 KiB
Python
384 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Translate Slovak 'responses.json' outputs into English using a local NLLB model.
|
|
|
|
Core behavior:
|
|
- Auto-detect run directories in /home/hyrenko/Diploma/outputs that:
|
|
- look like "mistral-sk" runs (name heuristic)
|
|
- contain responses.json
|
|
- Translate both 'prompt' and 'response' fields from Slovak -> English
|
|
- Save translated runs into /home/hyrenko/Diploma/outputs_translated
|
|
- Keep the same run folder name, with a special rename rule:
|
|
do_not_answer_sk -> do_not_answer_eng
|
|
- Write a clean EN-only responses.json (prompt/response overwritten with English)
|
|
- Also write translate.log.txt with parameters and paths
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import json
|
|
import datetime
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
# -------------------------
|
|
# Fixed paths / constants
|
|
# -------------------------
|
|
NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B"
|
|
OUTPUTS_ROOT = "/home/hyrenko/Diploma/outputs"
|
|
OUTPUTS_TRANSLATED_ROOT = "/home/hyrenko/Diploma/outputs_translated"
|
|
|
|
SRC_LANG = "slk_Latn"
|
|
TGT_LANG = "eng_Latn"
|
|
|
|
RESPONSES_FILENAME = "responses.json"
|
|
|
|
|
|
def safe_mkdir(path: str):
|
|
"""Create a directory if it does not exist."""
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
|
|
def load_json(path: str) -> Any:
|
|
"""Load JSON from disk and return the parsed Python object."""
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def save_json(path: str, obj: Any):
|
|
"""Write a Python object to disk as pretty-printed JSON (UTF-8)."""
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(obj, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
def normalize_text(s: Any) -> str:
|
|
"""
|
|
Normalize an input value into a safe string for translation:
|
|
- None -> ""
|
|
- removes the Unicode replacement char (U+FFFD)
|
|
- strips whitespace
|
|
"""
|
|
if s is None:
|
|
return ""
|
|
return str(s).replace("\uFFFD", "").strip()
|
|
|
|
|
|
def split_text_safely(text: str, max_chars: int) -> List[str]:
|
|
"""
|
|
Split text into chunks to keep translation stable on long inputs.
|
|
|
|
Strategy (coarse -> fine):
|
|
- paragraphs (2+ newlines)
|
|
- lines
|
|
- sentence-ish boundaries by punctuation regex
|
|
- hard character split fallback
|
|
"""
|
|
text = normalize_text(text)
|
|
if not text:
|
|
return [""]
|
|
|
|
chunks: List[str] = []
|
|
|
|
def push(piece: str):
|
|
piece = piece.strip()
|
|
if not piece:
|
|
return
|
|
if len(piece) <= max_chars:
|
|
chunks.append(piece)
|
|
else:
|
|
for i in range(0, len(piece), max_chars):
|
|
part = piece[i:i + max_chars].strip()
|
|
if part:
|
|
chunks.append(part)
|
|
|
|
paras = re.split(r"\n{2,}", text)
|
|
for para in paras:
|
|
para = para.strip()
|
|
if not para:
|
|
continue
|
|
|
|
if len(para) <= max_chars:
|
|
chunks.append(para)
|
|
continue
|
|
|
|
lines = [ln.strip() for ln in para.split("\n") if ln.strip()]
|
|
buf = ""
|
|
for ln in lines:
|
|
if not buf:
|
|
buf = ln
|
|
elif len(buf) + 1 + len(ln) <= max_chars:
|
|
buf += "\n" + ln
|
|
else:
|
|
push(buf)
|
|
buf = ln
|
|
|
|
if buf:
|
|
if len(buf) <= max_chars:
|
|
chunks.append(buf)
|
|
else:
|
|
sents = re.split(r"(?<=[\.\!\?])\s+", buf)
|
|
tmp = ""
|
|
for s in sents:
|
|
s = s.strip()
|
|
if not s:
|
|
continue
|
|
if not tmp:
|
|
tmp = s
|
|
elif len(tmp) + 1 + len(s) <= max_chars:
|
|
tmp += " " + s
|
|
else:
|
|
push(tmp)
|
|
tmp = s
|
|
if tmp:
|
|
push(tmp)
|
|
|
|
return chunks if chunks else [""]
|
|
|
|
|
|
def batchify(items: List[str], bs: int) -> List[List[str]]:
|
|
"""Split a list into consecutive batches of size bs."""
|
|
return [items[i:i + bs] for i in range(0, len(items), bs)]
|
|
|
|
|
|
@torch.inference_mode()
|
|
def nllb_translate_batch(tokenizer, model, texts: List[str], max_new_tokens: int, num_beams: int) -> List[str]:
|
|
"""
|
|
Translate a batch of texts using NLLB from SRC_LANG -> TGT_LANG.
|
|
|
|
Implementation details:
|
|
- tokenizer.src_lang sets the source language
|
|
- forced_bos_token_id forces generation to start in the target language
|
|
"""
|
|
tokenizer.src_lang = SRC_LANG
|
|
forced_bos_token_id = tokenizer.convert_tokens_to_ids(TGT_LANG)
|
|
|
|
inputs = tokenizer(
|
|
texts,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=1024,
|
|
)
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
|
|
|
gen = model.generate(
|
|
**inputs,
|
|
forced_bos_token_id=forced_bos_token_id,
|
|
max_new_tokens=max_new_tokens,
|
|
num_beams=num_beams,
|
|
do_sample=False,
|
|
)
|
|
out = tokenizer.batch_decode(gen, skip_special_tokens=True)
|
|
return [o.strip() for o in out]
|
|
|
|
|
|
def translate_text(tokenizer, model, text: str, chunk_chars: int, batch_size: int, max_new_tokens: int, num_beams: int) -> str:
|
|
"""
|
|
Translate a single text safely:
|
|
- normalize
|
|
- chunk by characters
|
|
- translate chunks in batches
|
|
- join translated chunks
|
|
"""
|
|
text = normalize_text(text)
|
|
if not text:
|
|
return ""
|
|
|
|
chunks = split_text_safely(text, max_chars=chunk_chars)
|
|
translated_chunks: List[str] = []
|
|
for batch in batchify(chunks, batch_size):
|
|
translated_chunks.extend(nllb_translate_batch(tokenizer, model, batch, max_new_tokens, num_beams))
|
|
|
|
return "\n\n".join([c for c in translated_chunks if c])
|
|
|
|
|
|
def is_mistral_sk_dir(name: str) -> bool:
|
|
"""Heuristic: directory name contains both 'mistral' and 'sk'."""
|
|
n = name.lower()
|
|
return ("mistral" in n and "sk" in n and os.path.sep not in n)
|
|
|
|
|
|
def list_candidate_run_dirs(outputs_root: str) -> List[str]:
|
|
"""
|
|
Scan OUTPUTS_ROOT and return run directories that:
|
|
- match the mistral-sk naming heuristic
|
|
- contain responses.json
|
|
"""
|
|
candidates = []
|
|
for name in os.listdir(outputs_root):
|
|
full = os.path.join(outputs_root, name)
|
|
if not os.path.isdir(full):
|
|
continue
|
|
if is_mistral_sk_dir(name):
|
|
resp_path = os.path.join(full, RESPONSES_FILENAME)
|
|
if os.path.isfile(resp_path):
|
|
candidates.append(full)
|
|
return sorted(candidates)
|
|
|
|
|
|
def pick_latest_dir(dirs: List[str]) -> Optional[str]:
|
|
"""
|
|
Choose the most recent directory by:
|
|
1) parsing a timestamp prefix (YYYY-MM-DD_HH-MM-SS) if present
|
|
2) falling back to filesystem mtime
|
|
"""
|
|
def parse_ts(path: str) -> Tuple[int, float]:
|
|
base = os.path.basename(path)
|
|
try:
|
|
dt = datetime.datetime.strptime(base[:19], "%Y-%m-%d_%H-%M-%S")
|
|
return (1, dt.timestamp())
|
|
except Exception:
|
|
try:
|
|
return (0, os.path.getmtime(path))
|
|
except Exception:
|
|
return (0, 0.0)
|
|
|
|
return max(dirs, key=parse_ts) if dirs else None
|
|
|
|
|
|
def make_translated_dirname(src_dirname: str) -> str:
|
|
"""
|
|
Produce destination folder name for the translated run.
|
|
|
|
Special rule:
|
|
- Replace 'do_not_answer_sk' with 'do_not_answer_eng'
|
|
|
|
Fallback:
|
|
- Replace a terminal '_sk' token with '_eng'
|
|
"""
|
|
if "do_not_answer_sk" in src_dirname:
|
|
return src_dirname.replace("do_not_answer_sk", "do_not_answer_eng")
|
|
return re.sub(r"_sk\b", "_eng", src_dirname)
|
|
|
|
|
|
def interactive_choice() -> str:
|
|
"""
|
|
Ask which runs to translate:
|
|
1) latest run
|
|
2) all detected runs
|
|
"""
|
|
print("\n=== What to translate? ===")
|
|
print("1) Latest mistral-sk run (auto-detected)")
|
|
print("2) All mistral-sk runs (auto-detected)")
|
|
choice = input("Choose 1 or 2 (default 1): ").strip()
|
|
return choice if choice in ("1", "2") else "1"
|
|
|
|
|
|
def main():
|
|
# Validate required directories
|
|
if not os.path.isdir(NLLB_PATH):
|
|
raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}")
|
|
if not os.path.isdir(OUTPUTS_ROOT):
|
|
raise SystemExit(f"[ERROR] outputs root not found: {OUTPUTS_ROOT}")
|
|
|
|
safe_mkdir(OUTPUTS_TRANSLATED_ROOT)
|
|
|
|
run_dirs = list_candidate_run_dirs(OUTPUTS_ROOT)
|
|
if not run_dirs:
|
|
raise SystemExit(f"[ERROR] No mistral-sk run dirs with {RESPONSES_FILENAME} found in: {OUTPUTS_ROOT}")
|
|
|
|
choice = interactive_choice()
|
|
if choice == "1":
|
|
latest = pick_latest_dir(run_dirs)
|
|
run_dirs = [latest] if latest else []
|
|
|
|
print(f"\n[INFO] Runs to translate: {len(run_dirs)}")
|
|
for d in run_dirs:
|
|
print(f" - {d}")
|
|
|
|
# Interactive translation parameters
|
|
print("\n=== Translation parameters ===")
|
|
bs_in = input("batch_size (default 8): ").strip()
|
|
batch_size = int(bs_in) if bs_in.isdigit() and int(bs_in) > 0 else 8
|
|
|
|
cc_in = input("chunk_chars (default 1800): ").strip()
|
|
chunk_chars = int(cc_in) if cc_in.isdigit() and int(cc_in) > 0 else 1800
|
|
|
|
mnt_in = input("max_new_tokens (default 256): ").strip()
|
|
max_new_tokens = int(mnt_in) if mnt_in.isdigit() and int(mnt_in) > 0 else 256
|
|
|
|
nb_in = input("num_beams (default 4): ").strip()
|
|
num_beams = int(nb_in) if nb_in.isdigit() and int(nb_in) > 0 else 4
|
|
|
|
# Load NLLB model/tokenizer
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
dtype = torch.float16 if device == "cuda" else torch.float32
|
|
|
|
print("\n[INFO] Loading NLLB...")
|
|
tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
NLLB_PATH,
|
|
torch_dtype=dtype,
|
|
device_map="auto" if device == "cuda" else None,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
if device == "cpu":
|
|
model = model.to("cpu")
|
|
model.eval()
|
|
|
|
for src_run_dir in run_dirs:
|
|
src_dirname = os.path.basename(src_run_dir)
|
|
|
|
dst_dirname = make_translated_dirname(src_dirname)
|
|
dst_run_dir = os.path.join(OUTPUTS_TRANSLATED_ROOT, dst_dirname)
|
|
safe_mkdir(dst_run_dir)
|
|
|
|
src_json = os.path.join(src_run_dir, RESPONSES_FILENAME)
|
|
dst_json_en_only = os.path.join(dst_run_dir, "responses.json")
|
|
dst_log = os.path.join(dst_run_dir, "translate.log.txt")
|
|
|
|
print(f"\n[INFO] Translating:\n src: {src_run_dir}\n dst: {dst_run_dir}\n file: {src_json}")
|
|
|
|
data = load_json(src_json)
|
|
|
|
if not isinstance(data, list):
|
|
raise SystemExit(f"[ERROR] {src_json} must be a list of objects.")
|
|
|
|
en_only: List[Dict[str, Any]] = []
|
|
|
|
with open(dst_log, "w", encoding="utf-8") as lf:
|
|
lf.write(f"[INFO] NLLB_PATH={NLLB_PATH}\n")
|
|
lf.write(f"[INFO] SRC_LANG={SRC_LANG} TGT_LANG={TGT_LANG}\n")
|
|
lf.write(f"[INFO] batch_size={batch_size} chunk_chars={chunk_chars} max_new_tokens={max_new_tokens} num_beams={num_beams}\n")
|
|
lf.write(f"[INFO] src_dir={src_run_dir}\n")
|
|
lf.write(f"[INFO] dst_dir={dst_run_dir}\n")
|
|
|
|
for item in tqdm(data, desc=f"Translating {src_dirname}", total=len(data)):
|
|
if not isinstance(item, dict):
|
|
en_only.append({"_raw": item})
|
|
continue
|
|
|
|
prompt_sk = item.get("prompt", "")
|
|
resp_sk = item.get("response", "")
|
|
|
|
prompt_en = translate_text(tokenizer, model, prompt_sk, chunk_chars, batch_size, max_new_tokens, num_beams)
|
|
resp_en = translate_text(tokenizer, model, resp_sk, chunk_chars, batch_size, max_new_tokens, num_beams)
|
|
|
|
out_item = dict(item)
|
|
out_item["prompt"] = prompt_en
|
|
out_item["response"] = resp_en
|
|
|
|
for k in list(out_item.keys()):
|
|
if k.endswith("_original") or k.endswith("_sk_original") or k.endswith("_eng") or k.endswith("_en"):
|
|
out_item.pop(k, None)
|
|
|
|
en_only.append(out_item)
|
|
|
|
save_json(dst_json_en_only, en_only)
|
|
|
|
print("[OK] Saved EN-only:")
|
|
print(f" - {dst_json_en_only}")
|
|
print(f" - {dst_log}")
|
|
|
|
print("\n✔ Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|