diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/results/report_2026-03-16_06-26-14.png b/tests/results/report_2026-03-16_06-26-14.png new file mode 100644 index 0000000..d1dc63e Binary files /dev/null and b/tests/results/report_2026-03-16_06-26-14.png differ diff --git a/tests/runner.py b/tests/runner.py new file mode 100644 index 0000000..36d92ac --- /dev/null +++ b/tests/runner.py @@ -0,0 +1,294 @@ +import asyncio +import sqlite3 +import time +import sys +from datetime import datetime +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from api.config import JUSTICE_API_BASE +import httpx + +DB_PATH = Path(__file__).parent / "test_queries.db" +OUT_DIR = Path(__file__).parent / "results" +REQUEST_TIMEOUT = 15 +CONCURRENT_LIMIT = 5 +DELAY_BETWEEN = 0.3 + +HEADERS = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36", + "Accept": "application/json, text/plain, */*", + "Accept-Language": "sk-SK,sk;q=0.9", + "Referer": "https://obcan.justice.sk/", + "Origin": "https://obcan.justice.sk", +} + +TOOL_MAP = { + "judge": (f"{JUSTICE_API_BASE}/v1/sudca","search"), + "judge_id": (f"{JUSTICE_API_BASE}/v1/sudca","id"), + "judge_autocomplete": (f"{JUSTICE_API_BASE}/v1/sudca/autocomplete","autocomplete"), + "court": (f"{JUSTICE_API_BASE}/v1/sud","search"), + "court_id": (f"{JUSTICE_API_BASE}/v1/sud","id"), + "court_autocomplete": (f"{JUSTICE_API_BASE}/v1/sud/autocomplete","autocomplete"), + "decision": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","search"), + "decision_id": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","id"), + "decision_autocomplete": (f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete","autocomplete"), + "contract": (f"{JUSTICE_API_BASE}/v1/zmluvy","search"), + "contract_id": (f"{JUSTICE_API_BASE}/v1/zmluvy","id"), + "contract_autocomplete": (f"{JUSTICE_API_BASE}/v1/zmluvy/autocomplete","autocomplete"), + "civil_proceedings": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","search"), + "civil_proceedings_id": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","id"), + "civil_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania/autocomplete","autocomplete"), + "admin_proceedings": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","search"), + "admin_proceedings_id": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","id"), + "admin_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete","autocomplete"), + "mimo_rozsah": (None,"skip"), +} + + +def load_queries(limit: int = None) -> list[dict]: + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + sql = "SELECT * FROM test_queries ORDER BY id" + if limit: + sql += f" LIMIT {limit}" + rows = conn.execute(sql).fetchall() + conn.close() + return [dict(r) for r in rows] + + +def pick_keyword(row: dict) -> str: + first = (row.get("expected_keywords") or "").split(",")[0].strip() + if len(first) > 2: + return first + words = [w for w in row["query_sk"].split() if len(w) > 3] + return " ".join(words[:2]) or row["query_sk"][:20] + + +def count_results(data) -> int | None: + if isinstance(data, list): + return len(data) + if isinstance(data, dict): + for key in ("totalElements", "total", "count", "numberOfElements"): + if key in data: + return int(data[key]) + if "content" in data: + return len(data["content"]) + return 1 + return None + + +async def validate_one( + client: httpx.AsyncClient, + semaphore: asyncio.Semaphore, + row: dict, + idx: int, + total: int, +) -> dict: + result = { + "id": row["id"], + "category": row["category"], + "difficulty": row["difficulty"], + "tool": row["expected_tool"], + "query": row["query_sk"], + "status": "pending", + "http_code": None, + "result_count": None, + "duration_ms": None, + "error": None, + } + + endpoint, call_type = TOOL_MAP.get(row["expected_tool"], (None, "skip")) + + print(f"Request {idx}/{total}") + + if call_type == "skip": + result["status"] = "SKIP" + return result + + if call_type == "id": + id_val = (row.get("expected_keywords") or "").split(",")[0].strip() + url, params = f"{endpoint}/{id_val}", {} + elif call_type == "autocomplete": + url, params = endpoint, {"query": pick_keyword(row), "limit": 5} + else: + url, params = endpoint, {"query": pick_keyword(row), "size": 5} + + async with semaphore: + await asyncio.sleep(DELAY_BETWEEN) + t0 = time.perf_counter() + try: + resp = await client.get(url, params=params, headers=HEADERS, timeout=REQUEST_TIMEOUT) + result["duration_ms"] = round((time.perf_counter() - t0) * 1000) + result["http_code"] = resp.status_code + + if resp.status_code == 200: + try: + cnt = count_results(resp.json()) + result["result_count"] = cnt + result["status"] = "OK" if cnt and cnt > 0 else "EMPTY" + except Exception: + result["status"] = "OK" + elif resp.status_code == 404: + result["status"] = "NOT_FOUND" + elif resp.status_code == 403: + result["status"] = "FORBIDDEN" + else: + result["status"] = "HTTP_ERROR" + result["error"] = f"HTTP {resp.status_code}" + + except httpx.TimeoutException: + result["status"] = "TIMEOUT" + result["duration_ms"] = REQUEST_TIMEOUT * 1000 + except httpx.ConnectError as e: + result["status"] = "CONN_ERROR" + result["error"] = str(e) + except Exception as e: + result["status"] = "ERROR" + result["error"] = str(e) + + return result + + +async def run_all(rows: list[dict]) -> list[dict]: + semaphore = asyncio.Semaphore(CONCURRENT_LIMIT) + async with httpx.AsyncClient(follow_redirects=True) as client: + tasks = [ + validate_one(client, semaphore, row, i + 1, len(rows)) + for i, row in enumerate(rows) + ] + return list(await asyncio.gather(*tasks)) + + +def generate_charts(results: list[dict], ts: str) -> Path: + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import matplotlib.patches as mpatches + import numpy as np + from collections import Counter, defaultdict + + STATUS_COLORS = { + "OK": "#1D9E75", + "EMPTY": "#EF9F27", + "SKIP": "#888780", + "NOT_FOUND": "#E24B4A", + "FORBIDDEN": "#D85A30", + "TIMEOUT": "#D4537E", + "HTTP_ERROR": "#E24B4A", + "CONN_ERROR": "#E24B4A", + "ERROR": "#E24B4A", + } + DIFF_COLORS = {"easy": "#1D9E75", "medium": "#EF9F27", "hard": "#E24B4A"} + + OUT_DIR.mkdir(parents=True, exist_ok=True) + + used_statuses = [s for s in STATUS_COLORS if any(r["status"] == s for r in results)] + + fig = plt.figure(figsize=(18, 18), facecolor="#FAFAF8") + fig.suptitle( + f"Validation Report - Legal AI Agent\n{ts.replace('_', ' ')}", + fontsize=15, fontweight="bold", y=0.98, color="#2C2C2A" + ) + # 2 rows: [pie + category] | [histogram + difficulty] + gs = fig.add_gridspec(2, 2, hspace=0.45, wspace=0.35, + top=0.94, bottom=0.04, left=0.07, right=0.97) + + def styled(ax, title): + ax.set_title(title, fontsize=12, fontweight="bold", color="#2C2C2A") + ax.set_facecolor("#F9F9F7") + ax.spines[["top", "right"]].set_visible(False) + ax.grid(axis="y", alpha=0.3, linestyle="--") + + # 1. Pie — overall results + ax1 = fig.add_subplot(gs[0, 0]) + cnt = Counter(r["status"] for r in results) + wedges, texts, pcts = ax1.pie( + cnt.values(), + labels=cnt.keys(), + colors=[STATUS_COLORS.get(k, "#888780") for k in cnt], + autopct="%1.1f%%", startangle=90, pctdistance=0.82, + wedgeprops={"edgecolor": "white", "linewidth": 1.5}, + ) + for t in texts: t.set_fontsize(10) + for t in pcts: t.set_fontsize(9); t.set_color("white"); t.set_fontweight("bold") + ax1.set_title("Overall Results", fontsize=12, fontweight="bold", color="#2C2C2A") + + # 2. Bar — results by category + ax2 = fig.add_subplot(gs[0, 1]) + categories = sorted(set(r["category"] for r in results)) + cat_data = defaultdict(Counter) + for r in results: + cat_data[r["category"]][r["status"]] += 1 + + x, w = np.arange(len(categories)), 0.8 / max(len(used_statuses), 1) + for i, status in enumerate(used_statuses): + vals = [cat_data[c][status] for c in categories] + offset = (i - len(used_statuses) / 2 + 0.5) * w + ax2.bar(x + offset, vals, w, label=status, + color=STATUS_COLORS.get(status, "#888780"), + edgecolor="white", linewidth=0.5) + + ax2.set_xticks(x) + ax2.set_xticklabels([c[:13] for c in categories], rotation=35, ha="right", fontsize=8) + ax2.set_ylabel("Count", fontsize=10) + ax2.legend(fontsize=7, loc="upper right") + styled(ax2, "Results by Category") + + # 3. Histogram — response times + ax4 = fig.add_subplot(gs[1, 0]) + times = [r["duration_ms"] for r in results + if r["duration_ms"] and r["status"] != "SKIP"] + if times: + ax4.hist(times, bins=20, color="#378ADD", edgecolor="white", linewidth=0.7, alpha=0.85) + avg = sum(times) / len(times) + ax4.axvline(avg, color="#E24B4A", linewidth=1.5, linestyle="--", + label=f"Avg: {avg:.0f} ms") + ax4.set_xlabel("ms", fontsize=10) + ax4.set_ylabel("Requests", fontsize=10) + ax4.legend(fontsize=9) + styled(ax4, "API Response Time Distribution") + + # 4. Bar — results by difficulty + ax5 = fig.add_subplot(gs[1, 1]) + difficulties = ["easy", "medium", "hard"] + diff_data = defaultdict(Counter) + for r in results: + diff_data[r["difficulty"]][r["status"]] += 1 + + x2, w2 = np.arange(3), 0.8 / max(len(used_statuses), 1) + for i, status in enumerate(used_statuses): + vals = [diff_data[d][status] for d in difficulties] + offset = (i - len(used_statuses) / 2 + 0.5) * w2 + ax5.bar(x2 + offset, vals, w2, label=status, + color=STATUS_COLORS.get(status, "#888780"), + edgecolor="white", linewidth=0.5) + + ax5.set_xticks(x2) + ax5.set_xticklabels(difficulties, fontsize=10) + ax5.set_ylabel("Count", fontsize=10) + ax5.legend(fontsize=8) + styled(ax5, "Results by Difficulty") + + + out = OUT_DIR / f"report_{ts}.png" + fig.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) + plt.close(fig) + return out + + +if __name__ == "__main__": + limit = int(sys.argv[1]) if len(sys.argv) > 1 else None + + rows = load_queries(limit) + ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + print(f"Start Test Validation") + + results = asyncio.run(run_all(rows)) + + print(f"End Test Validation") + + out = generate_charts(results, ts) + print(f"Test image is saved: {out}") \ No newline at end of file diff --git a/tests/test_queries.db b/tests/test_queries.db new file mode 100644 index 0000000..d2b25a1 Binary files /dev/null and b/tests/test_queries.db differ