legal-ai-assistant/tests/integration/test_tools.py

119 lines
3.9 KiB
Python

import pytest
from pydantic import ValidationError
from backend.tools.mcp.factory import create_tool
from backend.tools.api.schemas import (
CourtSearch,
CourtByID,
JudgeSearch,
)
class TestToolsMetadata:
"""
Check that the tool matches the schema
(name, description, and arguments).
"""
def test_tool_name_matches_schema(self) -> None:
tool = create_tool("/sud", CourtSearch)
assert tool.__name__ == CourtSearch.__name__
def test_tool_doc_matches_schema(self) -> None:
tool = create_tool("/sud", CourtSearch)
assert tool.__doc__ == CourtSearch.__doc__
def test_tool_signature_schema(self) -> None:
tool = create_tool("/sud", CourtSearch)
signature = list(tool.__signature__.parameters.keys())
schema_fields = list(CourtSearch.model_fields.keys())
assert signature == schema_fields
class TestToolCall:
"""
Check that the tool correctly forms and calls
the HTTP request (route, parameters, id, options).
"""
@pytest.mark.asyncio
async def test_tool_calls_http_request(self, mock_http) -> None:
tool = create_tool("/sud", CourtSearch)
await tool()
mock_http.assert_called_once()
@pytest.mark.asyncio
async def test_tool_passes_correct_route(self, mock_http) -> None:
tool = create_tool("/sud", CourtSearch)
await tool()
_, kwargs = mock_http.call_args
assert kwargs["route"] == "/sud"
@pytest.mark.asyncio
async def test_tool_passes_params(self, mock_http) -> None:
tool = create_tool("/sud", CourtSearch)
await tool(query="Košice")
_, kwargs = mock_http.call_args
assert kwargs["params"]["query"] == "Košice"
@pytest.mark.asyncio
async def test_tool_excludes_none_params(self, mock_http) -> None:
tool = create_tool("/sud", CourtSearch)
await tool(query="Košice")
_, kwargs = mock_http.call_args
assert "typSuduFacetFilter" not in kwargs["params"]
assert "krajFacetFilter" not in kwargs["params"]
@pytest.mark.asyncio
async def test_tool_empty_params(self, mock_http) -> None:
tool = create_tool("/sud", CourtSearch)
await tool()
_, kwargs = mock_http.call_args
assert kwargs["params"] == {}
@pytest.mark.asyncio
async def test_id_in_params(self, mock_http) -> None:
tool = create_tool("/sud/{id}", CourtByID)
await tool(id="sud_175")
_, kwargs = mock_http.call_args
assert kwargs["route"] == "/sud/sud_175"
@pytest.mark.asyncio
async def test_id_not_in_params(self, mock_http) -> None:
tool = create_tool("/sud/{id}", CourtByID)
await tool(id="sud_175")
_, kwargs = mock_http.call_args
assert "id" not in kwargs["params"]
@pytest.mark.asyncio
async def test_without_remove_keys(self, mock_http) -> None:
tool = create_tool("/sudca", JudgeSearch)
await tool()
_, kwargs = mock_http.call_args
assert kwargs["remove_keys"] is None
@pytest.mark.asyncio
async def test_with_remove_keys(self, mock_http) -> None:
tool = create_tool("/sudca", JudgeSearch, remove_keys=["sudcaMapList"])
await tool()
_, kwargs = mock_http.call_args
assert kwargs["remove_keys"] == ["sudcaMapList"]
class TestToolValidationSchemas:
"""
Check that invalid parameters stop execution with
an error, and invalid ones pass and call HTTP.
"""
@pytest.mark.asyncio
async def test_invalid_params(self, mock_http) -> None:
tool = create_tool("/sudca", JudgeSearch)
with pytest.raises(ValidationError):
await tool(page=-1)
@pytest.mark.asyncio
async def test_valid_params(self, mock_http) -> None:
tool = create_tool("/sudca", JudgeSearch)
await tool(query="Novák", page=0, size=10)
mock_http.assert_called_once()