add tests
This commit is contained in:
parent
18d7980582
commit
1a7e8aa355
16
app.py
16
app.py
@ -9,15 +9,19 @@ from core.stream_response import stream_response
|
||||
from api.fetch_api_data import set_log_callback
|
||||
|
||||
STARTERS = [
|
||||
("What legal data can the agent find?","magnifying_glass"),
|
||||
("What is the agent not allowed to do or use?","ban"),
|
||||
("What are the details of your AI model?","hexagon"),
|
||||
("What data sources does the agent rely on?","database"),
|
||||
("What legal data can the agent find?", "magnifying_glass"),
|
||||
("What is the agent not allowed to do or use?", "ban"),
|
||||
("What are the details of your AI model?", "hexagon"),
|
||||
("What data sources does the agent rely on?", "database"),
|
||||
]
|
||||
|
||||
PROFILES = [
|
||||
("qwen3.5:cloud","Qwen 3.5 CLOUD"),
|
||||
("gpt-oss:20b-cloud","GPT-OSS 20B CLOUD"),
|
||||
("qwen3.5:cloud", "Qwen 3.5 CLOUD (in Ollama)"),
|
||||
("gpt-oss:20b-cloud", "GPT-OSS 20B CLOUD (in Ollama)"),
|
||||
("gpt-oss:20b", "GPT-OSS 20B (Local LLM)"),
|
||||
("qwen3:8b", "Qwen3 8B (Local LLM)"),
|
||||
("gpt-4o", "GPT-4o (OpenAI API)"),
|
||||
("gpt-4o-mini", "GPT-4o Mini (OpenAI API)"),
|
||||
]
|
||||
|
||||
@cl.set_starters
|
||||
|
||||
@ -2,7 +2,13 @@ import os
|
||||
|
||||
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "qwen3.5:cloud")
|
||||
MAX_HISTORY = int(os.getenv("MAX_HISTORY", "20"))
|
||||
AGENT_TEMPERATURE = float(os.getenv("AGENT_TEMPERATURE", "0.7"))
|
||||
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "120.0"))
|
||||
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434/v1")
|
||||
OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "ollama")
|
||||
OLLAMA_TIMEOUT = float(os.getenv("OLLAMA_TIMEOUT", "120.0"))
|
||||
AGENT_TEMPERATURE = float(os.getenv("AGENT_TEMPERATURE", "0.7"))
|
||||
OLLAMA_MODELS = {"qwen3.5:cloud", "gpt-oss:20b-cloud", "gpt-oss:20b", "qwen3:8b"}
|
||||
|
||||
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
OPENAI_MODELS = {"gpt-4o", "gpt-4o-mini"}
|
||||
@ -2,7 +2,12 @@ from agents import Agent, AgentHooks
|
||||
from agents import OpenAIChatCompletionsModel, AsyncOpenAI, ModelSettings
|
||||
from agents import set_tracing_disabled
|
||||
|
||||
from core.config import DEFAULT_MODEL, OLLAMA_BASE_URL, OLLAMA_API_KEY, OLLAMA_TIMEOUT, AGENT_TEMPERATURE
|
||||
from core.config import (
|
||||
DEFAULT_MODEL, AGENT_TEMPERATURE, LLM_TIMEOUT,
|
||||
OLLAMA_BASE_URL, OLLAMA_API_KEY,
|
||||
OPENAI_BASE_URL, OPENAI_API_KEY,
|
||||
OPENAI_MODELS, OLLAMA_MODELS
|
||||
)
|
||||
from core.system_prompt import get_system_prompt
|
||||
from api.tools import ALL_TOOLS
|
||||
|
||||
@ -15,18 +20,35 @@ class MyAgentHooks(AgentHooks):
|
||||
async def on_end(self, context, agent, output):
|
||||
print(f"🏁 [AgentHooks] {agent.name} ended.")
|
||||
|
||||
def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent:
|
||||
def _make_client(model_name: str) -> AsyncOpenAI:
|
||||
"""Return the correct AsyncOpenAI client based on model name"""
|
||||
|
||||
client = AsyncOpenAI(
|
||||
base_url=OLLAMA_BASE_URL,
|
||||
api_key=OLLAMA_API_KEY,
|
||||
timeout=OLLAMA_TIMEOUT,
|
||||
max_retries=0
|
||||
)
|
||||
if model_name in OPENAI_MODELS:
|
||||
return AsyncOpenAI(
|
||||
base_url=OPENAI_BASE_URL,
|
||||
api_key=OPENAI_API_KEY,
|
||||
timeout=LLM_TIMEOUT,
|
||||
max_retries=0,
|
||||
)
|
||||
|
||||
if model_name in OLLAMA_MODELS:
|
||||
return AsyncOpenAI(
|
||||
base_url=OLLAMA_BASE_URL,
|
||||
api_key=OLLAMA_API_KEY,
|
||||
timeout=LLM_TIMEOUT,
|
||||
max_retries=0,
|
||||
)
|
||||
|
||||
raise ValueError(f"Model {model_name} not supported")
|
||||
|
||||
def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent:
|
||||
"""Initialize the assistant agent for legal work"""
|
||||
|
||||
client = _make_client(model_name)
|
||||
model = OpenAIChatCompletionsModel(model=model_name, openai_client=client)
|
||||
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
name="AI Lawyer Assistant",
|
||||
instructions=get_system_prompt(model_name),
|
||||
model=model,
|
||||
model_settings=ModelSettings(temperature=AGENT_TEMPERATURE, tool_choice="auto", parallel_tool_calls=False),
|
||||
|
||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
9
testing/fixtures.py
Normal file
9
testing/fixtures.py
Normal file
@ -0,0 +1,9 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
245
testing/reports/charts.py
Normal file
245
testing/reports/charts.py
Normal file
@ -0,0 +1,245 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
RESULTS_FILE = Path(__file__).parent.parent / "results.json"
|
||||
CHARTS_DIR = Path(__file__).parent.parent / "charts"
|
||||
|
||||
STATUS_COLORS = {
|
||||
"passed": "#3ddc84",
|
||||
"failed": "#ff5c5c",
|
||||
"skipped": "#e0c84a",
|
||||
"error": "#ff8a8a",
|
||||
}
|
||||
|
||||
BG_COLOR = "#1a1a1a"
|
||||
PANEL_COLOR = "#242424"
|
||||
TEXT_COLOR = "#e0e0e0"
|
||||
GRID_COLOR = "#333333"
|
||||
FONT_MONO = "monospace"
|
||||
|
||||
|
||||
def _style_ax(ax, title: str):
|
||||
ax.set_facecolor(PANEL_COLOR)
|
||||
ax.set_title(title, color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10)
|
||||
ax.tick_params(colors=TEXT_COLOR, labelsize=8)
|
||||
ax.spines[:].set_color(GRID_COLOR)
|
||||
ax.yaxis.grid(True, color=GRID_COLOR, linewidth=0.5, linestyle="--")
|
||||
ax.set_axisbelow(True)
|
||||
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
||||
label.set_color(TEXT_COLOR)
|
||||
label.set_fontfamily(FONT_MONO)
|
||||
|
||||
|
||||
def load_data() -> tuple[pd.DataFrame, dict]:
|
||||
if not RESULTS_FILE.exists():
|
||||
raise FileNotFoundError(f"results.json not found at {RESULTS_FILE}")
|
||||
|
||||
raw = json.loads(RESULTS_FILE.read_text(encoding="utf-8"))
|
||||
tests = raw.get("tests", [])
|
||||
|
||||
rows = []
|
||||
for t in tests:
|
||||
node = t["nodeid"]
|
||||
parts = node.split("::")
|
||||
module = parts[0].replace("tests/", "").replace("/", ".").replace(".py", "")
|
||||
cls = parts[1] if len(parts) >= 3 else "unknown"
|
||||
name = parts[-1]
|
||||
|
||||
duration = t.get("call", {}).get("duration", 0.0) if t.get("call") else 0.0
|
||||
|
||||
rows.append({
|
||||
"module": module,
|
||||
"class": cls,
|
||||
"name": name,
|
||||
"outcome": t["outcome"],
|
||||
"duration": duration,
|
||||
})
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
summary = raw.get("summary", {})
|
||||
return df, summary
|
||||
|
||||
|
||||
def chart_overall_status(df: pd.DataFrame, ax: plt.Axes):
|
||||
counts = df["outcome"].value_counts()
|
||||
colors = [STATUS_COLORS.get(k, "#888") for k in counts.index]
|
||||
|
||||
wedges, texts, pcts = ax.pie(
|
||||
counts.values,
|
||||
labels=counts.index,
|
||||
colors=colors,
|
||||
autopct="%1.1f%%",
|
||||
startangle=90,
|
||||
pctdistance=0.78,
|
||||
wedgeprops={"edgecolor": BG_COLOR, "linewidth": 2},
|
||||
)
|
||||
for t in texts:
|
||||
t.set_color(TEXT_COLOR)
|
||||
t.set_fontsize(9)
|
||||
t.set_fontfamily(FONT_MONO)
|
||||
for p in pcts:
|
||||
p.set_color("#111")
|
||||
p.set_fontsize(8)
|
||||
p.set_fontweight("bold")
|
||||
|
||||
ax.set_facecolor(PANEL_COLOR)
|
||||
ax.set_title("Overall Results", color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10)
|
||||
|
||||
|
||||
def chart_by_module(df: pd.DataFrame, ax: plt.Axes):
|
||||
pivot = (
|
||||
df.groupby(["module", "outcome"])
|
||||
.size()
|
||||
.unstack(fill_value=0)
|
||||
.reindex(columns=["passed", "failed", "skipped", "error"], fill_value=0)
|
||||
)
|
||||
|
||||
x = np.arange(len(pivot))
|
||||
width = 0.2
|
||||
offset = -(len(pivot.columns) - 1) / 2 * width
|
||||
|
||||
for i, col in enumerate(pivot.columns):
|
||||
bars = ax.bar(
|
||||
x + offset + i * width,
|
||||
pivot[col],
|
||||
width,
|
||||
label=col,
|
||||
color=STATUS_COLORS.get(col, "#888"),
|
||||
edgecolor=BG_COLOR,
|
||||
linewidth=0.8,
|
||||
)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(pivot.index, rotation=25, ha="right", fontsize=7, fontfamily=FONT_MONO)
|
||||
ax.set_ylabel("Tests", color=TEXT_COLOR, fontsize=9)
|
||||
ax.legend(
|
||||
fontsize=7,
|
||||
labelcolor=TEXT_COLOR,
|
||||
facecolor=PANEL_COLOR,
|
||||
edgecolor=GRID_COLOR,
|
||||
)
|
||||
_style_ax(ax, "Results by Module")
|
||||
|
||||
|
||||
def chart_duration_histogram(df: pd.DataFrame, ax: plt.Axes):
|
||||
durations = df.loc[df["outcome"] != "skipped", "duration"].values * 1000
|
||||
|
||||
if len(durations) == 0:
|
||||
ax.text(0.5, 0.5, "No data", ha="center", va="center", color=TEXT_COLOR)
|
||||
_style_ax(ax, "Test Duration (ms)")
|
||||
return
|
||||
|
||||
mean_ms = float(np.mean(durations))
|
||||
median_ms = float(np.median(durations))
|
||||
p95_ms = float(np.percentile(durations, 95))
|
||||
|
||||
ax.hist(durations, bins=20, color="#5cb8ff", edgecolor=BG_COLOR, linewidth=0.6, alpha=0.85)
|
||||
ax.axvline(mean_ms, color="#3ddc84", linewidth=1.5, linestyle="--", label=f"Mean {mean_ms:.1f} ms")
|
||||
ax.axvline(median_ms, color="#e0c84a", linewidth=1.5, linestyle=":", label=f"Median {median_ms:.1f} ms")
|
||||
ax.axvline(p95_ms, color="#ff5c5c", linewidth=1.5, linestyle="-.", label=f"P95 {p95_ms:.1f} ms")
|
||||
|
||||
ax.set_xlabel("ms", color=TEXT_COLOR, fontsize=9)
|
||||
ax.set_ylabel("Tests", color=TEXT_COLOR, fontsize=9)
|
||||
ax.legend(fontsize=7, labelcolor=TEXT_COLOR, facecolor=PANEL_COLOR, edgecolor=GRID_COLOR)
|
||||
_style_ax(ax, "Test Duration (ms)")
|
||||
|
||||
|
||||
def chart_slowest_tests(df: pd.DataFrame, ax: plt.Axes):
|
||||
top = (
|
||||
df[df["outcome"] != "skipped"]
|
||||
.nlargest(10, "duration")
|
||||
.copy()
|
||||
)
|
||||
top["label"] = top["class"] + "::" + top["name"]
|
||||
top["duration"] = top["duration"] * 1000
|
||||
|
||||
colors = [STATUS_COLORS.get(o, "#888") for o in top["outcome"]]
|
||||
bars = ax.barh(top["label"], top["duration"], color=colors, edgecolor=BG_COLOR, linewidth=0.6)
|
||||
|
||||
ax.set_xlabel("ms", color=TEXT_COLOR, fontsize=9)
|
||||
ax.tick_params(axis="y", labelsize=7)
|
||||
ax.invert_yaxis()
|
||||
_style_ax(ax, "Top 10 Slowest Tests")
|
||||
|
||||
|
||||
def chart_stats_table(df: pd.DataFrame, summary: dict, ax: plt.Axes):
|
||||
ax.set_facecolor(PANEL_COLOR)
|
||||
ax.axis("off")
|
||||
|
||||
total = len(df)
|
||||
passed = summary.get("passed", 0)
|
||||
failed = summary.get("failed", 0)
|
||||
skipped = summary.get("skipped", 0)
|
||||
duration = df["duration"].sum() * 1000
|
||||
|
||||
durations = df.loc[df["outcome"] != "skipped", "duration"].values * 1000
|
||||
|
||||
rows = [
|
||||
["Total tests", str(total)],
|
||||
["Passed", str(passed)],
|
||||
["Failed", str(failed)],
|
||||
["Skipped", str(skipped)],
|
||||
["Pass rate", f"{passed / total * 100:.1f}%" if total else "—"],
|
||||
["Total time", f"{duration:.0f} ms"],
|
||||
["Mean duration", f"{np.mean(durations):.1f} ms" if len(durations) else "—"],
|
||||
["Median", f"{np.median(durations):.1f} ms" if len(durations) else "—"],
|
||||
["P95", f"{np.percentile(durations, 95):.1f} ms" if len(durations) else "—"],
|
||||
]
|
||||
|
||||
table = ax.table(
|
||||
cellText=rows,
|
||||
colLabels=["Metric", "Value"],
|
||||
cellLoc="left",
|
||||
loc="center",
|
||||
colWidths=[0.6, 0.4],
|
||||
)
|
||||
table.auto_set_font_size(False)
|
||||
table.set_fontsize(9)
|
||||
|
||||
for (row, col), cell in table.get_celld().items():
|
||||
cell.set_facecolor("#2e2e2e" if row % 2 == 0 else PANEL_COLOR)
|
||||
cell.set_edgecolor(GRID_COLOR)
|
||||
cell.set_text_props(color=TEXT_COLOR, fontfamily=FONT_MONO)
|
||||
|
||||
ax.set_title("Summary", color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10)
|
||||
|
||||
|
||||
def generate(output_path: Path = None) -> Path:
|
||||
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_path or CHARTS_DIR / "report.png"
|
||||
|
||||
df, summary = load_data()
|
||||
|
||||
fig = plt.figure(figsize=(18, 14), facecolor=BG_COLOR)
|
||||
fig.suptitle(
|
||||
"Legal AI Assistant — Test Report",
|
||||
fontsize=16, fontweight="bold", color=TEXT_COLOR,
|
||||
y=0.98, fontfamily=FONT_MONO,
|
||||
)
|
||||
|
||||
gs = fig.add_gridspec(2, 3, hspace=0.45, wspace=0.35,
|
||||
top=0.93, bottom=0.05, left=0.06, right=0.97)
|
||||
|
||||
chart_overall_status(df, fig.add_subplot(gs[0, 0]))
|
||||
chart_by_module(df, fig.add_subplot(gs[0, 1:]))
|
||||
chart_duration_histogram(df, fig.add_subplot(gs[1, 0]))
|
||||
chart_slowest_tests(df, fig.add_subplot(gs[1, 1]))
|
||||
chart_stats_table(df, summary, fig.add_subplot(gs[1, 2]))
|
||||
|
||||
fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor=BG_COLOR)
|
||||
plt.close(fig)
|
||||
return output_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
out = generate()
|
||||
print(f"Chart saved: {out}")
|
||||
58
testing/run_tests.py
Normal file
58
testing/run_tests.py
Normal file
@ -0,0 +1,58 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).parent.parent
|
||||
TESTS_DIR = Path(__file__).parent
|
||||
RESULTS_JSON = TESTS_DIR / "results.json"
|
||||
REPORT_HTML = TESTS_DIR / "charts" / "report.html"
|
||||
CHARTS_DIR = TESTS_DIR / "charts"
|
||||
|
||||
|
||||
def run_pytest() -> int:
|
||||
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
args = [
|
||||
sys.executable, "-m", "pytest",
|
||||
|
||||
str(TESTS_DIR / "unit"),
|
||||
str(TESTS_DIR / "integration"),
|
||||
str(TESTS_DIR / "e2e"),
|
||||
str(TESTS_DIR / "llm"),
|
||||
|
||||
"--json-report",
|
||||
f"--json-report-file={RESULTS_JSON}",
|
||||
|
||||
f"--html={REPORT_HTML}",
|
||||
"--self-contained-html",
|
||||
|
||||
f"--cov=api",
|
||||
f"--cov=core",
|
||||
"--cov-report=term-missing",
|
||||
f"--cov-report=html:{CHARTS_DIR / 'coverage'}",
|
||||
|
||||
"-p", "no:terminal",
|
||||
"--tb=short",
|
||||
"-q",
|
||||
]
|
||||
|
||||
return subprocess.run(args, cwd=str(ROOT), check=False).returncode
|
||||
|
||||
|
||||
def run_charts():
|
||||
from testing.reports.charts import generate
|
||||
out = generate()
|
||||
print(f"Charts saved: {out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Start testing")
|
||||
|
||||
code = run_pytest()
|
||||
|
||||
if RESULTS_JSON.exists():
|
||||
run_charts()
|
||||
|
||||
print("Stop testing")
|
||||
|
||||
sys.exit(code)
|
||||
152
testing/tests/test_api.py
Normal file
152
testing/tests/test_api.py
Normal file
@ -0,0 +1,152 @@
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
from api.config import JUSTICE_API_BASE, HTTP_TIMEOUT
|
||||
|
||||
HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36",
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Accept-Language": "sk-SK,sk;q=0.9",
|
||||
"Referer": "https://obcan.justice.sk/",
|
||||
"Origin": "https://obcan.justice.sk",
|
||||
}
|
||||
|
||||
|
||||
def api_available() -> bool:
|
||||
try:
|
||||
r = httpx.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 1}, headers=HEADERS, timeout=5)
|
||||
return r.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
skip_if_offline = pytest.mark.skipif(
|
||||
not api_available(),
|
||||
reason="justice.sk API is not reachable"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
with httpx.Client(headers=HEADERS, timeout=HTTP_TIMEOUT, follow_redirects=True) as c:
|
||||
yield c
|
||||
|
||||
|
||||
@skip_if_offline
|
||||
class TestCourtsEndpoint:
|
||||
|
||||
def test_returns_200(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 5})
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_response_has_content_key(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 5})
|
||||
data = r.json()
|
||||
assert "content" in data
|
||||
|
||||
def test_total_elements_is_positive(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 1})
|
||||
data = r.json()
|
||||
assert data.get("totalElements", 0) > 0
|
||||
|
||||
def test_court_by_id_returns_valid_record(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sud/sud_175")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert "id" in data or "nazov" in data
|
||||
|
||||
def test_court_autocomplete_returns_list(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sud/autocomplete", params={"query": "Bratislava", "limit": 5})
|
||||
assert r.status_code == 200
|
||||
assert isinstance(r.json(), list)
|
||||
|
||||
def test_nonexistent_court_returns_404(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sud/sud_999999999")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
@skip_if_offline
|
||||
class TestJudgesEndpoint:
|
||||
|
||||
def test_judge_search_returns_200(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={"size": 5})
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_judge_autocomplete_returns_results(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca/autocomplete", params={"query": "Novák", "limit": 10})
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_judge_search_by_kraj(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={
|
||||
"krajFacetFilter": "Bratislavský kraj",
|
||||
"size": 5
|
||||
})
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_judge_search_pagination_page_zero(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={"page": 0, "size": 10})
|
||||
data = r.json()
|
||||
assert "content" in data
|
||||
|
||||
|
||||
@skip_if_offline
|
||||
class TestDecisionsEndpoint:
|
||||
|
||||
def test_decision_search_returns_200(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie", params={"size": 3})
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_decision_search_with_date_range(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie", params={
|
||||
"vydaniaOd": "01.01.2023",
|
||||
"vydaniaDo": "31.12.2023",
|
||||
"size": 3,
|
||||
})
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_decision_autocomplete(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete", params={"query": "Rozsudok", "limit": 5})
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@skip_if_offline
|
||||
class TestContractsEndpoint:
|
||||
|
||||
def test_contract_search_returns_200(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/zmluvy", params={"size": 5})
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_contract_search_by_typ_dokumentu(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/zmluvy", params={
|
||||
"typDokumentuFacetFilter": "ZMLUVA",
|
||||
"size": 3,
|
||||
})
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@skip_if_offline
|
||||
class TestCivilProceedingsEndpoint:
|
||||
|
||||
def test_civil_proceedings_returns_200(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania", params={"size": 3})
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_civil_proceedings_date_filter(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania", params={
|
||||
"pojednavaniaOd": "01.01.2024",
|
||||
"pojednavaniaDo": "31.01.2024",
|
||||
"size": 3,
|
||||
})
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
@skip_if_offline
|
||||
class TestAdminProceedingsEndpoint:
|
||||
|
||||
def test_admin_proceedings_returns_200(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie", params={"size": 3})
|
||||
assert r.status_code == 200
|
||||
|
||||
def test_admin_proceedings_autocomplete(self, client):
|
||||
r = client.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete", params={"query": "test", "limit": 5})
|
||||
assert r.status_code == 200
|
||||
113
testing/tests/test_fetch.py
Normal file
113
testing/tests/test_fetch.py
Normal file
@ -0,0 +1,113 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import httpx
|
||||
import respx
|
||||
|
||||
from api.fetch_api_data import fetch_api_data, set_log_callback, _cache
|
||||
|
||||
|
||||
BASE = "https://obcan.justice.sk/pilot/api/ress-isu-service"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
_cache.clear()
|
||||
yield
|
||||
_cache.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestFetchApiData:
|
||||
|
||||
@respx.mock
|
||||
async def test_successful_request_returns_dict(self):
|
||||
url = f"{BASE}/v1/sud"
|
||||
respx.get(url).mock(return_value=httpx.Response(200, json={"content": [], "totalElements": 0}))
|
||||
|
||||
result = await fetch_api_data(icon="", url=url, params={})
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "totalElements" in result
|
||||
|
||||
@respx.mock
|
||||
async def test_cache_hit_on_second_call(self):
|
||||
url = f"{BASE}/v1/sud"
|
||||
mock = respx.get(url).mock(return_value=httpx.Response(200, json={"content": []}))
|
||||
|
||||
await fetch_api_data(icon="", url=url, params={})
|
||||
await fetch_api_data(icon="", url=url, params={})
|
||||
|
||||
assert mock.call_count == 1
|
||||
|
||||
@respx.mock
|
||||
async def test_http_404_returns_error_dict(self):
|
||||
url = f"{BASE}/v1/sud/sud_99999"
|
||||
respx.get(url).mock(return_value=httpx.Response(404, text="Not Found"))
|
||||
|
||||
result = await fetch_api_data(icon="", url=url, params={})
|
||||
|
||||
assert result["error"] == "http_error"
|
||||
assert result["status_code"] == 404
|
||||
|
||||
@respx.mock
|
||||
async def test_http_500_returns_error_dict(self):
|
||||
url = f"{BASE}/v1/sud"
|
||||
respx.get(url).mock(return_value=httpx.Response(500, text="Server Error"))
|
||||
|
||||
result = await fetch_api_data(icon="", url=url, params={})
|
||||
|
||||
assert result["error"] == "http_error"
|
||||
assert result["status_code"] == 500
|
||||
|
||||
@respx.mock
|
||||
async def test_remove_keys_strips_specified_fields(self):
|
||||
url = f"{BASE}/v1/sud/sud_1"
|
||||
respx.get(url).mock(return_value=httpx.Response(200, json={"name": "Súd", "foto": "base64data"}))
|
||||
|
||||
result = await fetch_api_data(icon="", url=url, params={}, remove_keys=["foto"])
|
||||
|
||||
assert "foto" not in result
|
||||
assert "name" in result
|
||||
|
||||
@respx.mock
|
||||
async def test_remove_keys_missing_key_no_error(self):
|
||||
url = f"{BASE}/v1/sud/sud_1"
|
||||
respx.get(url).mock(return_value=httpx.Response(200, json={"name": "Súd"}))
|
||||
|
||||
result = await fetch_api_data(icon="", url=url, params={}, remove_keys=["foto", "nonexistent"])
|
||||
|
||||
assert result["name"] == "Súd"
|
||||
|
||||
@respx.mock
|
||||
async def test_log_callback_is_called(self):
|
||||
url = f"{BASE}/v1/sud"
|
||||
respx.get(url).mock(return_value=httpx.Response(200, json={}))
|
||||
|
||||
log_lines = []
|
||||
set_log_callback(lambda line: log_lines.append(line))
|
||||
|
||||
await fetch_api_data(icon="", url=url, params={})
|
||||
|
||||
set_log_callback(None)
|
||||
assert len(log_lines) > 0
|
||||
|
||||
@respx.mock
|
||||
async def test_params_are_passed_in_request(self):
|
||||
url = f"{BASE}/v1/sud"
|
||||
mock = respx.get(url).mock(return_value=httpx.Response(200, json={}))
|
||||
|
||||
await fetch_api_data(icon="", url=url, params={"query": "Bratislava", "size": 10})
|
||||
|
||||
assert mock.called
|
||||
sent_params = dict(mock.calls[0].request.url.params)
|
||||
assert sent_params["query"] == "Bratislava"
|
||||
assert sent_params["size"] == "10"
|
||||
|
||||
@respx.mock
|
||||
async def test_connect_error_returns_error_dict(self):
|
||||
url = f"{BASE}/v1/sud"
|
||||
respx.get(url).mock(side_effect=httpx.ConnectError("Connection refused"))
|
||||
|
||||
result = await fetch_api_data(icon="", url=url, params={})
|
||||
|
||||
assert result["error"] == "request_error"
|
||||
120
testing/tests/test_llm_compare.py
Normal file
120
testing/tests/test_llm_compare.py
Normal file
@ -0,0 +1,120 @@
|
||||
import time
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
from core.config import (
|
||||
OLLAMA_BASE_URL, OLLAMA_API_KEY, LLM_TIMEOUT,
|
||||
OLLAMA_MODELS, OPENAI_MODELS
|
||||
)
|
||||
from core.system_prompt import get_system_prompt
|
||||
|
||||
MODELS = OLLAMA_MODELS | OPENAI_MODELS
|
||||
|
||||
TEST_QUERIES = [
|
||||
{
|
||||
"id": "court_search",
|
||||
"query": "Nájdi súdy v Bratislavskom kraji.",
|
||||
"expected_keywords": ["súd", "bratislava", "kraj"],
|
||||
},
|
||||
{
|
||||
"id": "judge_search",
|
||||
"query": "Vyhľadaj sudcu Novák.",
|
||||
"expected_keywords": ["sudca", "novák", "novak"],
|
||||
},
|
||||
{
|
||||
"id": "no_legal_advice",
|
||||
"query": "Mám spor so zamestnávateľom, čo mám robiť?",
|
||||
"forbidden_keywords": ["musíte", "odporúčam vám podať žalobu", "právne poradenstvo"],
|
||||
"expected_keywords": ["api", "ministerstvo", "nie som právnik", "právny poradca"],
|
||||
},
|
||||
{
|
||||
"id": "slovak_response",
|
||||
"query": "What courts exist in Slovakia?",
|
||||
"expected_keywords": ["súd", "slovensko", "kraj"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def ollama_available() -> bool:
|
||||
try:
|
||||
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY)
|
||||
client.models.list()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
skip_if_no_ollama = pytest.mark.skipif(
|
||||
not ollama_available(),
|
||||
reason="Ollama is not running"
|
||||
)
|
||||
|
||||
|
||||
def query_model(model: str, user_message: str) -> tuple[str, float]:
|
||||
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY, timeout=LLM_TIMEOUT)
|
||||
start = time.perf_counter()
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": get_system_prompt(model)},
|
||||
{"role": "user", "content": user_message},
|
||||
],
|
||||
temperature=0.0,
|
||||
max_tokens=512,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
text = response.choices[0].message.content or ""
|
||||
return text, elapsed
|
||||
|
||||
|
||||
llm_results: dict[str, list[dict]] = {m: [] for m in MODELS}
|
||||
|
||||
|
||||
@skip_if_no_ollama
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("case", TEST_QUERIES, ids=[c["id"] for c in TEST_QUERIES])
|
||||
class TestLLMResponses:
|
||||
|
||||
def test_response_is_not_empty(self, model, case):
|
||||
text, _ = query_model(model, case["query"])
|
||||
assert len(text.strip()) > 0
|
||||
|
||||
def test_response_in_slovak(self, model, case):
|
||||
text, _ = query_model(model, case["query"])
|
||||
slovak_markers = ["je", "sú", "som", "nie", "súd", "sudca", "kraj", "ale", "alebo", "pre"]
|
||||
assert any(m in text.lower() for m in slovak_markers)
|
||||
|
||||
def test_expected_keywords_present(self, model, case):
|
||||
if "expected_keywords" not in case:
|
||||
pytest.skip("No expected_keywords defined")
|
||||
text, _ = query_model(model, case["query"])
|
||||
assert any(kw.lower() in text.lower() for kw in case["expected_keywords"])
|
||||
|
||||
def test_forbidden_keywords_absent(self, model, case):
|
||||
if "forbidden_keywords" not in case:
|
||||
pytest.skip("No forbidden_keywords defined")
|
||||
text, _ = query_model(model, case["query"])
|
||||
for kw in case["forbidden_keywords"]:
|
||||
assert kw.lower() not in text.lower(), f"Forbidden keyword found: {kw}"
|
||||
|
||||
def test_response_time_under_threshold(self, model, case):
|
||||
_, elapsed = query_model(model, case["query"])
|
||||
assert elapsed < float(LLM_TIMEOUT), f"Response took {elapsed:.1f}s"
|
||||
|
||||
def test_response_length_reasonable(self, model, case):
|
||||
text, _ = query_model(model, case["query"])
|
||||
assert 10 < len(text) < 4000
|
||||
|
||||
|
||||
@skip_if_no_ollama
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
class TestLLMBenchmark:
|
||||
|
||||
def test_collect_benchmark_data(self, model):
|
||||
"""Collects timing data per model — used by charts.py."""
|
||||
times = []
|
||||
for case in TEST_QUERIES:
|
||||
_, elapsed = query_model(model, case["query"])
|
||||
times.append(elapsed)
|
||||
llm_results[model].extend(times)
|
||||
assert len(times) == len(TEST_QUERIES)
|
||||
168
testing/tests/test_schemas.py
Normal file
168
testing/tests/test_schemas.py
Normal file
@ -0,0 +1,168 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from api.schemas import (
|
||||
CourtByID,
|
||||
JudgeByID,
|
||||
AdminProceedingsByID,
|
||||
CourtSearch,
|
||||
JudgeSearch,
|
||||
DecisionSearch,
|
||||
ContractSearch,
|
||||
CivilProceedingsSearch,
|
||||
CourtAutocomplete,
|
||||
JudgeAutocomplete,
|
||||
)
|
||||
|
||||
|
||||
class TestCourtByID:
|
||||
|
||||
def test_digit_gets_prefix(self):
|
||||
assert CourtByID(id="175").id == "sud_175"
|
||||
|
||||
def test_already_prefixed_unchanged(self):
|
||||
assert CourtByID(id="sud_175").id == "sud_175"
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
assert CourtByID(id=" 42 ").id == "sud_42"
|
||||
|
||||
def test_single_digit(self):
|
||||
assert CourtByID(id="1").id == "sud_1"
|
||||
|
||||
def test_large_number(self):
|
||||
assert CourtByID(id="9999").id == "sud_9999"
|
||||
|
||||
def test_non_numeric_string_unchanged(self):
|
||||
assert CourtByID(id="some_string").id == "some_string"
|
||||
|
||||
def test_whitespace_digit_combination(self):
|
||||
assert CourtByID(id=" 7 ").id == "sud_7"
|
||||
|
||||
|
||||
class TestJudgeByID:
|
||||
|
||||
def test_digit_gets_prefix(self):
|
||||
assert JudgeByID(id="1").id == "sudca_1"
|
||||
|
||||
def test_already_prefixed_unchanged(self):
|
||||
assert JudgeByID(id="sudca_99").id == "sudca_99"
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
assert JudgeByID(id=" 7 ").id == "sudca_7"
|
||||
|
||||
def test_large_number(self):
|
||||
assert JudgeByID(id="12345").id == "sudca_12345"
|
||||
|
||||
def test_non_numeric_string_unchanged(self):
|
||||
assert JudgeByID(id="sudca_abc").id == "sudca_abc"
|
||||
|
||||
|
||||
class TestAdminProceedingsByID:
|
||||
|
||||
def test_digit_gets_prefix(self):
|
||||
assert AdminProceedingsByID(id="103").id == "spravneKonanie_103"
|
||||
|
||||
def test_already_prefixed_unchanged(self):
|
||||
assert AdminProceedingsByID(id="spravneKonanie_103").id == "spravneKonanie_103"
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
assert AdminProceedingsByID(id=" 55 ").id == "spravneKonanie_55"
|
||||
|
||||
def test_single_digit(self):
|
||||
assert AdminProceedingsByID(id="5").id == "spravneKonanie_5"
|
||||
|
||||
def test_non_numeric_string_unchanged(self):
|
||||
assert AdminProceedingsByID(id="custom_id").id == "custom_id"
|
||||
|
||||
|
||||
class TestPaginationDefaults:
|
||||
|
||||
def test_court_search_defaults_are_none(self):
|
||||
obj = CourtSearch()
|
||||
assert obj.page is None
|
||||
assert obj.size is None
|
||||
|
||||
def test_sort_direction_default_asc(self):
|
||||
assert CourtSearch().sortDirection == "ASC"
|
||||
|
||||
def test_sort_direction_desc_accepted(self):
|
||||
assert CourtSearch(sortDirection="DESC").sortDirection == "DESC"
|
||||
|
||||
def test_sort_direction_invalid_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CourtSearch(sortDirection="INVALID")
|
||||
|
||||
def test_page_cannot_be_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CourtSearch(page=-1)
|
||||
|
||||
def test_size_cannot_be_zero(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CourtSearch(size=0)
|
||||
|
||||
def test_size_one_accepted(self):
|
||||
assert CourtSearch(size=1).size == 1
|
||||
|
||||
def test_page_zero_accepted(self):
|
||||
assert CourtSearch(page=0).page == 0
|
||||
|
||||
|
||||
class TestFacetFilters:
|
||||
|
||||
def test_court_search_facet_type_list(self):
|
||||
obj = CourtSearch(typSuduFacetFilter=["Okresný súd", "Krajský súd"])
|
||||
assert len(obj.typSuduFacetFilter) == 2
|
||||
|
||||
def test_judge_search_facet_kraj(self):
|
||||
obj = JudgeSearch(krajFacetFilter=["Bratislavský kraj"])
|
||||
assert obj.krajFacetFilter == ["Bratislavský kraj"]
|
||||
|
||||
def test_decision_search_forma_filter(self):
|
||||
obj = DecisionSearch(formaRozhodnutiaFacetFilter=["Rozsudok"])
|
||||
assert "Rozsudok" in obj.formaRozhodnutiaFacetFilter
|
||||
|
||||
def test_contract_search_typ_dokumentu(self):
|
||||
obj = ContractSearch(typDokumentuFacetFilter=["ZMLUVA", "DODATOK"])
|
||||
assert len(obj.typDokumentuFacetFilter) == 2
|
||||
|
||||
def test_civil_proceedings_usek_filter(self):
|
||||
obj = CivilProceedingsSearch(usekFacetFilter=["C", "O"])
|
||||
assert "C" in obj.usekFacetFilter
|
||||
|
||||
|
||||
class TestAutocomplete:
|
||||
|
||||
def test_court_autocomplete_limit_min_one(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CourtAutocomplete(limit=0)
|
||||
|
||||
def test_court_autocomplete_limit_valid(self):
|
||||
assert CourtAutocomplete(limit=5).limit == 5
|
||||
|
||||
def test_judge_autocomplete_guid_sud(self):
|
||||
obj = JudgeAutocomplete(guidSud="sud_100", limit=10)
|
||||
assert obj.guidSud == "sud_100"
|
||||
|
||||
def test_autocomplete_empty_query_accepted(self):
|
||||
assert CourtAutocomplete().query is None
|
||||
|
||||
def test_autocomplete_query_string(self):
|
||||
assert JudgeAutocomplete(query="Novák").query == "Novák"
|
||||
|
||||
|
||||
class TestModelDumpExcludeNone:
|
||||
|
||||
def test_excludes_none_fields(self):
|
||||
dumped = CourtSearch(query="Bratislava").model_dump(exclude_none=True)
|
||||
assert "query" in dumped
|
||||
assert "page" not in dumped
|
||||
assert "size" not in dumped
|
||||
|
||||
def test_full_params_included(self):
|
||||
dumped = JudgeSearch(query="Novák", page=0, size=20).model_dump(exclude_none=True)
|
||||
assert dumped["query"] == "Novák"
|
||||
assert dumped["page"] == 0
|
||||
assert dumped["size"] == 20
|
||||
|
||||
def test_empty_schema_dumps_empty_dict(self):
|
||||
assert CourtSearch().model_dump(exclude_none=True) == {}
|
||||
123
testing/tests/test_sys_prompt.py
Normal file
123
testing/tests/test_sys_prompt.py
Normal file
@ -0,0 +1,123 @@
|
||||
import pytest
|
||||
from core.system_prompt import get_system_prompt
|
||||
from core.config import OLLAMA_MODELS, OPENAI_MODELS
|
||||
|
||||
MODELS = OLLAMA_MODELS | OPENAI_MODELS
|
||||
|
||||
@pytest.fixture(params=MODELS)
|
||||
def prompt(request):
|
||||
return get_system_prompt(request.param)
|
||||
|
||||
|
||||
@pytest.fixture(params=MODELS)
|
||||
def prompt_lower(request):
|
||||
return get_system_prompt(request.param).lower()
|
||||
|
||||
|
||||
class TestPromptContainsModelName:
|
||||
|
||||
def test_model_name_appears_in_prompt(self, prompt, request):
|
||||
model = request.node.callspec.params["prompt"]
|
||||
assert model in prompt
|
||||
|
||||
|
||||
class TestRequiredSections:
|
||||
|
||||
def test_has_role_section(self, prompt_lower):
|
||||
assert "role" in prompt_lower
|
||||
|
||||
def test_has_operational_constraints(self, prompt_lower):
|
||||
assert "constraint" in prompt_lower or "not allowed" in prompt_lower or "do not" in prompt_lower
|
||||
|
||||
def test_has_workflow_steps(self, prompt_lower):
|
||||
assert "step" in prompt_lower or "workflow" in prompt_lower
|
||||
|
||||
def test_has_error_recovery(self, prompt_lower):
|
||||
assert "error" in prompt_lower or "not found" in prompt_lower
|
||||
|
||||
def test_has_response_format(self, prompt_lower):
|
||||
assert "format" in prompt_lower or "response" in prompt_lower
|
||||
|
||||
|
||||
class TestSupportedDomains:
|
||||
|
||||
def test_courts_mentioned(self, prompt_lower):
|
||||
assert "court" in prompt_lower or "súd" in prompt_lower or "sud" in prompt_lower
|
||||
|
||||
def test_judges_mentioned(self, prompt_lower):
|
||||
assert "judge" in prompt_lower or "sudca" in prompt_lower or "sudcovia" in prompt_lower
|
||||
|
||||
def test_decisions_mentioned(self, prompt_lower):
|
||||
assert "decision" in prompt_lower or "rozhodnut" in prompt_lower
|
||||
|
||||
def test_contracts_mentioned(self, prompt_lower):
|
||||
assert "contract" in prompt_lower or "zmluv" in prompt_lower
|
||||
|
||||
def test_civil_proceedings_mentioned(self, prompt_lower):
|
||||
assert "civil" in prompt_lower or "obcian" in prompt_lower or "pojednavan" in prompt_lower
|
||||
|
||||
def test_admin_proceedings_mentioned(self, prompt_lower):
|
||||
assert "admin" in prompt_lower or "spravne" in prompt_lower or "správne" in prompt_lower
|
||||
|
||||
|
||||
class TestConstraints:
|
||||
|
||||
def test_no_legal_advice_constraint(self, prompt_lower):
|
||||
assert "legal advisor" in prompt_lower or "not a lawyer" in prompt_lower or "legal advice" in prompt_lower
|
||||
|
||||
def test_api_only_constraint(self, prompt_lower):
|
||||
assert "api" in prompt_lower
|
||||
|
||||
def test_slovak_language_requirement(self, prompt_lower):
|
||||
assert "slovak" in prompt_lower or "slovensk" in prompt_lower
|
||||
|
||||
def test_no_raw_json_rule(self, prompt_lower):
|
||||
assert "json" in prompt_lower or "technical" in prompt_lower
|
||||
|
||||
def test_no_speculate_rule(self, prompt_lower):
|
||||
assert "speculate" in prompt_lower or "infer" in prompt_lower or "gaps" in prompt_lower
|
||||
|
||||
|
||||
class TestPaginationRules:
|
||||
|
||||
def test_page_starts_at_zero_mentioned(self, prompt_lower):
|
||||
assert "page" in prompt_lower and "0" in prompt_lower
|
||||
|
||||
def test_autocomplete_preferred_mentioned(self, prompt_lower):
|
||||
assert "autocomplete" in prompt_lower
|
||||
|
||||
|
||||
class TestDateRules:
|
||||
|
||||
def test_date_format_dd_mm_yyyy_mentioned(self, prompt):
|
||||
assert "DD.MM.YYYY" in prompt or "dd.mm.yyyy" in prompt.lower()
|
||||
|
||||
def test_civil_date_field_mentioned(self, prompt_lower):
|
||||
assert "pojednavaniaod" in prompt_lower or "pojednavania" in prompt_lower
|
||||
|
||||
def test_decision_date_field_mentioned(self, prompt_lower):
|
||||
assert "vydaniaod" in prompt_lower or "vydania" in prompt_lower
|
||||
|
||||
|
||||
class TestIDNormalizationRules:
|
||||
|
||||
def test_sud_prefix_mentioned(self, prompt_lower):
|
||||
assert "sud_" in prompt_lower
|
||||
|
||||
def test_sudca_prefix_mentioned(self, prompt_lower):
|
||||
assert "sudca_" in prompt_lower
|
||||
|
||||
def test_spravnekonanie_prefix_mentioned(self, prompt_lower):
|
||||
assert "spravnekonanie_" in prompt_lower or "spravne" in prompt_lower
|
||||
|
||||
|
||||
class TestPromptLength:
|
||||
|
||||
def test_prompt_is_not_empty(self, prompt):
|
||||
assert len(prompt.strip()) > 0
|
||||
|
||||
def test_prompt_has_minimum_length(self, prompt):
|
||||
assert len(prompt) > 500
|
||||
|
||||
def test_prompt_has_reasonable_max_length(self, prompt):
|
||||
assert len(prompt) < 50_000
|
||||
152
testing/tests/test_tools.py
Normal file
152
testing/tests/test_tools.py
Normal file
@ -0,0 +1,152 @@
|
||||
import pytest
|
||||
import httpx
|
||||
import respx
|
||||
|
||||
from api.fetch_api_data import _cache
|
||||
from api.config import JUSTICE_API_BASE
|
||||
from api.tools import (
|
||||
court_search,
|
||||
court_id,
|
||||
court_autocomplete,
|
||||
judge_search,
|
||||
judge_id,
|
||||
judge_autocomplete,
|
||||
decision_search,
|
||||
contract_search,
|
||||
civil_proceedings_search,
|
||||
admin_proceedings_search,
|
||||
)
|
||||
|
||||
|
||||
EMPTY_LIST_RESPONSE = {"content": [], "totalElements": 0, "numberOfElements": 0}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
_cache.clear()
|
||||
yield
|
||||
_cache.clear()
|
||||
|
||||
|
||||
def make_response(body: dict = None):
|
||||
return httpx.Response(200, json=body or EMPTY_LIST_RESPONSE)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCourtTools:
|
||||
|
||||
@respx.mock
|
||||
async def test_court_search_calls_correct_url(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud").mock(return_value=make_response())
|
||||
await court_search.on_invoke_tool(None, '{"query": "Bratislava"}')
|
||||
assert mock.called
|
||||
|
||||
@respx.mock
|
||||
async def test_court_id_calls_correct_url(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud/sud_175").mock(
|
||||
return_value=httpx.Response(200, json={"id": "sud_175", "foto": "data"})
|
||||
)
|
||||
await court_id.on_invoke_tool(None, '{"id": "175"}')
|
||||
assert mock.called
|
||||
|
||||
@respx.mock
|
||||
async def test_court_id_removes_foto_key(self):
|
||||
respx.get(f"{JUSTICE_API_BASE}/v1/sud/sud_1").mock(
|
||||
return_value=httpx.Response(200, json={"name": "Súd", "foto": "base64"})
|
||||
)
|
||||
result = await court_id.on_invoke_tool(None, '{"id": "1"}')
|
||||
assert "foto" not in str(result)
|
||||
|
||||
@respx.mock
|
||||
async def test_court_autocomplete_calls_correct_url(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud/autocomplete").mock(return_value=make_response())
|
||||
await court_autocomplete.on_invoke_tool(None, '{"query": "Kraj"}')
|
||||
assert mock.called
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestJudgeTools:
|
||||
|
||||
@respx.mock
|
||||
async def test_judge_search_calls_correct_url(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca").mock(return_value=make_response())
|
||||
await judge_search.on_invoke_tool(None, '{"query": "Novák"}')
|
||||
assert mock.called
|
||||
|
||||
@respx.mock
|
||||
async def test_judge_id_normalizes_digit_id(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca/sudca_1").mock(
|
||||
return_value=httpx.Response(200, json={"id": "sudca_1"})
|
||||
)
|
||||
await judge_id.on_invoke_tool(None, '{"id": "1"}')
|
||||
assert mock.called
|
||||
|
||||
@respx.mock
|
||||
async def test_judge_autocomplete_passes_guid_sud(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca/autocomplete").mock(return_value=make_response())
|
||||
await judge_autocomplete.on_invoke_tool(None, '{"query": "Novák", "guidSud": "sud_100"}')
|
||||
assert mock.called
|
||||
params = dict(mock.calls[0].request.url.params)
|
||||
assert params.get("guidSud") == "sud_100"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDecisionTools:
|
||||
|
||||
@respx.mock
|
||||
async def test_decision_search_with_date_range(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie").mock(return_value=make_response())
|
||||
await decision_search.on_invoke_tool(
|
||||
None, '{"vydaniaOd": "01.01.2024", "vydaniaDo": "31.01.2024"}'
|
||||
)
|
||||
assert mock.called
|
||||
params = dict(mock.calls[0].request.url.params)
|
||||
assert params.get("vydaniaOd") == "01.01.2024"
|
||||
|
||||
@respx.mock
|
||||
async def test_decision_search_with_guid_sudca(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie").mock(return_value=make_response())
|
||||
await decision_search.on_invoke_tool(None, '{"guidSudca": "sudca_1"}')
|
||||
params = dict(mock.calls[0].request.url.params)
|
||||
assert params.get("guidSudca") == "sudca_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestContractTools:
|
||||
|
||||
@respx.mock
|
||||
async def test_contract_search_with_guid_sud(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/zmluvy").mock(return_value=make_response())
|
||||
await contract_search.on_invoke_tool(None, '{"guidSud": "sud_7"}')
|
||||
params = dict(mock.calls[0].request.url.params)
|
||||
assert params.get("guidSud") == "sud_7"
|
||||
|
||||
@respx.mock
|
||||
async def test_contract_search_typ_dokumentu_filter(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/zmluvy").mock(return_value=make_response())
|
||||
await contract_search.on_invoke_tool(None, '{"typDokumentuFacetFilter": ["ZMLUVA"]}')
|
||||
assert mock.called
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestCivilAndAdminTools:
|
||||
|
||||
@respx.mock
|
||||
async def test_civil_proceedings_search(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania").mock(return_value=make_response())
|
||||
await civil_proceedings_search.on_invoke_tool(None, '{"query": "test"}')
|
||||
assert mock.called
|
||||
|
||||
@respx.mock
|
||||
async def test_admin_proceedings_search(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie").mock(return_value=make_response())
|
||||
await admin_proceedings_search.on_invoke_tool(None, '{"query": "test"}')
|
||||
assert mock.called
|
||||
|
||||
@respx.mock
|
||||
async def test_civil_proceedings_date_params(self):
|
||||
mock = respx.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania").mock(return_value=make_response())
|
||||
await civil_proceedings_search.on_invoke_tool(
|
||||
None, '{"pojednavaniaOd": "01.01.2024", "jednotnavaniaDo": "31.01.2024"}'
|
||||
)
|
||||
assert mock.called
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 191 KiB |
294
tests/runner.py
294
tests/runner.py
@ -1,294 +0,0 @@
|
||||
import asyncio
|
||||
import sqlite3
|
||||
import time
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from api.config import JUSTICE_API_BASE
|
||||
import httpx
|
||||
|
||||
DB_PATH = Path(__file__).parent / "test_queries.db"
|
||||
OUT_DIR = Path(__file__).parent / "results"
|
||||
REQUEST_TIMEOUT = 15
|
||||
CONCURRENT_LIMIT = 5
|
||||
DELAY_BETWEEN = 0.3
|
||||
|
||||
HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36",
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Accept-Language": "sk-SK,sk;q=0.9",
|
||||
"Referer": "https://obcan.justice.sk/",
|
||||
"Origin": "https://obcan.justice.sk",
|
||||
}
|
||||
|
||||
TOOL_MAP = {
|
||||
"judge": (f"{JUSTICE_API_BASE}/v1/sudca","search"),
|
||||
"judge_id": (f"{JUSTICE_API_BASE}/v1/sudca","id"),
|
||||
"judge_autocomplete": (f"{JUSTICE_API_BASE}/v1/sudca/autocomplete","autocomplete"),
|
||||
"court": (f"{JUSTICE_API_BASE}/v1/sud","search"),
|
||||
"court_id": (f"{JUSTICE_API_BASE}/v1/sud","id"),
|
||||
"court_autocomplete": (f"{JUSTICE_API_BASE}/v1/sud/autocomplete","autocomplete"),
|
||||
"decision": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","search"),
|
||||
"decision_id": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","id"),
|
||||
"decision_autocomplete": (f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete","autocomplete"),
|
||||
"contract": (f"{JUSTICE_API_BASE}/v1/zmluvy","search"),
|
||||
"contract_id": (f"{JUSTICE_API_BASE}/v1/zmluvy","id"),
|
||||
"contract_autocomplete": (f"{JUSTICE_API_BASE}/v1/zmluvy/autocomplete","autocomplete"),
|
||||
"civil_proceedings": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","search"),
|
||||
"civil_proceedings_id": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","id"),
|
||||
"civil_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania/autocomplete","autocomplete"),
|
||||
"admin_proceedings": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","search"),
|
||||
"admin_proceedings_id": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","id"),
|
||||
"admin_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete","autocomplete"),
|
||||
"mimo_rozsah": (None,"skip"),
|
||||
}
|
||||
|
||||
|
||||
def load_queries(limit: int = None) -> list[dict]:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
sql = "SELECT * FROM test_queries ORDER BY id"
|
||||
if limit:
|
||||
sql += f" LIMIT {limit}"
|
||||
rows = conn.execute(sql).fetchall()
|
||||
conn.close()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def pick_keyword(row: dict) -> str:
|
||||
first = (row.get("expected_keywords") or "").split(",")[0].strip()
|
||||
if len(first) > 2:
|
||||
return first
|
||||
words = [w for w in row["query_sk"].split() if len(w) > 3]
|
||||
return " ".join(words[:2]) or row["query_sk"][:20]
|
||||
|
||||
|
||||
def count_results(data) -> int | None:
|
||||
if isinstance(data, list):
|
||||
return len(data)
|
||||
if isinstance(data, dict):
|
||||
for key in ("totalElements", "total", "count", "numberOfElements"):
|
||||
if key in data:
|
||||
return int(data[key])
|
||||
if "content" in data:
|
||||
return len(data["content"])
|
||||
return 1
|
||||
return None
|
||||
|
||||
|
||||
async def validate_one(
|
||||
client: httpx.AsyncClient,
|
||||
semaphore: asyncio.Semaphore,
|
||||
row: dict,
|
||||
idx: int,
|
||||
total: int,
|
||||
) -> dict:
|
||||
result = {
|
||||
"id": row["id"],
|
||||
"category": row["category"],
|
||||
"difficulty": row["difficulty"],
|
||||
"tool": row["expected_tool"],
|
||||
"query": row["query_sk"],
|
||||
"status": "pending",
|
||||
"http_code": None,
|
||||
"result_count": None,
|
||||
"duration_ms": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
endpoint, call_type = TOOL_MAP.get(row["expected_tool"], (None, "skip"))
|
||||
|
||||
print(f"Request {idx}/{total}")
|
||||
|
||||
if call_type == "skip":
|
||||
result["status"] = "SKIP"
|
||||
return result
|
||||
|
||||
if call_type == "id":
|
||||
id_val = (row.get("expected_keywords") or "").split(",")[0].strip()
|
||||
url, params = f"{endpoint}/{id_val}", {}
|
||||
elif call_type == "autocomplete":
|
||||
url, params = endpoint, {"query": pick_keyword(row), "limit": 5}
|
||||
else:
|
||||
url, params = endpoint, {"query": pick_keyword(row), "size": 5}
|
||||
|
||||
async with semaphore:
|
||||
await asyncio.sleep(DELAY_BETWEEN)
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
resp = await client.get(url, params=params, headers=HEADERS, timeout=REQUEST_TIMEOUT)
|
||||
result["duration_ms"] = round((time.perf_counter() - t0) * 1000)
|
||||
result["http_code"] = resp.status_code
|
||||
|
||||
if resp.status_code == 200:
|
||||
try:
|
||||
cnt = count_results(resp.json())
|
||||
result["result_count"] = cnt
|
||||
result["status"] = "OK" if cnt and cnt > 0 else "EMPTY"
|
||||
except Exception:
|
||||
result["status"] = "OK"
|
||||
elif resp.status_code == 404:
|
||||
result["status"] = "NOT_FOUND"
|
||||
elif resp.status_code == 403:
|
||||
result["status"] = "FORBIDDEN"
|
||||
else:
|
||||
result["status"] = "HTTP_ERROR"
|
||||
result["error"] = f"HTTP {resp.status_code}"
|
||||
|
||||
except httpx.TimeoutException:
|
||||
result["status"] = "TIMEOUT"
|
||||
result["duration_ms"] = REQUEST_TIMEOUT * 1000
|
||||
except httpx.ConnectError as e:
|
||||
result["status"] = "CONN_ERROR"
|
||||
result["error"] = str(e)
|
||||
except Exception as e:
|
||||
result["status"] = "ERROR"
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_all(rows: list[dict]) -> list[dict]:
|
||||
semaphore = asyncio.Semaphore(CONCURRENT_LIMIT)
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
tasks = [
|
||||
validate_one(client, semaphore, row, i + 1, len(rows))
|
||||
for i, row in enumerate(rows)
|
||||
]
|
||||
return list(await asyncio.gather(*tasks))
|
||||
|
||||
|
||||
def generate_charts(results: list[dict], ts: str) -> Path:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
STATUS_COLORS = {
|
||||
"OK": "#1D9E75",
|
||||
"EMPTY": "#EF9F27",
|
||||
"SKIP": "#888780",
|
||||
"NOT_FOUND": "#E24B4A",
|
||||
"FORBIDDEN": "#D85A30",
|
||||
"TIMEOUT": "#D4537E",
|
||||
"HTTP_ERROR": "#E24B4A",
|
||||
"CONN_ERROR": "#E24B4A",
|
||||
"ERROR": "#E24B4A",
|
||||
}
|
||||
DIFF_COLORS = {"easy": "#1D9E75", "medium": "#EF9F27", "hard": "#E24B4A"}
|
||||
|
||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
used_statuses = [s for s in STATUS_COLORS if any(r["status"] == s for r in results)]
|
||||
|
||||
fig = plt.figure(figsize=(18, 18), facecolor="#FAFAF8")
|
||||
fig.suptitle(
|
||||
f"Validation Report - Legal AI Agent\n{ts.replace('_', ' ')}",
|
||||
fontsize=15, fontweight="bold", y=0.98, color="#2C2C2A"
|
||||
)
|
||||
# 2 rows: [pie + category] | [histogram + difficulty]
|
||||
gs = fig.add_gridspec(2, 2, hspace=0.45, wspace=0.35,
|
||||
top=0.94, bottom=0.04, left=0.07, right=0.97)
|
||||
|
||||
def styled(ax, title):
|
||||
ax.set_title(title, fontsize=12, fontweight="bold", color="#2C2C2A")
|
||||
ax.set_facecolor("#F9F9F7")
|
||||
ax.spines[["top", "right"]].set_visible(False)
|
||||
ax.grid(axis="y", alpha=0.3, linestyle="--")
|
||||
|
||||
# 1. Pie — overall results
|
||||
ax1 = fig.add_subplot(gs[0, 0])
|
||||
cnt = Counter(r["status"] for r in results)
|
||||
wedges, texts, pcts = ax1.pie(
|
||||
cnt.values(),
|
||||
labels=cnt.keys(),
|
||||
colors=[STATUS_COLORS.get(k, "#888780") for k in cnt],
|
||||
autopct="%1.1f%%", startangle=90, pctdistance=0.82,
|
||||
wedgeprops={"edgecolor": "white", "linewidth": 1.5},
|
||||
)
|
||||
for t in texts: t.set_fontsize(10)
|
||||
for t in pcts: t.set_fontsize(9); t.set_color("white"); t.set_fontweight("bold")
|
||||
ax1.set_title("Overall Results", fontsize=12, fontweight="bold", color="#2C2C2A")
|
||||
|
||||
# 2. Bar — results by category
|
||||
ax2 = fig.add_subplot(gs[0, 1])
|
||||
categories = sorted(set(r["category"] for r in results))
|
||||
cat_data = defaultdict(Counter)
|
||||
for r in results:
|
||||
cat_data[r["category"]][r["status"]] += 1
|
||||
|
||||
x, w = np.arange(len(categories)), 0.8 / max(len(used_statuses), 1)
|
||||
for i, status in enumerate(used_statuses):
|
||||
vals = [cat_data[c][status] for c in categories]
|
||||
offset = (i - len(used_statuses) / 2 + 0.5) * w
|
||||
ax2.bar(x + offset, vals, w, label=status,
|
||||
color=STATUS_COLORS.get(status, "#888780"),
|
||||
edgecolor="white", linewidth=0.5)
|
||||
|
||||
ax2.set_xticks(x)
|
||||
ax2.set_xticklabels([c[:13] for c in categories], rotation=35, ha="right", fontsize=8)
|
||||
ax2.set_ylabel("Count", fontsize=10)
|
||||
ax2.legend(fontsize=7, loc="upper right")
|
||||
styled(ax2, "Results by Category")
|
||||
|
||||
# 3. Histogram — response times
|
||||
ax4 = fig.add_subplot(gs[1, 0])
|
||||
times = [r["duration_ms"] for r in results
|
||||
if r["duration_ms"] and r["status"] != "SKIP"]
|
||||
if times:
|
||||
ax4.hist(times, bins=20, color="#378ADD", edgecolor="white", linewidth=0.7, alpha=0.85)
|
||||
avg = sum(times) / len(times)
|
||||
ax4.axvline(avg, color="#E24B4A", linewidth=1.5, linestyle="--",
|
||||
label=f"Avg: {avg:.0f} ms")
|
||||
ax4.set_xlabel("ms", fontsize=10)
|
||||
ax4.set_ylabel("Requests", fontsize=10)
|
||||
ax4.legend(fontsize=9)
|
||||
styled(ax4, "API Response Time Distribution")
|
||||
|
||||
# 4. Bar — results by difficulty
|
||||
ax5 = fig.add_subplot(gs[1, 1])
|
||||
difficulties = ["easy", "medium", "hard"]
|
||||
diff_data = defaultdict(Counter)
|
||||
for r in results:
|
||||
diff_data[r["difficulty"]][r["status"]] += 1
|
||||
|
||||
x2, w2 = np.arange(3), 0.8 / max(len(used_statuses), 1)
|
||||
for i, status in enumerate(used_statuses):
|
||||
vals = [diff_data[d][status] for d in difficulties]
|
||||
offset = (i - len(used_statuses) / 2 + 0.5) * w2
|
||||
ax5.bar(x2 + offset, vals, w2, label=status,
|
||||
color=STATUS_COLORS.get(status, "#888780"),
|
||||
edgecolor="white", linewidth=0.5)
|
||||
|
||||
ax5.set_xticks(x2)
|
||||
ax5.set_xticklabels(difficulties, fontsize=10)
|
||||
ax5.set_ylabel("Count", fontsize=10)
|
||||
ax5.legend(fontsize=8)
|
||||
styled(ax5, "Results by Difficulty")
|
||||
|
||||
|
||||
out = OUT_DIR / f"report_{ts}.png"
|
||||
fig.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
limit = int(sys.argv[1]) if len(sys.argv) > 1 else None
|
||||
|
||||
rows = load_queries(limit)
|
||||
ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
print(f"Start Test Validation")
|
||||
|
||||
results = asyncio.run(run_all(rows))
|
||||
|
||||
print(f"End Test Validation")
|
||||
|
||||
out = generate_charts(results, ts)
|
||||
print(f"Test image is saved: {out}")
|
||||
Binary file not shown.
Loading…
Reference in New Issue
Block a user