from dotenv import load_dotenv load_dotenv(".env.test", override=True) import os import pytest import httpx import json from unittest.mock import AsyncMock, MagicMock, patch from backend.tools.api.schemas import ( CourtSearch, CourtAutocomplete, JudgeSearch, JudgeAutocomplete, DecisionSearch, DecisionAutocomplete, ContractSearch, ContractAutocomplete, CivilProceedingsSearch, CivilProceedingsAutocomplete, AdminProceedingsSearch, AdminProceedingsAutocomplete, ExecutorSearch, ExecutorAutocomplete, ) from backend.agent.sys_prompt import get_system_prompt from backend.agent.agent import build_agent, make_mcp_server from backend.agent.response import stream_response #################################################################################################################### # for test_schemas.py #################################################################################################################### @pytest.fixture def all_search() -> list: return [ CourtSearch, JudgeSearch, DecisionSearch, ContractSearch, CivilProceedingsSearch, AdminProceedingsSearch, ExecutorSearch ] @pytest.fixture def all_autocomplete() -> list: return [ CourtAutocomplete, JudgeAutocomplete, DecisionAutocomplete, ContractAutocomplete, CivilProceedingsAutocomplete, AdminProceedingsAutocomplete, ExecutorAutocomplete ] #################################################################################################################### # for test_http.py #################################################################################################################### @pytest.fixture def make_response() -> callable: def _factory(json_data: dict, status: int = 200) -> AsyncMock: mock = AsyncMock() mock.json = MagicMock(return_value=json_data) mock.status_code = status mock.url = "https://api.com/result" mock.raise_for_status = MagicMock() return mock return _factory @pytest.fixture def mock_client(): inner = MagicMock() inner.get = AsyncMock() cm = AsyncMock() cm.__aenter__ = AsyncMock(return_value=inner) cm.__aexit__ = AsyncMock(return_value=False) with patch("backend.tools.api.http_request_handler.get_client", return_value=cm): yield inner @pytest.fixture(autouse=True) def clear_cache(request): from cachetools import TTLCache new_cache = TTLCache(maxsize=100, ttl=60) with patch( "backend.tools.api.http_request_handler.CACHE", new_cache ): yield new_cache #################################################################################################################### # for test_prompt.py #################################################################################################################### @pytest.fixture(scope="module") def sys_prompt() -> str: return get_system_prompt() #################################################################################################################### # for test_tools.py #################################################################################################################### @pytest.fixture def mock_http(): with patch( "backend.tools.mcp.factory.http_request", new_callable=AsyncMock ) as m: m.return_value = {"url": "https://test.com", "data": {}} yield m #################################################################################################################### # for test_format.py #################################################################################################################### @pytest.fixture def mock_agent_task(): async def fake_task(query, queue, messages): await queue.put({"type": "text", "data": "response"}) await queue.put(None) with patch("backend.routers.run_agent.run_agent_task", fake_task): yield #################################################################################################################### # E2E #################################################################################################################### LITELLM_BASE_URL = os.getenv("LITELLM_BASE_URL") LITELLM_API_KEY = os.getenv("LITELLM_API_KEY") DEFAULT_MODEL = os.getenv("DEFAULT_MODEL") LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", 300.0)) TEST_MODEL = os.getenv("TEST_MODEL") @pytest.fixture def run_agent() -> callable: async def _run(query: str) -> tuple[str, list[str], str | None]: response_text = "" actual_tools = [] error = None mcp_server = make_mcp_server() try: async with mcp_server: agent = build_agent(mcp_server=mcp_server, model_name=DEFAULT_MODEL) async for event in stream_response(agent, [{"role": "user", "content": query}]): if event["type"] == "text": response_text += event["data"] elif event["type"] == "tool_start": actual_tools.append(event["tool"]) elif event["type"] == "error": error = event["data"] except Exception as e: error = str(e) return response_text, actual_tools, error return _run @pytest.fixture def judge() -> callable: async def _judge(prompt: str) -> tuple[float, str]: async with httpx.AsyncClient(timeout=LLM_TIMEOUT) as client: resp = await client.post( f"{LITELLM_BASE_URL}/chat/completions", headers={"Authorization": f"Bearer {LITELLM_API_KEY}"}, json={ "model": TEST_MODEL, "messages": [{"role": "user", "content": prompt}], "temperature": 0, # "max_tokens": 512, }, ) data = resp.json() if "choices" not in data: print(f"DEBUG: LiteLLM Error Response: {data}") return 0.0, f"Judge failed: {data.get('error', 'Unknown error')}" raw = data["choices"][0]["message"].get("content") if raw is None: print(f"DEBUG: Judge returned empty content. Data: {data}") return 0.0, "Judge returned empty content" clean = raw.replace("```json", "").replace("```", "").strip() parsed = json.loads(clean) score = round((float(parsed["score"]) - 1) / 4, 3) return score, parsed["reason"] return _judge #################################################################################################################### # TEST SCENARIOS #################################################################################################################### COST_PER_1M = { "gpt-oss-120b": {"input": 0.180, "output": 0.800}, "llama-3.1-8b": {"input": 0.034, "output": 0.075}, "qwen-qwq-32b": {"input": 0.290, "output": 0.390}, "qwen3-235b": {"input": 0.600, "output": 1.200}, "gemini-2.5-flash": {"input": 0.092, "output": 2.500}, "gemini-2.5-pro": {"input": 0.522, "output": 10.000}, "deepseek-r1": {"input": 0.700, "output": 2.500}, } @pytest.fixture def calculate_cost() -> callable: def _cost(model: str, input_tokens: int, output_tokens: int) -> float: prices = COST_PER_1M.get(model, {"input": 0.0, "output": 0.0}) return ( input_tokens / 1_000_000 * prices["input"] + output_tokens / 1_000_000 * prices["output"] ) return _cost @pytest.fixture def judge_tools() -> callable: def _tools(expected: list[str], actual: list[str]) -> tuple[float, float, float]: if not expected and not actual: return 1.0, 1.0, 1.0 if not expected: return 0.0, 1.0, 0.0 if not actual: return 0.0, 0.0, 0.0 expected_set = set(expected) actual_set = set(actual) intersection = expected_set & actual_set precision = len(intersection) / len(actual_set) recall = len(intersection) / len(expected_set) f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return round(precision, 3), round(recall, 3), round(f1, 3) return _tools