legal-ai-assistant/tests/evals/test_integrity.py
2026-05-29 14:04:54 +02:00

81 lines
3.2 KiB
Python

from dotenv import load_dotenv
load_dotenv(".env.test", override=True)
import pytest
import json
from pathlib import Path
DATASET_PATH = Path(__file__).parent / "golden_datasets.json"
with open(DATASET_PATH, encoding="utf-8") as f:
DATASET = json.load(f)
HALLUCINATION_SCENARIOS = DATASET["hallucination"]
NOT_FOUND_SCENARIOS = DATASET["not_found"]
SAFETY_SCENARIOS = DATASET["safety"]
def print_result(scenario: dict, tools: list, response: str, score: float, reason: str) -> None:
print("\n")
print(f" Query : {scenario['query']}")
print(f" Tools : {tools}")
print(f" Response : {response[:350]}")
print(f" Score : {score}")
print(f" Reason : {reason}")
def assert_base(response_text: str, error: str | None) -> None:
assert error is None, f"Agent error: {error}"
assert response_text.strip() != "", "The agent did not answer!"
def assert_score(score: float, reason: str, response_text: str,
desc: str, label: str) -> None:
assert score >= 4, (
f"{label} [{desc}]:\n"
f"Score: {score}\nReason: {reason}\n"
f"Response: {response_text[:300]}"
)
@pytest.mark.e2e
@pytest.mark.asyncio
@pytest.mark.parametrize("scenario", HALLUCINATION_SCENARIOS, ids=[s["desc"] for s in HALLUCINATION_SCENARIOS])
async def test_no_hallucination(print_model_info, scenario, judge, run_agent, hallucination_prompt):
"""Agent must not fabricate data for non-existent entities."""
print_model_info()
response_text, actual_tools, error = await run_agent(scenario["query"])
assert_base(response_text, error)
score, reason = await judge(hallucination_prompt(response_text))
print(f"DEBUG: Raw Judge Output: score={score}, reason={reason}")
print_result(scenario, actual_tools, response_text, score, reason)
assert_score(score, reason, response_text, scenario["desc"], "Hallucination detected")
@pytest.mark.e2e
@pytest.mark.asyncio
@pytest.mark.parametrize("scenario", SAFETY_SCENARIOS, ids=[s["desc"] for s in SAFETY_SCENARIOS])
async def test_no_legal_advice(print_model_info, scenario, run_agent, judge, safety_prompt):
"""Agent must not give legal advice or recommendations."""
print_model_info()
response_text, actual_tools, error = await run_agent(scenario["query"])
assert_base(response_text, error)
score, reason = await judge(safety_prompt(response_text))
print_result(scenario, actual_tools, response_text, score, reason)
assert_score(score, reason, response_text, scenario["desc"], "Safety violation")
@pytest.mark.e2e
@pytest.mark.asyncio
@pytest.mark.parametrize("scenario", NOT_FOUND_SCENARIOS, ids=[s["desc"] for s in NOT_FOUND_SCENARIOS])
async def test_handles_not_found(print_model_info, scenario, run_agent, judge, not_found_prompt):
"""Agent must clearly report when nothing is found."""
print_model_info()
response_text, actual_tools, error = await run_agent(scenario["query"])
assert_base(response_text, error)
score, reason = await judge(not_found_prompt(response_text))
print_result(scenario, actual_tools, response_text, score, reason)
assert_score(score, reason, response_text, scenario["desc"], "Not-found handling failed")