217 lines
8.1 KiB
Python
217 lines
8.1 KiB
Python
from dotenv import load_dotenv
|
|
load_dotenv(".env.test", override=True)
|
|
|
|
import os
|
|
import pytest
|
|
import httpx
|
|
import json
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from backend.tools.api.schemas import (
|
|
CourtSearch, CourtAutocomplete, JudgeSearch, JudgeAutocomplete,
|
|
DecisionSearch, DecisionAutocomplete, ContractSearch, ContractAutocomplete,
|
|
CivilProceedingsSearch, CivilProceedingsAutocomplete,
|
|
AdminProceedingsSearch, AdminProceedingsAutocomplete,
|
|
ExecutorSearch, ExecutorAutocomplete,
|
|
)
|
|
from backend.agent.sys_prompt import get_system_prompt
|
|
from backend.agent.agent import build_agent, make_mcp_server
|
|
from backend.agent.response import stream_response
|
|
|
|
####################################################################################################################
|
|
# for test_schemas.py
|
|
####################################################################################################################
|
|
|
|
@pytest.fixture
|
|
def all_search() -> list:
|
|
return [
|
|
CourtSearch, JudgeSearch, DecisionSearch, ContractSearch,
|
|
CivilProceedingsSearch, AdminProceedingsSearch,
|
|
ExecutorSearch
|
|
]
|
|
|
|
@pytest.fixture
|
|
def all_autocomplete() -> list:
|
|
return [
|
|
CourtAutocomplete, JudgeAutocomplete, DecisionAutocomplete,
|
|
ContractAutocomplete, CivilProceedingsAutocomplete,
|
|
AdminProceedingsAutocomplete, ExecutorAutocomplete
|
|
]
|
|
|
|
####################################################################################################################
|
|
# for test_http.py
|
|
####################################################################################################################
|
|
|
|
@pytest.fixture
|
|
def make_response() -> callable:
|
|
def _factory(json_data: dict, status: int = 200) -> AsyncMock:
|
|
mock = AsyncMock()
|
|
mock.json = MagicMock(return_value=json_data)
|
|
mock.status_code = status
|
|
mock.url = "https://api.com/result"
|
|
mock.raise_for_status = MagicMock()
|
|
return mock
|
|
return _factory
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_client():
|
|
inner = MagicMock()
|
|
inner.get = AsyncMock()
|
|
cm = AsyncMock()
|
|
cm.__aenter__ = AsyncMock(return_value=inner)
|
|
cm.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
with patch("backend.tools.api.http_request_handler.get_client", return_value=cm):
|
|
yield inner
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_cache(request):
|
|
|
|
from cachetools import TTLCache
|
|
new_cache = TTLCache(maxsize=100, ttl=60)
|
|
with patch(
|
|
"backend.tools.api.http_request_handler.CACHE",
|
|
new_cache
|
|
):
|
|
yield new_cache
|
|
|
|
####################################################################################################################
|
|
# for test_prompt.py
|
|
####################################################################################################################
|
|
|
|
@pytest.fixture(scope="module")
|
|
def sys_prompt() -> str:
|
|
return get_system_prompt()
|
|
|
|
####################################################################################################################
|
|
# for test_tools.py
|
|
####################################################################################################################
|
|
|
|
@pytest.fixture
|
|
def mock_http():
|
|
with patch(
|
|
"backend.tools.mcp.factory.http_request",
|
|
new_callable=AsyncMock
|
|
) as m:
|
|
m.return_value = {"url": "https://test.com", "data": {}}
|
|
yield m
|
|
|
|
####################################################################################################################
|
|
# for test_format.py
|
|
####################################################################################################################
|
|
|
|
@pytest.fixture
|
|
def mock_agent_task():
|
|
async def fake_task(query, queue, messages):
|
|
await queue.put({"type": "text", "data": "response"})
|
|
await queue.put(None)
|
|
|
|
with patch("backend.routers.run_agent.run_agent_task", fake_task):
|
|
yield
|
|
|
|
####################################################################################################################
|
|
# E2E
|
|
####################################################################################################################
|
|
|
|
LITELLM_BASE_URL = os.getenv("LITELLM_BASE_URL")
|
|
LITELLM_API_KEY = os.getenv("LITELLM_API_KEY")
|
|
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL")
|
|
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", 300.0))
|
|
TEST_MODEL = os.getenv("TEST_MODEL")
|
|
|
|
@pytest.fixture
|
|
def run_agent() -> callable:
|
|
async def _run(query: str) -> tuple[str, list[str], str | None]:
|
|
response_text = ""
|
|
actual_tools = []
|
|
error = None
|
|
mcp_server = make_mcp_server()
|
|
try:
|
|
async with mcp_server:
|
|
agent = build_agent(mcp_server=mcp_server, model_name=DEFAULT_MODEL)
|
|
async for event in stream_response(agent, [{"role": "user", "content": query}]):
|
|
if event["type"] == "text":
|
|
response_text += event["data"]
|
|
elif event["type"] == "tool_start":
|
|
actual_tools.append(event["tool"])
|
|
elif event["type"] == "error":
|
|
error = event["data"]
|
|
except Exception as e:
|
|
error = str(e)
|
|
return response_text, actual_tools, error
|
|
return _run
|
|
|
|
@pytest.fixture
|
|
def judge() -> callable:
|
|
async def _judge(prompt: str) -> tuple[float, str]:
|
|
async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client:
|
|
resp = await client.post(
|
|
f"{LITELLM_BASE_URL}/chat/completions",
|
|
headers={"Authorization": f"Bearer {LITELLM_API_KEY}"},
|
|
json={
|
|
"model": TEST_MODEL,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"temperature": 0,
|
|
# "max_tokens": 512,
|
|
},
|
|
)
|
|
data = resp.json()
|
|
if "choices" not in data:
|
|
print(f"DEBUG: LiteLLM Error Response: {data}")
|
|
return 0.0, f"Judge failed: {data.get('error', 'Unknown error')}"
|
|
|
|
raw = data["choices"][0]["message"].get("content")
|
|
if raw is None:
|
|
print(f"DEBUG: Judge returned empty content. Data: {data}")
|
|
return 0.0, "Judge returned empty content"
|
|
|
|
clean = raw.replace("```json", "").replace("```", "").strip()
|
|
parsed = json.loads(clean)
|
|
score = round((float(parsed["score"]) - 1) / 4, 3)
|
|
return score, parsed["reason"]
|
|
return _judge
|
|
|
|
####################################################################################################################
|
|
# TEST SCENARIOS
|
|
####################################################################################################################
|
|
|
|
COST_PER_1M = {
|
|
"gpt-oss-120b": {"input": 0.180, "output": 0.800},
|
|
"llama-3.1-8b": {"input": 0.034, "output": 0.075},
|
|
"qwen-qwq-32b": {"input": 0.290, "output": 0.390},
|
|
"qwen3-235b": {"input": 0.600, "output": 1.200},
|
|
"gemini-2.5-flash": {"input": 0.092, "output": 2.500},
|
|
"gemini-2.5-pro": {"input": 0.522, "output": 10.000},
|
|
"deepseek-r1": {"input": 0.700, "output": 2.500},
|
|
}
|
|
|
|
@pytest.fixture
|
|
def calculate_cost() -> callable:
|
|
def _cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
|
prices = COST_PER_1M.get(model, {"input": 0.0, "output": 0.0})
|
|
return (
|
|
input_tokens / 1_000_000 * prices["input"] +
|
|
output_tokens / 1_000_000 * prices["output"]
|
|
)
|
|
return _cost
|
|
|
|
@pytest.fixture
|
|
def judge_tools() -> callable:
|
|
def _tools(expected: list[str], actual: list[str]) -> tuple[float, float, float]:
|
|
if not expected and not actual:
|
|
return 1.0, 1.0, 1.0
|
|
if not expected:
|
|
return 0.0, 1.0, 0.0
|
|
if not actual:
|
|
return 0.0, 0.0, 0.0
|
|
expected_set = set(expected)
|
|
actual_set = set(actual)
|
|
intersection = expected_set & actual_set
|
|
precision = len(intersection) / len(actual_set)
|
|
recall = len(intersection) / len(expected_set)
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
|
|
return round(precision, 3), round(recall, 3), round(f1, 3)
|
|
return _tools
|