ai-lawyer-agent/testing/tests/test_project.py
2026-03-23 09:07:12 +01:00

244 lines
8.4 KiB
Python

import json
import os
import re
import sqlite3
import pytest
from openai import OpenAI
from core.config import OLLAMA_BASE_URL, OLLAMA_API_KEY, LLM_TIMEOUT, DEFAULT_MODEL
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "test_cases.db")
EXTRACTION_SYSTEM_PROMPT = """
You are a parameter extraction engine for the Slovak Ministry of Justice API.
Your ONLY job: read the user query and return a JSON object.
You MUST always return ONLY a JSON object — nothing else.
No explanations. No markdown. No ```json fences. Just the raw JSON.
Return format:
{"tool": "<tool_name>", "params": {<extracted parameters>}}
Available tools and their parameters:
court_search : query, typSuduFacetFilter[], krajFacetFilter[], okresFacetFilter[],
zahrnutZaniknuteSudy, sortProperty, sortDirection, page, size
court_id : id (format: "sud_<number>")
court_autocomplete : query, limit
judge_search : query, funkciaFacetFilter[], typSuduFacetFilter[], krajFacetFilter[],
okresFacetFilter[], stavZapisuFacetFilter[], guidSud, page, size,
sortProperty, sortDirection
judge_id : id (format: "sudca_<number>")
judge_autocomplete : query, guidSud, limit
decision_search : query, typSuduFacetFilter[], krajFacetFilter[], okresFacetFilter[],
formaRozhodnutiaFacetFilter[], vydaniaOd, vydaniaDo,
ecli, spisovaZnacka, guidSudca, guidSud, sortProperty, sortDirection, page, size
decision_id : id (ECLI string, e.g. "ECLI:SK:OSPO:1965:8114010264.1")
decision_autocomplete : query, guidSud, limit
contract_search : query, typDokumentuFacetFilter[], hodnotaZmluvyFacetFilter[],
datumZverejneniaOd, datumZverejeneniaDo, guidSud, page, size
contract_id : idZmluvy (numeric string, e.g. "2156252")
contract_autocomplete : query, guidSud, limit
civil_proceedings_search : query, krajFacetFilter[], usekFacetFilter[],
formaUkonuFacetFilter[], pojednavaniaOd, pojednavaniaDo,
guidSudca, guidSud, verejneVyhlasenie, page, size
civil_proceedings_id : id (UUID string)
civil_proceedings_autocomplete : query, guidSud, guidSudca, verejneVyhlasenie, limit
admin_proceedings_search : query, druhFacetFilter[], datumPravoplatnostiOd,
datumPravoplatnostiDo, sortProperty, sortDirection, page, size
admin_proceedings_id : id (format: "spravneKonanie_<number>")
admin_proceedings_autocomplete : query, limit
Rules:
- Dates MUST be in DD.MM.YYYY format.
- IDs MUST use the correct prefix (sud_, sudca_, spravneKonanie_).
- Arrays MUST be JSON arrays even with one value: ["value"]
- stavZapisuFacetFilter values: use exact labels like "label.sudca.aktivny"
- If a number is given without prefix (e.g. "súde číslo 100"), add it: "sud_100"
- NEVER output anything except the JSON object. No thinking, no prose.
"""
def ollama_available() -> bool:
try:
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY, timeout=5)
client.models.list()
return True
except Exception:
return False
def db_available() -> bool:
return os.path.exists(DB_PATH)
def load_cases():
conn = sqlite3.connect(DB_PATH)
rows = conn.execute(
"SELECT id, query, expected FROM test_cases ORDER BY id"
).fetchall()
conn.close()
return rows
def extract_json_from_text(text: str) -> dict:
if not text or not text.strip():
raise ValueError("LLM returned empty response")
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
text = re.sub(r"```(?:json)?", "", text).replace("```", "").strip()
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
raise ValueError(
f"No JSON object found in LLM response. "
f"Raw text (first 300 chars): {text[:300]!r}"
)
return json.loads(match.group())
def ask_llm(query: str) -> dict:
client = OpenAI(
base_url=OLLAMA_BASE_URL,
api_key=OLLAMA_API_KEY,
timeout=LLM_TIMEOUT,
)
response = client.chat.completions.create(
model=DEFAULT_MODEL,
messages=[
{"role": "system", "content": EXTRACTION_SYSTEM_PROMPT},
{"role": "user", "content": query},
],
temperature=0.0,
max_tokens=1024,
)
choice = response.choices[0]
raw = choice.message.content or ""
if not raw.strip():
try:
raw = choice.message.model_extra.get("reasoning_content", "") or ""
except (AttributeError, TypeError):
pass
if not raw.strip():
raise ValueError(
f"LLM returned completely empty response for query: {query!r}. "
"Check if Ollama is running and the model is loaded."
)
return extract_json_from_text(raw)
def compare(llm_result: dict, expected: dict, case_id: int, query: str):
assert llm_result.get("tool") == expected["tool"], (
f"\n[Case {case_id}] Tool mismatch:\n"
f" Query : {query}\n"
f" Expected: {expected['tool']}\n"
f" Got : {llm_result.get('tool')}\n"
f" Full LLM: {json.dumps(llm_result, ensure_ascii=False)}"
)
llm_params = llm_result.get("params", {})
exp_params = expected.get("params", {})
for key, exp_val in exp_params.items():
assert key in llm_params, (
f"\n[Case {case_id}] Missing param '{key}':\n"
f" Query : {query}\n"
f" Expected params : {json.dumps(exp_params, ensure_ascii=False)}\n"
f" LLM params : {json.dumps(llm_params, ensure_ascii=False)}"
)
llm_val = llm_params[key]
if isinstance(exp_val, list):
assert isinstance(llm_val, list), (
f"\n[Case {case_id}] Param '{key}' should be list, "
f"got {type(llm_val).__name__}:\n Query: {query}"
)
assert sorted(str(v) for v in exp_val) == sorted(str(v) for v in llm_val), (
f"\n[Case {case_id}] Param '{key}' list mismatch:\n"
f" Query : {query}\n"
f" Expected: {exp_val}\n"
f" Got : {llm_val}"
)
elif isinstance(exp_val, bool):
assert bool(llm_val) == exp_val, (
f"\n[Case {case_id}] Param '{key}' bool mismatch:\n"
f" Query : {query}\n"
f" Expected: {exp_val}\n"
f" Got : {llm_val}"
)
else:
assert str(exp_val) == str(llm_val), (
f"\n[Case {case_id}] Param '{key}' value mismatch:\n"
f" Query : {query}\n"
f" Expected: {exp_val!r}\n"
f" Got : {llm_val!r}"
)
skip_if_no_ollama = pytest.mark.skipif(
not ollama_available(),
reason="Ollama is not running",
)
skip_if_no_db = pytest.mark.skipif(
not db_available(),
reason=f"Database not found: {DB_PATH}. Copy test_cases.db to testing/",
)
def pytest_generate_tests(metafunc):
if "db_case" in metafunc.fixturenames:
if db_available():
cases = load_cases()
metafunc.parametrize(
"db_case",
cases,
ids=[f"{row[0]:02d}" for row in cases],
)
else:
metafunc.parametrize("db_case", [])
# ── tests ─────────────────────────────────────────────────────────────────────
@skip_if_no_ollama
@skip_if_no_db
def test_llm_extracts_params(db_case):
case_id, query, expected_raw = db_case
expected = json.loads(expected_raw)
llm_result = ask_llm(query)
compare(llm_result, expected, case_id, query)
@skip_if_no_db
def test_db_has_54_rows():
cases = load_cases()
assert len(cases) == 54, f"Expected 54 rows, got {len(cases)}"
@skip_if_no_db
def test_db_columns_are_valid():
cases = load_cases()
for case_id, query, expected_raw in cases:
assert query.strip(), f"Row {case_id}: empty query"
try:
expected = json.loads(expected_raw)
except json.JSONDecodeError as e:
pytest.fail(f"Row {case_id}: invalid JSON in expected — {e}")
assert "tool" in expected, f"Row {case_id}: missing 'tool' in expected"
assert "params" in expected, f"Row {case_id}: missing 'params' in expected"