legal-ai-assistant/tests/evals/test_scenarios.py

116 lines
4.0 KiB
Python

from dotenv import load_dotenv
load_dotenv(".env.test", override=True)
import os
import time
import json
import pytest
from pathlib import Path
from backend.agent.agent import build_agent, make_mcp_server
from backend.agent.response import stream_response
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL")
REQUESTS_PATH = Path(__file__).parent / "requests.json"
with open(REQUESTS_PATH, encoding="utf-8") as f:
SCENARIOS = json.load(f)
class ScenarioStats:
def __init__(self):
self.actual_tools = []
self.response_text = ""
self.input_tokens = 0
self.output_tokens = 0
self.error = None
class TotalStats:
total_time = 0.0
total_input_tokens = 0
total_output_tokens = 0
total_cost = 0.0
scenarios_count = 0
def print_report(scenario: dict, stats: ScenarioStats, cost: float,
elapsed: float, f1: float, precision: float, recall: float) -> None:
sum_token = stats.input_tokens + stats.output_tokens
avg_time = TotalStats.total_time / TotalStats.scenarios_count
print(f"\n{''*60}")
print(f"\tResource : {scenario['resource']} / {scenario['level']}")
print(f"\tQuery : {scenario['query'][:100]}...")
print(f"\tExpected : {scenario.get('expected_tools', [])}")
print(f"\tActual : {stats.actual_tools}")
print(f"\tF1 : {f1:.2f} precision={precision:.2f} recall={recall:.2f}")
print(f"\tTokens : input={stats.input_tokens} output={stats.output_tokens} sum={sum_token}")
print(f"\tCost : ${cost:.6f}")
print(f"\tElapsed : {elapsed}s")
print(f"\tResponse : {stats.response_text[:100].replace(chr(10), ' ')}...")
print(f"\n\t[TOTAL PROGRESS | Scenarios: {TotalStats.scenarios_count}]")
print(f"\tAccumulated Time : {TotalStats.total_time:.2f}s (avg: {avg_time:.2f}s/req)")
print(f"\tAccumulated Cost : ${TotalStats.total_cost:.6f}")
print(f"\tAccumulated Tokens: In={TotalStats.total_input_tokens} Out={TotalStats.total_output_tokens}")
if stats.error:
print(f"\tERROR : {stats.error}")
print(f"{''*60}")
@pytest.mark.evals
@pytest.mark.asyncio
@pytest.mark.parametrize(
"scenario",
SCENARIOS,
ids=[f"{s['resource']}-{s['level']}" for s in SCENARIOS]
)
async def test_scenarios(scenario: dict, judge_tools, calculate_cost) -> None:
stats = ScenarioStats()
query = scenario["query"]
expected_tools = scenario.get("expected_tools", [])
mcp_server = make_mcp_server()
pure_agent_time = 0.0
try:
async with mcp_server:
agent = build_agent(mcp_server=mcp_server, model_name=DEFAULT_MODEL)
async for event in stream_response(agent, [{"role": "user", "content": query}]):
if event["type"] == "text":
stats.response_text += event["data"]
elif event["type"] == "tool_start":
stats.actual_tools.append(event["tool"])
elif event["type"] == "usage":
stats.input_tokens += event["input_tokens"]
stats.output_tokens += event["output_tokens"]
pure_agent_time = event.get("pure_duration", 0.0)
elif event["type"] == "error":
stats.error = event["data"]
except Exception as e:
stats.error = str(e)
elapsed = round(pure_agent_time, 2)
cost = calculate_cost(DEFAULT_MODEL, stats.input_tokens, stats.output_tokens)
precision, recall, f1 = judge_tools(expected_tools, stats.actual_tools)
TotalStats.total_time += elapsed
TotalStats.total_input_tokens += stats.input_tokens
TotalStats.total_output_tokens += stats.output_tokens
TotalStats.total_cost += cost
TotalStats.scenarios_count += 1
print_report(
scenario=scenario,
stats=stats,
cost=cost,
elapsed=elapsed,
f1=f1,
precision=precision,
recall=recall
)
assert stats.error is None, f"Agent error [{scenario['resource']} {scenario['level']}]: {stats.error}"