DP/translate/Translate_sk_to_eng.py
2026-02-04 21:07:17 +01:00

351 lines
11 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Interactive translator for Mistral-SK outputs -> English using LOCAL NLLB.
What you asked (final behavior):
- Auto-detect Mistral-SK run dirs in:
/home/hyrenko/Diploma/outputs
(folders containing 'mistral' and 'sk' and a responses.json inside)
- Translate BOTH 'prompt' and 'response' from Slovak -> English using local NLLB:
/home/hyrenko/Diploma/models/nllb-200-1.3B
- Write results into:
/home/hyrenko/Diploma/outputs_translated
- Create the SAME run folder name, except:
do_not_answer_sk -> do_not_answer_eng
Example:
2026-01-25_06-33-31-mistral-sk-7b-do_not_answer_sk-prompt:939-4bit
-> 2026-01-25_06-33-31-mistral-sk-7b-do_not_answer_eng-prompt:939-4bit
- In the destination folder, the main file 'responses.json' contains ONLY ENGLISH:
{
"id": ...,
"category": ...,
"prompt": "<ENGLISH>",
"response": "<ENGLISH>",
"refusal": ...,
...
}
No *_original fields, no duplicates. Clean EN-only.
- Optional: also saves a translate.log.txt for reproducibility.
"""
import os
import re
import json
import datetime
from typing import List, Dict, Any, Optional, Tuple
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# -------------------------
# Your fixed paths
# -------------------------
NLLB_PATH = "/home/hyrenko/Diploma/models/nllb-200-1.3B"
OUTPUTS_ROOT = "/home/hyrenko/Diploma/outputs"
OUTPUTS_TRANSLATED_ROOT = "/home/hyrenko/Diploma/outputs_translated"
SRC_LANG = "slk_Latn"
TGT_LANG = "eng_Latn"
RESPONSES_FILENAME = "responses.json"
def safe_mkdir(path: str):
os.makedirs(path, exist_ok=True)
def load_json(path: str) -> Any:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def save_json(path: str, obj: Any):
with open(path, "w", encoding="utf-8") as f:
json.dump(obj, f, ensure_ascii=False, indent=2)
def normalize_text(s: Any) -> str:
if s is None:
return ""
return str(s).replace("\uFFFD", "").strip()
def split_text_safely(text: str, max_chars: int) -> List[str]:
"""
Robust chunking: paragraphs -> lines -> sentences -> hard split.
Character-based splitting is stable for NLLB.
"""
text = normalize_text(text)
if not text:
return [""]
chunks: List[str] = []
def push(piece: str):
piece = piece.strip()
if not piece:
return
if len(piece) <= max_chars:
chunks.append(piece)
else:
for i in range(0, len(piece), max_chars):
part = piece[i:i + max_chars].strip()
if part:
chunks.append(part)
paras = re.split(r"\n{2,}", text)
for para in paras:
para = para.strip()
if not para:
continue
if len(para) <= max_chars:
chunks.append(para)
continue
lines = [ln.strip() for ln in para.split("\n") if ln.strip()]
buf = ""
for ln in lines:
if not buf:
buf = ln
elif len(buf) + 1 + len(ln) <= max_chars:
buf += "\n" + ln
else:
push(buf)
buf = ln
if buf:
if len(buf) <= max_chars:
chunks.append(buf)
else:
sents = re.split(r"(?<=[\.\!\?])\s+", buf)
tmp = ""
for s in sents:
s = s.strip()
if not s:
continue
if not tmp:
tmp = s
elif len(tmp) + 1 + len(s) <= max_chars:
tmp += " " + s
else:
push(tmp)
tmp = s
if tmp:
push(tmp)
return chunks if chunks else [""]
def batchify(items: List[str], bs: int) -> List[List[str]]:
return [items[i:i + bs] for i in range(0, len(items), bs)]
@torch.inference_mode()
def nllb_translate_batch(tokenizer, model, texts: List[str], max_new_tokens: int, num_beams: int) -> List[str]:
tokenizer.src_lang = SRC_LANG
forced_bos_token_id = tokenizer.convert_tokens_to_ids(TGT_LANG)
inputs = tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
gen = model.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=False,
)
out = tokenizer.batch_decode(gen, skip_special_tokens=True)
return [o.strip() for o in out]
def translate_text(tokenizer, model, text: str, chunk_chars: int, batch_size: int, max_new_tokens: int, num_beams: int) -> str:
text = normalize_text(text)
if not text:
return ""
chunks = split_text_safely(text, max_chars=chunk_chars)
translated_chunks: List[str] = []
for batch in batchify(chunks, batch_size):
translated_chunks.extend(nllb_translate_batch(tokenizer, model, batch, max_new_tokens, num_beams))
return "\n\n".join([c for c in translated_chunks if c])
def is_mistral_sk_dir(name: str) -> bool:
n = name.lower()
return ("mistral" in n and "sk" in n and os.path.sep not in n)
def list_candidate_run_dirs(outputs_root: str) -> List[str]:
candidates = []
for name in os.listdir(outputs_root):
full = os.path.join(outputs_root, name)
if not os.path.isdir(full):
continue
if is_mistral_sk_dir(name):
resp_path = os.path.join(full, RESPONSES_FILENAME)
if os.path.isfile(resp_path):
candidates.append(full)
return sorted(candidates)
def pick_latest_dir(dirs: List[str]) -> Optional[str]:
def parse_ts(path: str) -> Tuple[int, float]:
base = os.path.basename(path)
try:
dt = datetime.datetime.strptime(base[:19], "%Y-%m-%d_%H-%M-%S")
return (1, dt.timestamp())
except Exception:
try:
return (0, os.path.getmtime(path))
except Exception:
return (0, 0.0)
return max(dirs, key=parse_ts) if dirs else None
def make_translated_dirname(src_dirname: str) -> str:
if "do_not_answer_sk" in src_dirname:
return src_dirname.replace("do_not_answer_sk", "do_not_answer_eng")
# fallback: just swap terminal "_sk" if present
return re.sub(r"_sk\b", "_eng", src_dirname)
def interactive_choice() -> str:
print("\n=== What to translate? ===")
print("1) Latest mistral-sk run (auto-detected)")
print("2) All mistral-sk runs (auto-detected)")
choice = input("Choose 1 or 2 (default 1): ").strip()
return choice if choice in ("1", "2") else "1"
def main():
if not os.path.isdir(NLLB_PATH):
raise SystemExit(f"[ERROR] NLLB model path not found: {NLLB_PATH}")
if not os.path.isdir(OUTPUTS_ROOT):
raise SystemExit(f"[ERROR] outputs root not found: {OUTPUTS_ROOT}")
safe_mkdir(OUTPUTS_TRANSLATED_ROOT)
run_dirs = list_candidate_run_dirs(OUTPUTS_ROOT)
if not run_dirs:
raise SystemExit(f"[ERROR] No mistral-sk run dirs with {RESPONSES_FILENAME} found in: {OUTPUTS_ROOT}")
choice = interactive_choice()
if choice == "1":
latest = pick_latest_dir(run_dirs)
run_dirs = [latest] if latest else []
print(f"\n[INFO] Runs to translate: {len(run_dirs)}")
for d in run_dirs:
print(f" - {d}")
# Interactive translation params (simple)
print("\n=== Translation parameters ===")
bs_in = input("batch_size (default 8): ").strip()
batch_size = int(bs_in) if bs_in.isdigit() and int(bs_in) > 0 else 8
cc_in = input("chunk_chars (default 1800): ").strip()
chunk_chars = int(cc_in) if cc_in.isdigit() and int(cc_in) > 0 else 1800
mnt_in = input("max_new_tokens (default 256): ").strip()
max_new_tokens = int(mnt_in) if mnt_in.isdigit() and int(mnt_in) > 0 else 256
nb_in = input("num_beams (default 4): ").strip()
num_beams = int(nb_in) if nb_in.isdigit() and int(nb_in) > 0 else 4
# Load NLLB
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print("\n[INFO] Loading NLLB...")
tokenizer = AutoTokenizer.from_pretrained(NLLB_PATH, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
NLLB_PATH,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else None,
low_cpu_mem_usage=True,
)
if device == "cpu":
model = model.to("cpu")
model.eval()
# Translate each run
for src_run_dir in run_dirs:
src_dirname = os.path.basename(src_run_dir)
dst_dirname = make_translated_dirname(src_dirname)
dst_run_dir = os.path.join(OUTPUTS_TRANSLATED_ROOT, dst_dirname)
safe_mkdir(dst_run_dir)
src_json = os.path.join(src_run_dir, RESPONSES_FILENAME)
dst_json_en_only = os.path.join(dst_run_dir, "responses.json")
dst_log = os.path.join(dst_run_dir, "translate.log.txt")
print(f"\n[INFO] Translating:\n src: {src_run_dir}\n dst: {dst_run_dir}\n file: {src_json}")
data = load_json(src_json)
if not isinstance(data, list):
raise SystemExit(f"[ERROR] {src_json} must be a list of objects.")
en_only: List[Dict[str, Any]] = []
with open(dst_log, "w", encoding="utf-8") as lf:
lf.write(f"[INFO] NLLB_PATH={NLLB_PATH}\n")
lf.write(f"[INFO] SRC_LANG={SRC_LANG} TGT_LANG={TGT_LANG}\n")
lf.write(f"[INFO] batch_size={batch_size} chunk_chars={chunk_chars} max_new_tokens={max_new_tokens} num_beams={num_beams}\n")
lf.write(f"[INFO] src_dir={src_run_dir}\n")
lf.write(f"[INFO] dst_dir={dst_run_dir}\n")
for item in tqdm(data, desc=f"Translating {src_dirname}", total=len(data)):
if not isinstance(item, dict):
# keep non-dict items as-is (rare)
en_only.append({"_raw": item})
continue
prompt_sk = item.get("prompt", "")
resp_sk = item.get("response", "")
prompt_en = translate_text(tokenizer, model, prompt_sk, chunk_chars, batch_size, max_new_tokens, num_beams)
resp_en = translate_text(tokenizer, model, resp_sk, chunk_chars, batch_size, max_new_tokens, num_beams)
# EN-only output: copy everything but replace prompt/response with English
out_item = dict(item)
out_item["prompt"] = prompt_en
out_item["response"] = resp_en
# Ensure no leftover original fields if they existed
for k in list(out_item.keys()):
if k.endswith("_original") or k.endswith("_sk_original") or k.endswith("_eng") or k.endswith("_en"):
# remove any prior translation artifacts
out_item.pop(k, None)
en_only.append(out_item)
save_json(dst_json_en_only, en_only)
print("[OK] Saved EN-only:")
print(f" - {dst_json_en_only}")
print(f" - {dst_log}")
print("\n✔ Done.")
if __name__ == "__main__":
main()