DP/translate/Translate_sk_to_eng.py
2026-02-04 20:31:34 +00:00

384 lines
12 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Translate Slovak 'responses.json' outputs into English using a local NLLB model.
Core behavior:
- Auto-detect run directories in /home/hyrenko/Diploma/outputs that:
- look like "mistral-sk" runs (name heuristic)
- contain responses.json
- Translate both 'prompt' and 'response' fields from Slovak -> English
- Save translated runs into /home/hyrenko/Diploma/outputs_translated
- Keep the same run folder name, with a special rename rule:
do_not_answer_sk -> do_not_answer_eng
- Write a clean EN-only responses.json (prompt/response overwritten with English)
- Also write translate.log.txt with parameters and paths
"""
import os
import re
import json
import datetime
from typing import List, Dict, Any, Optional, Tuple
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# -------------------------
# Fixed paths / constants
# -------------------------
NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B"
OUTPUTS_ROOT = "/home/hyrenko/Diploma/outputs"
OUTPUTS_TRANSLATED_ROOT = "/home/hyrenko/Diploma/outputs_translated"
SRC_LANG = "slk_Latn"
TGT_LANG = "eng_Latn"
RESPONSES_FILENAME = "responses.json"
def safe_mkdir(path: str):
"""Create a directory if it does not exist."""
os.makedirs(path, exist_ok=True)
def load_json(path: str) -> Any:
"""Load JSON from disk and return the parsed Python object."""
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def save_json(path: str, obj: Any):
"""Write a Python object to disk as pretty-printed JSON (UTF-8)."""
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, ensure_ascii=False, indent=2)
def normalize_text(s: Any) -> str:
"""
Normalize an input value into a safe string for translation:
- None -> ""
- removes the Unicode replacement char (U+FFFD)
- strips whitespace
"""
if s is None:
return ""
return str(s).replace("\uFFFD", "").strip()
def split_text_safely(text: str, max_chars: int) -> List[str]:
"""
Split text into chunks to keep translation stable on long inputs.
Strategy (coarse -> fine):
- paragraphs (2+ newlines)
- lines
- sentence-ish boundaries by punctuation regex
- hard character split fallback
"""
text = normalize_text(text)
if not text:
return [""]
chunks: List[str] = []
def push(piece: str):
piece = piece.strip()
if not piece:
return
if len(piece) <= max_chars:
chunks.append(piece)
else:
for i in range(0, len(piece), max_chars):
part = piece[i:i + max_chars].strip()
if part:
chunks.append(part)
paras = re.split(r"\n{2,}", text)
for para in paras:
para = para.strip()
if not para:
continue
if len(para) <= max_chars:
chunks.append(para)
continue
lines = [ln.strip() for ln in para.split("\n") if ln.strip()]
buf = ""
for ln in lines:
if not buf:
buf = ln
elif len(buf) + 1 + len(ln) <= max_chars:
buf += "\n" + ln
else:
push(buf)
buf = ln
if buf:
if len(buf) <= max_chars:
chunks.append(buf)
else:
sents = re.split(r"(?<=[\.\!\?])\s+", buf)
tmp = ""
for s in sents:
s = s.strip()
if not s:
continue
if not tmp:
tmp = s
elif len(tmp) + 1 + len(s) <= max_chars:
tmp += " " + s
else:
push(tmp)
tmp = s
if tmp:
push(tmp)
return chunks if chunks else [""]
def batchify(items: List[str], bs: int) -> List[List[str]]:
"""Split a list into consecutive batches of size bs."""
return [items[i:i + bs] for i in range(0, len(items), bs)]
@torch.inference_mode()
def nllb_translate_batch(tokenizer, model, texts: List[str], max_new_tokens: int, num_beams: int) -> List[str]:
"""
Translate a batch of texts using NLLB from SRC_LANG -> TGT_LANG.
Implementation details:
- tokenizer.src_lang sets the source language
- forced_bos_token_id forces generation to start in the target language
"""
tokenizer.src_lang = SRC_LANG
forced_bos_token_id = tokenizer.convert_tokens_to_ids(TGT_LANG)
inputs = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=False,
)
out = tokenizer.batch_decode(gen, skip_special_tokens=True)
return [o.strip() for o in out]
def translate_text(tokenizer, model, text: str, chunk_chars: int, batch_size: int, max_new_tokens: int, num_beams: int) -> str:
"""
Translate a single text safely:
- normalize
- chunk by characters
- translate chunks in batches
- join translated chunks
"""
text = normalize_text(text)
if not text:
return ""
chunks = split_text_safely(text, max_chars=chunk_chars)
translated_chunks: List[str] = []
for batch in batchify(chunks, batch_size):
translated_chunks.extend(nllb_translate_batch(tokenizer, model, batch, max_new_tokens, num_beams))
return "\n\n".join([c for c in translated_chunks if c])
def is_mistral_sk_dir(name: str) -> bool:
"""Heuristic: directory name contains both 'mistral' and 'sk'."""
n = name.lower()
return ("mistral" in n and "sk" in n and os.path.sep not in n)
def list_candidate_run_dirs(outputs_root: str) -> List[str]:
"""
Scan OUTPUTS_ROOT and return run directories that:
- match the mistral-sk naming heuristic
- contain responses.json
"""
candidates = []
for name in os.listdir(outputs_root):
full = os.path.join(outputs_root, name)
if not os.path.isdir(full):
continue
if is_mistral_sk_dir(name):
resp_path = os.path.join(full, RESPONSES_FILENAME)
if os.path.isfile(resp_path):
candidates.append(full)
return sorted(candidates)
def pick_latest_dir(dirs: List[str]) -> Optional[str]:
"""
Choose the most recent directory by:
1) parsing a timestamp prefix (YYYY-MM-DD_HH-MM-SS) if present
2) falling back to filesystem mtime
"""
def parse_ts(path: str) -> Tuple[int, float]:
base = os.path.basename(path)
try:
dt = datetime.datetime.strptime(base[:19], "%Y-%m-%d_%H-%M-%S")
return (1, dt.timestamp())
except Exception:
try:
return (0, os.path.getmtime(path))
except Exception:
return (0, 0.0)
return max(dirs, key=parse_ts) if dirs else None
def make_translated_dirname(src_dirname: str) -> str:
"""
Produce destination folder name for the translated run.
Special rule:
- Replace 'do_not_answer_sk' with 'do_not_answer_eng'
Fallback:
- Replace a terminal '_sk' token with '_eng'
"""
if "do_not_answer_sk" in src_dirname:
return src_dirname.replace("do_not_answer_sk", "do_not_answer_eng")
return re.sub(r"_sk\b", "_eng", src_dirname)
def interactive_choice() -> str:
"""
Ask which runs to translate:
1) latest run
2) all detected runs
"""
print("\n=== What to translate? ===")
print("1) Latest mistral-sk run (auto-detected)")
print("2) All mistral-sk runs (auto-detected)")
choice = input("Choose 1 or 2 (default 1): ").strip()
return choice if choice in ("1", "2") else "1"
def main():
# Validate required directories
if not os.path.isdir(NLLB_PATH):
raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}")
if not os.path.isdir(OUTPUTS_ROOT):
raise SystemExit(f"[ERROR] outputs root not found: {OUTPUTS_ROOT}")
safe_mkdir(OUTPUTS_TRANSLATED_ROOT)
run_dirs = list_candidate_run_dirs(OUTPUTS_ROOT)
if not run_dirs:
raise SystemExit(f"[ERROR] No mistral-sk run dirs with {RESPONSES_FILENAME} found in: {OUTPUTS_ROOT}")
choice = interactive_choice()
if choice == "1":
latest = pick_latest_dir(run_dirs)
run_dirs = [latest] if latest else []
print(f"\n[INFO] Runs to translate: {len(run_dirs)}")
for d in run_dirs:
print(f" - {d}")
# Interactive translation parameters
print("\n=== Translation parameters ===")
bs_in = input("batch_size (default 8): ").strip()
batch_size = int(bs_in) if bs_in.isdigit() and int(bs_in) > 0 else 8
cc_in = input("chunk_chars (default 1800): ").strip()
chunk_chars = int(cc_in) if cc_in.isdigit() and int(cc_in) > 0 else 1800
mnt_in = input("max_new_tokens (default 256): ").strip()
max_new_tokens = int(mnt_in) if mnt_in.isdigit() and int(mnt_in) > 0 else 256
nb_in = input("num_beams (default 4): ").strip()
num_beams = int(nb_in) if nb_in.isdigit() and int(nb_in) > 0 else 4
# Load NLLB model/tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print("\n[INFO] Loading NLLB...")
tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
NLLB_PATH,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else None,
low_cpu_mem_usage=True,
)
if device == "cpu":
model = model.to("cpu")
model.eval()
for src_run_dir in run_dirs:
src_dirname = os.path.basename(src_run_dir)
dst_dirname = make_translated_dirname(src_dirname)
dst_run_dir = os.path.join(OUTPUTS_TRANSLATED_ROOT, dst_dirname)
safe_mkdir(dst_run_dir)
src_json = os.path.join(src_run_dir, RESPONSES_FILENAME)
dst_json_en_only = os.path.join(dst_run_dir, "responses.json")
dst_log = os.path.join(dst_run_dir, "translate.log.txt")
print(f"\n[INFO] Translating:\n src: {src_run_dir}\n dst: {dst_run_dir}\n file: {src_json}")
data = load_json(src_json)
if not isinstance(data, list):
raise SystemExit(f"[ERROR] {src_json} must be a list of objects.")
en_only: List[Dict[str, Any]] = []
with open(dst_log, "w", encoding="utf-8") as lf:
lf.write(f"[INFO] NLLB_PATH={NLLB_PATH}\n")
lf.write(f"[INFO] SRC_LANG={SRC_LANG} TGT_LANG={TGT_LANG}\n")
lf.write(f"[INFO] batch_size={batch_size} chunk_chars={chunk_chars} max_new_tokens={max_new_tokens} num_beams={num_beams}\n")
lf.write(f"[INFO] src_dir={src_run_dir}\n")
lf.write(f"[INFO] dst_dir={dst_run_dir}\n")
for item in tqdm(data, desc=f"Translating {src_dirname}", total=len(data)):
if not isinstance(item, dict):
en_only.append({"_raw": item})
continue
prompt_sk = item.get("prompt", "")
resp_sk = item.get("response", "")
prompt_en = translate_text(tokenizer, model, prompt_sk, chunk_chars, batch_size, max_new_tokens, num_beams)
resp_en = translate_text(tokenizer, model, resp_sk, chunk_chars, batch_size, max_new_tokens, num_beams)
out_item = dict(item)
out_item["prompt"] = prompt_en
out_item["response"] = resp_en
for k in list(out_item.keys()):
if k.endswith("_original") or k.endswith("_sk_original") or k.endswith("_eng") or k.endswith("_en"):
out_item.pop(k, None)
en_only.append(out_item)
save_json(dst_json_en_only, en_only)
print("[OK] Saved EN-only:")
print(f" - {dst_json_en_only}")
print(f" - {dst_log}")
print("\n✔ Done.")
if __name__ == "__main__":
main()