Added comments
This commit is contained in:
parent
1d8289bb13
commit
3bb4286763
@ -22,7 +22,7 @@ LLAMA_INPUT_DIR = "/home/hyrenko/Diploma/response/llama"
|
||||
GEMMA_INPUT_DIR = "/home/hyrenko/Diploma/response/gemma"
|
||||
QWEN_INPUT_DIR = "/home/hyrenko/Diploma/response/qwen"
|
||||
|
||||
# NEW: mistral-sk translated outputs root (where your translated runs live)
|
||||
# Root for translated mistral-sk runs (each run dir contains responses.json)
|
||||
MISTRAL_SK_TRANSLATED_ROOT = "/home/hyrenko/Diploma/outputs_translated"
|
||||
MISTRAL_RESPONSES_FILENAME = "responses.json"
|
||||
|
||||
@ -39,6 +39,7 @@ os.makedirs(OUTPUT_ROOT, exist_ok=True)
|
||||
# Interactive pickers
|
||||
# =========================
|
||||
def pick_gpu_interactive() -> str:
|
||||
"""Interactive device picker: returns GPU id string or empty string for CPU."""
|
||||
def fmt_gb(b): return f"{b / (1024**3):.1f} GB"
|
||||
print("\n=== Device selection ===")
|
||||
if not torch.cuda.is_available():
|
||||
@ -59,16 +60,15 @@ def pick_gpu_interactive() -> str:
|
||||
|
||||
|
||||
def _is_mistral_sk_dir(name: str) -> bool:
|
||||
"""Heuristic: directory name contains 'mistral' and 'sk' (in any form)."""
|
||||
n = name.lower()
|
||||
# matches: "mistral-sk", "mistral_sk", or any name containing both "mistral" and "sk"
|
||||
return ("mistral" in n and "sk" in n)
|
||||
|
||||
|
||||
def _list_mistral_sk_run_dirs(root: str) -> List[str]:
|
||||
"""
|
||||
Returns all directories under root that look like mistral-sk runs
|
||||
and contain responses.json.
|
||||
Sorted newest-first by timestamp prefix or mtime.
|
||||
List directories under 'root' that look like mistral-sk runs and contain responses.json.
|
||||
Sorted newest-first by timestamp prefix (YYYY-MM-DD_HH-MM-SS) if present, otherwise by mtime.
|
||||
"""
|
||||
if not os.path.isdir(root):
|
||||
return []
|
||||
@ -86,7 +86,6 @@ def _list_mistral_sk_run_dirs(root: str) -> List[str]:
|
||||
|
||||
def sort_key(path: str) -> Tuple[int, float, str]:
|
||||
base = os.path.basename(path)
|
||||
# prefer YYYY-MM-DD_HH-MM-SS prefix
|
||||
try:
|
||||
dt = datetime.datetime.strptime(base[:19], "%Y-%m-%d_%H-%M-%S")
|
||||
return (1, dt.timestamp(), base)
|
||||
@ -100,6 +99,10 @@ def _list_mistral_sk_run_dirs(root: str) -> List[str]:
|
||||
|
||||
|
||||
def pick_mistral_sk_run_dirs_interactive(root: str) -> List[str]:
|
||||
"""
|
||||
Interactive picker for mistral-sk runs under 'root'.
|
||||
Returns a list of selected run directories.
|
||||
"""
|
||||
runs = _list_mistral_sk_run_dirs(root)
|
||||
if not runs:
|
||||
return []
|
||||
@ -157,9 +160,9 @@ def pick_mistral_sk_run_dirs_interactive(root: str) -> List[str]:
|
||||
|
||||
def pick_input_dirs_interactive() -> List[str]:
|
||||
"""
|
||||
Returns list of input directories.
|
||||
- For llama/gemma/qwen: returns [that_dir]
|
||||
- For mistral-sk: returns [run_dir1, run_dir2, ...] chosen by user
|
||||
Returns a list of input directories to evaluate.
|
||||
- llama/gemma/qwen: returns [that_dir]
|
||||
- mistral-sk: returns selected run dir(s) under MISTRAL_SK_TRANSLATED_ROOT
|
||||
"""
|
||||
options = {
|
||||
"1": LLAMA_INPUT_DIR,
|
||||
@ -199,9 +202,15 @@ def pick_input_dirs_interactive() -> List[str]:
|
||||
|
||||
|
||||
def resolve_model_key(input_dir: str) -> str:
|
||||
"""
|
||||
Map an input directory to a stable model key:
|
||||
- mistral_sk for runs under MISTRAL_SK_TRANSLATED_ROOT
|
||||
- llama/gemma/qwen for fixed input dirs
|
||||
- fallback: basename(input_dir)
|
||||
"""
|
||||
norm = os.path.abspath(input_dir)
|
||||
|
||||
# detect mistral-sk run dir
|
||||
# Detect mistral-sk run dir under outputs_translated
|
||||
if os.path.abspath(os.path.dirname(norm)) == os.path.abspath(MISTRAL_SK_TRANSLATED_ROOT) and _is_mistral_sk_dir(os.path.basename(norm)):
|
||||
return "mistral_sk"
|
||||
|
||||
@ -219,10 +228,14 @@ def resolve_model_key(input_dir: str) -> str:
|
||||
# Model loader
|
||||
# =========================
|
||||
def init_model_and_tokenizer(model_path: str, cuda_visible_devices: str):
|
||||
"""
|
||||
Load the classifier model and tokenizer.
|
||||
Also verifies that weights exist in the model directory (supports sharded safetensors).
|
||||
"""
|
||||
if not os.path.isdir(model_path):
|
||||
raise FileNotFoundError(f"[ERROR] Model path not found: {model_path}")
|
||||
|
||||
# --- Verify model weights (supports sharded safetensors) ---
|
||||
# Verify weight files
|
||||
weight_files = [
|
||||
f for f in os.listdir(model_path)
|
||||
if f.endswith(".bin") or f.endswith(".safetensors") or f.endswith(".index.json")
|
||||
@ -261,7 +274,8 @@ def init_model_and_tokenizer(model_path: str, cuda_visible_devices: str):
|
||||
# =========================
|
||||
def list_json_files(input_dir: str) -> List[str]:
|
||||
"""
|
||||
Берём все *.json в директории (включая responses.json у mistral-sk run).
|
||||
Return all *.json files in the directory.
|
||||
For mistral-sk runs this typically includes responses.json.
|
||||
"""
|
||||
files = sorted(glob.glob(os.path.join(input_dir, "*.json")))
|
||||
if not files:
|
||||
@ -270,6 +284,13 @@ def list_json_files(input_dir: str) -> List[str]:
|
||||
|
||||
|
||||
def load_records(path: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load records from a file that can be either:
|
||||
- a JSON list
|
||||
- a single JSON object
|
||||
- JSONL (one JSON object per line)
|
||||
Returns a list of dict records.
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
text = f.read().strip()
|
||||
if not text:
|
||||
@ -287,6 +308,7 @@ def load_records(path: str) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
def ensure_dir(path: str):
|
||||
"""Create output directory if missing."""
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
@ -314,6 +336,7 @@ REFUSAL_PATTERNS = [
|
||||
|
||||
|
||||
def is_prompt_heuristic_unsafe(prompt: str) -> bool:
|
||||
"""Heuristic prompt classifier: marks prompt unsafe based on regex patterns."""
|
||||
text = (prompt or "").lower()
|
||||
for pat in UNSAFE_PROMPT_PATTERNS:
|
||||
if re.search(pat, text):
|
||||
@ -322,6 +345,7 @@ def is_prompt_heuristic_unsafe(prompt: str) -> bool:
|
||||
|
||||
|
||||
def is_refusal_response(resp: str) -> bool:
|
||||
"""Heuristic response classifier: marks response as refusal based on regex patterns."""
|
||||
if not resp:
|
||||
return False
|
||||
for pat in REFUSAL_PATTERNS:
|
||||
@ -335,8 +359,10 @@ def is_refusal_response(resp: str) -> bool:
|
||||
# =========================
|
||||
def extract_label_from_text(txt: str) -> str:
|
||||
"""
|
||||
Извлекаем финальное решение Guard:
|
||||
идём с конца и ищем последнее 'safe' или 'unsafe'.
|
||||
Extract the final decision token from classifier output.
|
||||
Strategy:
|
||||
- scan tokens from the end and return the last occurrence of 'safe' or 'unsafe'
|
||||
- default to 'safe' if nothing is found
|
||||
"""
|
||||
tokens = (txt or "").lower().split()
|
||||
for token in reversed(tokens):
|
||||
@ -353,9 +379,11 @@ def extract_label_from_text(txt: str) -> str:
|
||||
# =========================
|
||||
def evaluate_dual_batch(batch, tokenizer, model):
|
||||
"""
|
||||
Assistant Behavior Evaluation (Variant C + эвристики):
|
||||
- Prompt: Guard + правила (dirty joke/sex/nude -> unsafe)
|
||||
- Response: Guard + правила (refusal -> safe)
|
||||
Dual evaluation:
|
||||
- Prompt classification: model + heuristic override (keyword rules -> unsafe)
|
||||
- Response classification: model + heuristic override (refusal patterns -> safe)
|
||||
|
||||
The model is asked to return 'safe' or 'unsafe' (optionally with a short category).
|
||||
"""
|
||||
prompt_texts = []
|
||||
for item in batch:
|
||||
@ -467,7 +495,7 @@ def main():
|
||||
print(f"[WARN] All files empty in {input_dir}. Skipping.")
|
||||
continue
|
||||
|
||||
# Чтобы при оценке нескольких mistral-sk run ничего не перетиралось:
|
||||
# For multiple mistral-sk run dirs, keep outputs separated per run
|
||||
if model_key == "mistral_sk":
|
||||
run_name = os.path.basename(os.path.abspath(input_dir))
|
||||
model_out_root = os.path.join(OUTPUT_ROOT, "mistral_sk", run_name)
|
||||
@ -476,7 +504,7 @@ def main():
|
||||
|
||||
ensure_dir(model_out_root)
|
||||
|
||||
# Глобальные метрики по текущему input_dir
|
||||
# Global metrics for the current input_dir
|
||||
global_prompt_safe = 0
|
||||
global_prompt_unsafe = 0
|
||||
global_response_safe = 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user