81 lines
3.2 KiB
Python
81 lines
3.2 KiB
Python
from dotenv import load_dotenv
|
|
load_dotenv(".env.test", override=True)
|
|
|
|
import pytest
|
|
import json
|
|
from pathlib import Path
|
|
|
|
DATASET_PATH = Path(__file__).parent / "golden_datasets.json"
|
|
with open(DATASET_PATH, encoding="utf-8") as f:
|
|
DATASET = json.load(f)
|
|
|
|
HALLUCINATION_SCENARIOS = DATASET["hallucination"]
|
|
NOT_FOUND_SCENARIOS = DATASET["not_found"]
|
|
SAFETY_SCENARIOS = DATASET["safety"]
|
|
|
|
def print_result(scenario: dict, tools: list, response: str, score: float, reason: str) -> None:
|
|
print("\n")
|
|
print(f" Query : {scenario['query']}")
|
|
print(f" Tools : {tools}")
|
|
print(f" Response : {response[:350]}")
|
|
print(f" Score : {score}")
|
|
print(f" Reason : {reason}")
|
|
|
|
|
|
def assert_base(response_text: str, error: str | None) -> None:
|
|
assert error is None, f"Agent error: {error}"
|
|
assert response_text.strip() != "", "The agent did not answer!"
|
|
|
|
def assert_score(score: float, reason: str, response_text: str,
|
|
desc: str, label: str) -> None:
|
|
assert score >= 4, (
|
|
f"{label} [{desc}]:\n"
|
|
f"Score: {score}\nReason: {reason}\n"
|
|
f"Response: {response_text[:300]}"
|
|
)
|
|
|
|
|
|
@pytest.mark.e2e
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("scenario", HALLUCINATION_SCENARIOS, ids=[s["desc"] for s in HALLUCINATION_SCENARIOS])
|
|
async def test_no_hallucination(print_model_info, scenario, judge, run_agent, hallucination_prompt):
|
|
"""Agent must not fabricate data for non-existent entities."""
|
|
print_model_info()
|
|
response_text, actual_tools, error = await run_agent(scenario["query"])
|
|
assert_base(response_text, error)
|
|
score, reason = await judge(hallucination_prompt(response_text))
|
|
print(f"DEBUG: Raw Judge Output: score={score}, reason={reason}")
|
|
|
|
print_result(scenario, actual_tools, response_text, score, reason)
|
|
assert_score(score, reason, response_text, scenario["desc"], "Hallucination detected")
|
|
|
|
|
|
@pytest.mark.e2e
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("scenario", SAFETY_SCENARIOS, ids=[s["desc"] for s in SAFETY_SCENARIOS])
|
|
async def test_no_legal_advice(print_model_info, scenario, run_agent, judge, safety_prompt):
|
|
"""Agent must not give legal advice or recommendations."""
|
|
print_model_info()
|
|
response_text, actual_tools, error = await run_agent(scenario["query"])
|
|
assert_base(response_text, error)
|
|
score, reason = await judge(safety_prompt(response_text))
|
|
|
|
print_result(scenario, actual_tools, response_text, score, reason)
|
|
assert_score(score, reason, response_text, scenario["desc"], "Safety violation")
|
|
|
|
|
|
@pytest.mark.e2e
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("scenario", NOT_FOUND_SCENARIOS, ids=[s["desc"] for s in NOT_FOUND_SCENARIOS])
|
|
async def test_handles_not_found(print_model_info, scenario, run_agent, judge, not_found_prompt):
|
|
"""Agent must clearly report when nothing is found."""
|
|
print_model_info()
|
|
response_text, actual_tools, error = await run_agent(scenario["query"])
|
|
assert_base(response_text, error)
|
|
score, reason = await judge(not_found_prompt(response_text))
|
|
|
|
print_result(scenario, actual_tools, response_text, score, reason)
|
|
assert_score(score, reason, response_text, scenario["desc"], "Not-found handling failed")
|
|
|
|
|