legal-ai-assistant/tests/conftest.py

217 lines
8.1 KiB
Python

from dotenv import load_dotenv
load_dotenv(".env.test", override=True)
import os
import pytest
import httpx
import json
from unittest.mock import AsyncMock, MagicMock, patch
from backend.tools.api.schemas import (
CourtSearch, CourtAutocomplete, JudgeSearch, JudgeAutocomplete,
DecisionSearch, DecisionAutocomplete, ContractSearch, ContractAutocomplete,
CivilProceedingsSearch, CivilProceedingsAutocomplete,
AdminProceedingsSearch, AdminProceedingsAutocomplete,
ExecutorSearch, ExecutorAutocomplete,
)
from backend.agent.sys_prompt import get_system_prompt
from backend.agent.agent import build_agent, make_mcp_server
from backend.agent.response import stream_response
####################################################################################################################
# for test_schemas.py
####################################################################################################################
@pytest.fixture
def all_search() -> list:
return [
CourtSearch, JudgeSearch, DecisionSearch, ContractSearch,
CivilProceedingsSearch, AdminProceedingsSearch,
ExecutorSearch
]
@pytest.fixture
def all_autocomplete() -> list:
return [
CourtAutocomplete, JudgeAutocomplete, DecisionAutocomplete,
ContractAutocomplete, CivilProceedingsAutocomplete,
AdminProceedingsAutocomplete, ExecutorAutocomplete
]
####################################################################################################################
# for test_http.py
####################################################################################################################
@pytest.fixture
def make_response() -> callable:
def _factory(json_data: dict, status: int = 200) -> AsyncMock:
mock = AsyncMock()
mock.json = MagicMock(return_value=json_data)
mock.status_code = status
mock.url = "https://api.com/result"
mock.raise_for_status = MagicMock()
return mock
return _factory
@pytest.fixture
def mock_client():
inner = MagicMock()
inner.get = AsyncMock()
cm = AsyncMock()
cm.__aenter__ = AsyncMock(return_value=inner)
cm.__aexit__ = AsyncMock(return_value=False)
with patch("backend.tools.api.http_request_handler.get_client", return_value=cm):
yield inner
@pytest.fixture(autouse=True)
def clear_cache(request):
from cachetools import TTLCache
new_cache = TTLCache(maxsize=100, ttl=60)
with patch(
"backend.tools.api.http_request_handler.CACHE",
new_cache
):
yield new_cache
####################################################################################################################
# for test_prompt.py
####################################################################################################################
@pytest.fixture(scope="module")
def sys_prompt() -> str:
return get_system_prompt()
####################################################################################################################
# for test_tools.py
####################################################################################################################
@pytest.fixture
def mock_http():
with patch(
"backend.tools.mcp.factory.http_request",
new_callable=AsyncMock
) as m:
m.return_value = {"url": "https://test.com", "data": {}}
yield m
####################################################################################################################
# for test_format.py
####################################################################################################################
@pytest.fixture
def mock_agent_task():
async def fake_task(query, queue, messages):
await queue.put({"type": "text", "data": "response"})
await queue.put(None)
with patch("backend.routers.run_agent.run_agent_task", fake_task):
yield
####################################################################################################################
# E2E
####################################################################################################################
LITELLM_BASE_URL = os.getenv("LITELLM_BASE_URL")
LITELLM_API_KEY = os.getenv("LITELLM_API_KEY")
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL")
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", 300.0))
TEST_MODEL = os.getenv("TEST_MODEL")
@pytest.fixture
def run_agent() -> callable:
async def _run(query: str) -> tuple[str, list[str], str | None]:
response_text = ""
actual_tools = []
error = None
mcp_server = make_mcp_server()
try:
async with mcp_server:
agent = build_agent(mcp_server=mcp_server, model_name=DEFAULT_MODEL)
async for event in stream_response(agent, [{"role": "user", "content": query}]):
if event["type"] == "text":
response_text += event["data"]
elif event["type"] == "tool_start":
actual_tools.append(event["tool"])
elif event["type"] == "error":
error = event["data"]
except Exception as e:
error = str(e)
return response_text, actual_tools, error
return _run
@pytest.fixture
def judge() -> callable:
async def _judge(prompt: str) -> tuple[float, str]:
async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client:
resp = await client.post(
f"{LITELLM_BASE_URL}/chat/completions",
headers={"Authorization": f"Bearer {LITELLM_API_KEY}"},
json={
"model": TEST_MODEL,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
# "max_tokens": 512,
},
)
data = resp.json()
if "choices" not in data:
print(f"DEBUG: LiteLLM Error Response: {data}")
return 0.0, f"Judge failed: {data.get('error', 'Unknown error')}"
raw = data["choices"][0]["message"].get("content")
if raw is None:
print(f"DEBUG: Judge returned empty content. Data: {data}")
return 0.0, "Judge returned empty content"
clean = raw.replace("```json", "").replace("```", "").strip()
parsed = json.loads(clean)
score = round((float(parsed["score"]) - 1) / 4, 3)
return score, parsed["reason"]
return _judge
####################################################################################################################
# TEST SCENARIOS
####################################################################################################################
COST_PER_1M = {
"gpt-oss-120b": {"input": 0.180, "output": 0.800},
"llama-3.1-8b": {"input": 0.034, "output": 0.075},
"qwen-qwq-32b": {"input": 0.290, "output": 0.390},
"qwen3-235b": {"input": 0.600, "output": 1.200},
"gemini-2.5-flash": {"input": 0.092, "output": 2.500},
"gemini-2.5-pro": {"input": 0.522, "output": 10.000},
"deepseek-r1": {"input": 0.700, "output": 2.500},
}
@pytest.fixture
def calculate_cost() -> callable:
def _cost(model: str, input_tokens: int, output_tokens: int) -> float:
prices = COST_PER_1M.get(model, {"input": 0.0, "output": 0.0})
return (
input_tokens / 1_000_000 * prices["input"] +
output_tokens / 1_000_000 * prices["output"]
)
return _cost
@pytest.fixture
def judge_tools() -> callable:
def _tools(expected: list[str], actual: list[str]) -> tuple[float, float, float]:
if not expected and not actual:
return 1.0, 1.0, 1.0
if not expected:
return 0.0, 1.0, 0.0
if not actual:
return 0.0, 0.0, 0.0
expected_set = set(expected)
actual_set = set(actual)
intersection = expected_set & actual_set
precision = len(intersection) / len(actual_set)
recall = len(intersection) / len(expected_set)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return round(precision, 3), round(recall, 3), round(f1, 3)
return _tools