add validation test
This commit is contained in:
parent
397cc2158e
commit
18d7980582
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
BIN
tests/results/report_2026-03-16_06-26-14.png
Normal file
BIN
tests/results/report_2026-03-16_06-26-14.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 191 KiB |
294
tests/runner.py
Normal file
294
tests/runner.py
Normal file
@ -0,0 +1,294 @@
|
||||
import asyncio
|
||||
import sqlite3
|
||||
import time
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from api.config import JUSTICE_API_BASE
|
||||
import httpx
|
||||
|
||||
DB_PATH = Path(__file__).parent / "test_queries.db"
|
||||
OUT_DIR = Path(__file__).parent / "results"
|
||||
REQUEST_TIMEOUT = 15
|
||||
CONCURRENT_LIMIT = 5
|
||||
DELAY_BETWEEN = 0.3
|
||||
|
||||
HEADERS = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36",
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Accept-Language": "sk-SK,sk;q=0.9",
|
||||
"Referer": "https://obcan.justice.sk/",
|
||||
"Origin": "https://obcan.justice.sk",
|
||||
}
|
||||
|
||||
TOOL_MAP = {
|
||||
"judge": (f"{JUSTICE_API_BASE}/v1/sudca","search"),
|
||||
"judge_id": (f"{JUSTICE_API_BASE}/v1/sudca","id"),
|
||||
"judge_autocomplete": (f"{JUSTICE_API_BASE}/v1/sudca/autocomplete","autocomplete"),
|
||||
"court": (f"{JUSTICE_API_BASE}/v1/sud","search"),
|
||||
"court_id": (f"{JUSTICE_API_BASE}/v1/sud","id"),
|
||||
"court_autocomplete": (f"{JUSTICE_API_BASE}/v1/sud/autocomplete","autocomplete"),
|
||||
"decision": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","search"),
|
||||
"decision_id": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","id"),
|
||||
"decision_autocomplete": (f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete","autocomplete"),
|
||||
"contract": (f"{JUSTICE_API_BASE}/v1/zmluvy","search"),
|
||||
"contract_id": (f"{JUSTICE_API_BASE}/v1/zmluvy","id"),
|
||||
"contract_autocomplete": (f"{JUSTICE_API_BASE}/v1/zmluvy/autocomplete","autocomplete"),
|
||||
"civil_proceedings": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","search"),
|
||||
"civil_proceedings_id": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","id"),
|
||||
"civil_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania/autocomplete","autocomplete"),
|
||||
"admin_proceedings": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","search"),
|
||||
"admin_proceedings_id": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","id"),
|
||||
"admin_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete","autocomplete"),
|
||||
"mimo_rozsah": (None,"skip"),
|
||||
}
|
||||
|
||||
|
||||
def load_queries(limit: int = None) -> list[dict]:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
sql = "SELECT * FROM test_queries ORDER BY id"
|
||||
if limit:
|
||||
sql += f" LIMIT {limit}"
|
||||
rows = conn.execute(sql).fetchall()
|
||||
conn.close()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def pick_keyword(row: dict) -> str:
|
||||
first = (row.get("expected_keywords") or "").split(",")[0].strip()
|
||||
if len(first) > 2:
|
||||
return first
|
||||
words = [w for w in row["query_sk"].split() if len(w) > 3]
|
||||
return " ".join(words[:2]) or row["query_sk"][:20]
|
||||
|
||||
|
||||
def count_results(data) -> int | None:
|
||||
if isinstance(data, list):
|
||||
return len(data)
|
||||
if isinstance(data, dict):
|
||||
for key in ("totalElements", "total", "count", "numberOfElements"):
|
||||
if key in data:
|
||||
return int(data[key])
|
||||
if "content" in data:
|
||||
return len(data["content"])
|
||||
return 1
|
||||
return None
|
||||
|
||||
|
||||
async def validate_one(
|
||||
client: httpx.AsyncClient,
|
||||
semaphore: asyncio.Semaphore,
|
||||
row: dict,
|
||||
idx: int,
|
||||
total: int,
|
||||
) -> dict:
|
||||
result = {
|
||||
"id": row["id"],
|
||||
"category": row["category"],
|
||||
"difficulty": row["difficulty"],
|
||||
"tool": row["expected_tool"],
|
||||
"query": row["query_sk"],
|
||||
"status": "pending",
|
||||
"http_code": None,
|
||||
"result_count": None,
|
||||
"duration_ms": None,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
endpoint, call_type = TOOL_MAP.get(row["expected_tool"], (None, "skip"))
|
||||
|
||||
print(f"Request {idx}/{total}")
|
||||
|
||||
if call_type == "skip":
|
||||
result["status"] = "SKIP"
|
||||
return result
|
||||
|
||||
if call_type == "id":
|
||||
id_val = (row.get("expected_keywords") or "").split(",")[0].strip()
|
||||
url, params = f"{endpoint}/{id_val}", {}
|
||||
elif call_type == "autocomplete":
|
||||
url, params = endpoint, {"query": pick_keyword(row), "limit": 5}
|
||||
else:
|
||||
url, params = endpoint, {"query": pick_keyword(row), "size": 5}
|
||||
|
||||
async with semaphore:
|
||||
await asyncio.sleep(DELAY_BETWEEN)
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
resp = await client.get(url, params=params, headers=HEADERS, timeout=REQUEST_TIMEOUT)
|
||||
result["duration_ms"] = round((time.perf_counter() - t0) * 1000)
|
||||
result["http_code"] = resp.status_code
|
||||
|
||||
if resp.status_code == 200:
|
||||
try:
|
||||
cnt = count_results(resp.json())
|
||||
result["result_count"] = cnt
|
||||
result["status"] = "OK" if cnt and cnt > 0 else "EMPTY"
|
||||
except Exception:
|
||||
result["status"] = "OK"
|
||||
elif resp.status_code == 404:
|
||||
result["status"] = "NOT_FOUND"
|
||||
elif resp.status_code == 403:
|
||||
result["status"] = "FORBIDDEN"
|
||||
else:
|
||||
result["status"] = "HTTP_ERROR"
|
||||
result["error"] = f"HTTP {resp.status_code}"
|
||||
|
||||
except httpx.TimeoutException:
|
||||
result["status"] = "TIMEOUT"
|
||||
result["duration_ms"] = REQUEST_TIMEOUT * 1000
|
||||
except httpx.ConnectError as e:
|
||||
result["status"] = "CONN_ERROR"
|
||||
result["error"] = str(e)
|
||||
except Exception as e:
|
||||
result["status"] = "ERROR"
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_all(rows: list[dict]) -> list[dict]:
|
||||
semaphore = asyncio.Semaphore(CONCURRENT_LIMIT)
|
||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
||||
tasks = [
|
||||
validate_one(client, semaphore, row, i + 1, len(rows))
|
||||
for i, row in enumerate(rows)
|
||||
]
|
||||
return list(await asyncio.gather(*tasks))
|
||||
|
||||
|
||||
def generate_charts(results: list[dict], ts: str) -> Path:
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
STATUS_COLORS = {
|
||||
"OK": "#1D9E75",
|
||||
"EMPTY": "#EF9F27",
|
||||
"SKIP": "#888780",
|
||||
"NOT_FOUND": "#E24B4A",
|
||||
"FORBIDDEN": "#D85A30",
|
||||
"TIMEOUT": "#D4537E",
|
||||
"HTTP_ERROR": "#E24B4A",
|
||||
"CONN_ERROR": "#E24B4A",
|
||||
"ERROR": "#E24B4A",
|
||||
}
|
||||
DIFF_COLORS = {"easy": "#1D9E75", "medium": "#EF9F27", "hard": "#E24B4A"}
|
||||
|
||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
used_statuses = [s for s in STATUS_COLORS if any(r["status"] == s for r in results)]
|
||||
|
||||
fig = plt.figure(figsize=(18, 18), facecolor="#FAFAF8")
|
||||
fig.suptitle(
|
||||
f"Validation Report - Legal AI Agent\n{ts.replace('_', ' ')}",
|
||||
fontsize=15, fontweight="bold", y=0.98, color="#2C2C2A"
|
||||
)
|
||||
# 2 rows: [pie + category] | [histogram + difficulty]
|
||||
gs = fig.add_gridspec(2, 2, hspace=0.45, wspace=0.35,
|
||||
top=0.94, bottom=0.04, left=0.07, right=0.97)
|
||||
|
||||
def styled(ax, title):
|
||||
ax.set_title(title, fontsize=12, fontweight="bold", color="#2C2C2A")
|
||||
ax.set_facecolor("#F9F9F7")
|
||||
ax.spines[["top", "right"]].set_visible(False)
|
||||
ax.grid(axis="y", alpha=0.3, linestyle="--")
|
||||
|
||||
# 1. Pie — overall results
|
||||
ax1 = fig.add_subplot(gs[0, 0])
|
||||
cnt = Counter(r["status"] for r in results)
|
||||
wedges, texts, pcts = ax1.pie(
|
||||
cnt.values(),
|
||||
labels=cnt.keys(),
|
||||
colors=[STATUS_COLORS.get(k, "#888780") for k in cnt],
|
||||
autopct="%1.1f%%", startangle=90, pctdistance=0.82,
|
||||
wedgeprops={"edgecolor": "white", "linewidth": 1.5},
|
||||
)
|
||||
for t in texts: t.set_fontsize(10)
|
||||
for t in pcts: t.set_fontsize(9); t.set_color("white"); t.set_fontweight("bold")
|
||||
ax1.set_title("Overall Results", fontsize=12, fontweight="bold", color="#2C2C2A")
|
||||
|
||||
# 2. Bar — results by category
|
||||
ax2 = fig.add_subplot(gs[0, 1])
|
||||
categories = sorted(set(r["category"] for r in results))
|
||||
cat_data = defaultdict(Counter)
|
||||
for r in results:
|
||||
cat_data[r["category"]][r["status"]] += 1
|
||||
|
||||
x, w = np.arange(len(categories)), 0.8 / max(len(used_statuses), 1)
|
||||
for i, status in enumerate(used_statuses):
|
||||
vals = [cat_data[c][status] for c in categories]
|
||||
offset = (i - len(used_statuses) / 2 + 0.5) * w
|
||||
ax2.bar(x + offset, vals, w, label=status,
|
||||
color=STATUS_COLORS.get(status, "#888780"),
|
||||
edgecolor="white", linewidth=0.5)
|
||||
|
||||
ax2.set_xticks(x)
|
||||
ax2.set_xticklabels([c[:13] for c in categories], rotation=35, ha="right", fontsize=8)
|
||||
ax2.set_ylabel("Count", fontsize=10)
|
||||
ax2.legend(fontsize=7, loc="upper right")
|
||||
styled(ax2, "Results by Category")
|
||||
|
||||
# 3. Histogram — response times
|
||||
ax4 = fig.add_subplot(gs[1, 0])
|
||||
times = [r["duration_ms"] for r in results
|
||||
if r["duration_ms"] and r["status"] != "SKIP"]
|
||||
if times:
|
||||
ax4.hist(times, bins=20, color="#378ADD", edgecolor="white", linewidth=0.7, alpha=0.85)
|
||||
avg = sum(times) / len(times)
|
||||
ax4.axvline(avg, color="#E24B4A", linewidth=1.5, linestyle="--",
|
||||
label=f"Avg: {avg:.0f} ms")
|
||||
ax4.set_xlabel("ms", fontsize=10)
|
||||
ax4.set_ylabel("Requests", fontsize=10)
|
||||
ax4.legend(fontsize=9)
|
||||
styled(ax4, "API Response Time Distribution")
|
||||
|
||||
# 4. Bar — results by difficulty
|
||||
ax5 = fig.add_subplot(gs[1, 1])
|
||||
difficulties = ["easy", "medium", "hard"]
|
||||
diff_data = defaultdict(Counter)
|
||||
for r in results:
|
||||
diff_data[r["difficulty"]][r["status"]] += 1
|
||||
|
||||
x2, w2 = np.arange(3), 0.8 / max(len(used_statuses), 1)
|
||||
for i, status in enumerate(used_statuses):
|
||||
vals = [diff_data[d][status] for d in difficulties]
|
||||
offset = (i - len(used_statuses) / 2 + 0.5) * w2
|
||||
ax5.bar(x2 + offset, vals, w2, label=status,
|
||||
color=STATUS_COLORS.get(status, "#888780"),
|
||||
edgecolor="white", linewidth=0.5)
|
||||
|
||||
ax5.set_xticks(x2)
|
||||
ax5.set_xticklabels(difficulties, fontsize=10)
|
||||
ax5.set_ylabel("Count", fontsize=10)
|
||||
ax5.legend(fontsize=8)
|
||||
styled(ax5, "Results by Difficulty")
|
||||
|
||||
|
||||
out = OUT_DIR / f"report_{ts}.png"
|
||||
fig.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
|
||||
plt.close(fig)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
limit = int(sys.argv[1]) if len(sys.argv) > 1 else None
|
||||
|
||||
rows = load_queries(limit)
|
||||
ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
print(f"Start Test Validation")
|
||||
|
||||
results = asyncio.run(run_all(rows))
|
||||
|
||||
print(f"End Test Validation")
|
||||
|
||||
out = generate_charts(results, ts)
|
||||
print(f"Test image is saved: {out}")
|
||||
BIN
tests/test_queries.db
Normal file
BIN
tests/test_queries.db
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user