244 lines
8.4 KiB
Python
244 lines
8.4 KiB
Python
import json
|
|
import os
|
|
import re
|
|
import sqlite3
|
|
|
|
import pytest
|
|
from openai import OpenAI
|
|
|
|
from core.config import OLLAMA_BASE_URL, OLLAMA_API_KEY, LLM_TIMEOUT, DEFAULT_MODEL
|
|
|
|
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "test_cases.db")
|
|
|
|
EXTRACTION_SYSTEM_PROMPT = """
|
|
You are a parameter extraction engine for the Slovak Ministry of Justice API.
|
|
|
|
Your ONLY job: read the user query and return a JSON object.
|
|
You MUST always return ONLY a JSON object — nothing else.
|
|
No explanations. No markdown. No ```json fences. Just the raw JSON.
|
|
|
|
Return format:
|
|
{"tool": "<tool_name>", "params": {<extracted parameters>}}
|
|
|
|
Available tools and their parameters:
|
|
|
|
court_search : query, typSuduFacetFilter[], krajFacetFilter[], okresFacetFilter[],
|
|
zahrnutZaniknuteSudy, sortProperty, sortDirection, page, size
|
|
court_id : id (format: "sud_<number>")
|
|
court_autocomplete : query, limit
|
|
|
|
judge_search : query, funkciaFacetFilter[], typSuduFacetFilter[], krajFacetFilter[],
|
|
okresFacetFilter[], stavZapisuFacetFilter[], guidSud, page, size,
|
|
sortProperty, sortDirection
|
|
judge_id : id (format: "sudca_<number>")
|
|
judge_autocomplete : query, guidSud, limit
|
|
|
|
decision_search : query, typSuduFacetFilter[], krajFacetFilter[], okresFacetFilter[],
|
|
formaRozhodnutiaFacetFilter[], vydaniaOd, vydaniaDo,
|
|
ecli, spisovaZnacka, guidSudca, guidSud, sortProperty, sortDirection, page, size
|
|
decision_id : id (ECLI string, e.g. "ECLI:SK:OSPO:1965:8114010264.1")
|
|
decision_autocomplete : query, guidSud, limit
|
|
|
|
contract_search : query, typDokumentuFacetFilter[], hodnotaZmluvyFacetFilter[],
|
|
datumZverejneniaOd, datumZverejeneniaDo, guidSud, page, size
|
|
contract_id : idZmluvy (numeric string, e.g. "2156252")
|
|
contract_autocomplete : query, guidSud, limit
|
|
|
|
civil_proceedings_search : query, krajFacetFilter[], usekFacetFilter[],
|
|
formaUkonuFacetFilter[], pojednavaniaOd, pojednavaniaDo,
|
|
guidSudca, guidSud, verejneVyhlasenie, page, size
|
|
civil_proceedings_id : id (UUID string)
|
|
civil_proceedings_autocomplete : query, guidSud, guidSudca, verejneVyhlasenie, limit
|
|
|
|
admin_proceedings_search : query, druhFacetFilter[], datumPravoplatnostiOd,
|
|
datumPravoplatnostiDo, sortProperty, sortDirection, page, size
|
|
admin_proceedings_id : id (format: "spravneKonanie_<number>")
|
|
admin_proceedings_autocomplete : query, limit
|
|
|
|
Rules:
|
|
- Dates MUST be in DD.MM.YYYY format.
|
|
- IDs MUST use the correct prefix (sud_, sudca_, spravneKonanie_).
|
|
- Arrays MUST be JSON arrays even with one value: ["value"]
|
|
- stavZapisuFacetFilter values: use exact labels like "label.sudca.aktivny"
|
|
- If a number is given without prefix (e.g. "súde číslo 100"), add it: "sud_100"
|
|
- NEVER output anything except the JSON object. No thinking, no prose.
|
|
"""
|
|
|
|
def ollama_available() -> bool:
|
|
try:
|
|
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY, timeout=5)
|
|
client.models.list()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def db_available() -> bool:
|
|
return os.path.exists(DB_PATH)
|
|
|
|
|
|
def load_cases():
|
|
conn = sqlite3.connect(DB_PATH)
|
|
rows = conn.execute(
|
|
"SELECT id, query, expected FROM test_cases ORDER BY id"
|
|
).fetchall()
|
|
conn.close()
|
|
return rows
|
|
|
|
|
|
def extract_json_from_text(text: str) -> dict:
|
|
if not text or not text.strip():
|
|
raise ValueError("LLM returned empty response")
|
|
|
|
text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
|
|
|
|
text = re.sub(r"```(?:json)?", "", text).replace("```", "").strip()
|
|
|
|
match = re.search(r"\{.*\}", text, re.DOTALL)
|
|
if not match:
|
|
raise ValueError(
|
|
f"No JSON object found in LLM response. "
|
|
f"Raw text (first 300 chars): {text[:300]!r}"
|
|
)
|
|
|
|
return json.loads(match.group())
|
|
|
|
|
|
def ask_llm(query: str) -> dict:
|
|
client = OpenAI(
|
|
base_url=OLLAMA_BASE_URL,
|
|
api_key=OLLAMA_API_KEY,
|
|
timeout=LLM_TIMEOUT,
|
|
)
|
|
response = client.chat.completions.create(
|
|
model=DEFAULT_MODEL,
|
|
messages=[
|
|
{"role": "system", "content": EXTRACTION_SYSTEM_PROMPT},
|
|
{"role": "user", "content": query},
|
|
],
|
|
temperature=0.0,
|
|
max_tokens=1024,
|
|
)
|
|
|
|
choice = response.choices[0]
|
|
raw = choice.message.content or ""
|
|
|
|
if not raw.strip():
|
|
try:
|
|
raw = choice.message.model_extra.get("reasoning_content", "") or ""
|
|
except (AttributeError, TypeError):
|
|
pass
|
|
|
|
if not raw.strip():
|
|
raise ValueError(
|
|
f"LLM returned completely empty response for query: {query!r}. "
|
|
"Check if Ollama is running and the model is loaded."
|
|
)
|
|
|
|
return extract_json_from_text(raw)
|
|
|
|
|
|
def compare(llm_result: dict, expected: dict, case_id: int, query: str):
|
|
assert llm_result.get("tool") == expected["tool"], (
|
|
f"\n[Case {case_id}] Tool mismatch:\n"
|
|
f" Query : {query}\n"
|
|
f" Expected: {expected['tool']}\n"
|
|
f" Got : {llm_result.get('tool')}\n"
|
|
f" Full LLM: {json.dumps(llm_result, ensure_ascii=False)}"
|
|
)
|
|
|
|
llm_params = llm_result.get("params", {})
|
|
exp_params = expected.get("params", {})
|
|
|
|
for key, exp_val in exp_params.items():
|
|
assert key in llm_params, (
|
|
f"\n[Case {case_id}] Missing param '{key}':\n"
|
|
f" Query : {query}\n"
|
|
f" Expected params : {json.dumps(exp_params, ensure_ascii=False)}\n"
|
|
f" LLM params : {json.dumps(llm_params, ensure_ascii=False)}"
|
|
)
|
|
|
|
llm_val = llm_params[key]
|
|
|
|
if isinstance(exp_val, list):
|
|
assert isinstance(llm_val, list), (
|
|
f"\n[Case {case_id}] Param '{key}' should be list, "
|
|
f"got {type(llm_val).__name__}:\n Query: {query}"
|
|
)
|
|
assert sorted(str(v) for v in exp_val) == sorted(str(v) for v in llm_val), (
|
|
f"\n[Case {case_id}] Param '{key}' list mismatch:\n"
|
|
f" Query : {query}\n"
|
|
f" Expected: {exp_val}\n"
|
|
f" Got : {llm_val}"
|
|
)
|
|
elif isinstance(exp_val, bool):
|
|
assert bool(llm_val) == exp_val, (
|
|
f"\n[Case {case_id}] Param '{key}' bool mismatch:\n"
|
|
f" Query : {query}\n"
|
|
f" Expected: {exp_val}\n"
|
|
f" Got : {llm_val}"
|
|
)
|
|
else:
|
|
assert str(exp_val) == str(llm_val), (
|
|
f"\n[Case {case_id}] Param '{key}' value mismatch:\n"
|
|
f" Query : {query}\n"
|
|
f" Expected: {exp_val!r}\n"
|
|
f" Got : {llm_val!r}"
|
|
)
|
|
|
|
|
|
skip_if_no_ollama = pytest.mark.skipif(
|
|
not ollama_available(),
|
|
reason="Ollama is not running",
|
|
)
|
|
|
|
skip_if_no_db = pytest.mark.skipif(
|
|
not db_available(),
|
|
reason=f"Database not found: {DB_PATH}. Copy test_cases.db to testing/",
|
|
)
|
|
|
|
|
|
def pytest_generate_tests(metafunc):
|
|
if "db_case" in metafunc.fixturenames:
|
|
if db_available():
|
|
cases = load_cases()
|
|
metafunc.parametrize(
|
|
"db_case",
|
|
cases,
|
|
ids=[f"{row[0]:02d}" for row in cases],
|
|
)
|
|
else:
|
|
metafunc.parametrize("db_case", [])
|
|
|
|
|
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
|
|
|
@skip_if_no_ollama
|
|
@skip_if_no_db
|
|
def test_llm_extracts_params(db_case):
|
|
|
|
case_id, query, expected_raw = db_case
|
|
expected = json.loads(expected_raw)
|
|
|
|
llm_result = ask_llm(query)
|
|
compare(llm_result, expected, case_id, query)
|
|
|
|
|
|
@skip_if_no_db
|
|
def test_db_has_54_rows():
|
|
cases = load_cases()
|
|
assert len(cases) == 54, f"Expected 54 rows, got {len(cases)}"
|
|
|
|
|
|
@skip_if_no_db
|
|
def test_db_columns_are_valid():
|
|
cases = load_cases()
|
|
for case_id, query, expected_raw in cases:
|
|
assert query.strip(), f"Row {case_id}: empty query"
|
|
try:
|
|
expected = json.loads(expected_raw)
|
|
except json.JSONDecodeError as e:
|
|
pytest.fail(f"Row {case_id}: invalid JSON in expected — {e}")
|
|
assert "tool" in expected, f"Row {case_id}: missing 'tool' in expected"
|
|
assert "params" in expected, f"Row {case_id}: missing 'params' in expected"
|