add tests

This commit is contained in:
G0DSEND016 2026-03-23 02:55:42 +01:00
parent 18d7980582
commit 1a7e8aa355
17 changed files with 1189 additions and 311 deletions

16
app.py
View File

@ -9,15 +9,19 @@ from core.stream_response import stream_response
from api.fetch_api_data import set_log_callback
STARTERS = [
("What legal data can the agent find?","magnifying_glass"),
("What is the agent not allowed to do or use?","ban"),
("What are the details of your AI model?","hexagon"),
("What data sources does the agent rely on?","database"),
("What legal data can the agent find?", "magnifying_glass"),
("What is the agent not allowed to do or use?", "ban"),
("What are the details of your AI model?", "hexagon"),
("What data sources does the agent rely on?", "database"),
]
PROFILES = [
("qwen3.5:cloud","Qwen 3.5 CLOUD"),
("gpt-oss:20b-cloud","GPT-OSS 20B CLOUD"),
("qwen3.5:cloud", "Qwen 3.5 CLOUD (in Ollama)"),
("gpt-oss:20b-cloud", "GPT-OSS 20B CLOUD (in Ollama)"),
("gpt-oss:20b", "GPT-OSS 20B (Local LLM)"),
("qwen3:8b", "Qwen3 8B (Local LLM)"),
("gpt-4o", "GPT-4o (OpenAI API)"),
("gpt-4o-mini", "GPT-4o Mini (OpenAI API)"),
]
@cl.set_starters

View File

@ -2,7 +2,13 @@ import os
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "qwen3.5:cloud")
MAX_HISTORY = int(os.getenv("MAX_HISTORY", "20"))
AGENT_TEMPERATURE = float(os.getenv("AGENT_TEMPERATURE", "0.7"))
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "120.0"))
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434/v1")
OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "ollama")
OLLAMA_TIMEOUT = float(os.getenv("OLLAMA_TIMEOUT", "120.0"))
AGENT_TEMPERATURE = float(os.getenv("AGENT_TEMPERATURE", "0.7"))
OLLAMA_MODELS = {"qwen3.5:cloud", "gpt-oss:20b-cloud", "gpt-oss:20b", "qwen3:8b"}
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
OPENAI_MODELS = {"gpt-4o", "gpt-4o-mini"}

View File

@ -2,7 +2,12 @@ from agents import Agent, AgentHooks
from agents import OpenAIChatCompletionsModel, AsyncOpenAI, ModelSettings
from agents import set_tracing_disabled
from core.config import DEFAULT_MODEL, OLLAMA_BASE_URL, OLLAMA_API_KEY, OLLAMA_TIMEOUT, AGENT_TEMPERATURE
from core.config import (
DEFAULT_MODEL, AGENT_TEMPERATURE, LLM_TIMEOUT,
OLLAMA_BASE_URL, OLLAMA_API_KEY,
OPENAI_BASE_URL, OPENAI_API_KEY,
OPENAI_MODELS, OLLAMA_MODELS
)
from core.system_prompt import get_system_prompt
from api.tools import ALL_TOOLS
@ -15,18 +20,35 @@ class MyAgentHooks(AgentHooks):
async def on_end(self, context, agent, output):
print(f"🏁 [AgentHooks] {agent.name} ended.")
def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent:
def _make_client(model_name: str) -> AsyncOpenAI:
"""Return the correct AsyncOpenAI client based on model name"""
client = AsyncOpenAI(
base_url=OLLAMA_BASE_URL,
api_key=OLLAMA_API_KEY,
timeout=OLLAMA_TIMEOUT,
max_retries=0
)
if model_name in OPENAI_MODELS:
return AsyncOpenAI(
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
timeout=LLM_TIMEOUT,
max_retries=0,
)
if model_name in OLLAMA_MODELS:
return AsyncOpenAI(
base_url=OLLAMA_BASE_URL,
api_key=OLLAMA_API_KEY,
timeout=LLM_TIMEOUT,
max_retries=0,
)
raise ValueError(f"Model {model_name} not supported")
def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent:
"""Initialize the assistant agent for legal work"""
client = _make_client(model_name)
model = OpenAIChatCompletionsModel(model=model_name, openai_client=client)
agent = Agent(
name="Assistant",
name="AI Lawyer Assistant",
instructions=get_system_prompt(model_name),
model=model,
model_settings=ModelSettings(temperature=AGENT_TEMPERATURE, tool_choice="auto", parallel_tool_calls=False),

Binary file not shown.

9
testing/fixtures.py Normal file
View File

@ -0,0 +1,9 @@
import pytest
import asyncio
@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()

245
testing/reports/charts.py Normal file
View File

@ -0,0 +1,245 @@
import json
from pathlib import Path
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
matplotlib.use("Agg")
RESULTS_FILE = Path(__file__).parent.parent / "results.json"
CHARTS_DIR = Path(__file__).parent.parent / "charts"
STATUS_COLORS = {
"passed": "#3ddc84",
"failed": "#ff5c5c",
"skipped": "#e0c84a",
"error": "#ff8a8a",
}
BG_COLOR = "#1a1a1a"
PANEL_COLOR = "#242424"
TEXT_COLOR = "#e0e0e0"
GRID_COLOR = "#333333"
FONT_MONO = "monospace"
def _style_ax(ax, title: str):
ax.set_facecolor(PANEL_COLOR)
ax.set_title(title, color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10)
ax.tick_params(colors=TEXT_COLOR, labelsize=8)
ax.spines[:].set_color(GRID_COLOR)
ax.yaxis.grid(True, color=GRID_COLOR, linewidth=0.5, linestyle="--")
ax.set_axisbelow(True)
for label in ax.get_xticklabels() + ax.get_yticklabels():
label.set_color(TEXT_COLOR)
label.set_fontfamily(FONT_MONO)
def load_data() -> tuple[pd.DataFrame, dict]:
if not RESULTS_FILE.exists():
raise FileNotFoundError(f"results.json not found at {RESULTS_FILE}")
raw = json.loads(RESULTS_FILE.read_text(encoding="utf-8"))
tests = raw.get("tests", [])
rows = []
for t in tests:
node = t["nodeid"]
parts = node.split("::")
module = parts[0].replace("tests/", "").replace("/", ".").replace(".py", "")
cls = parts[1] if len(parts) >= 3 else "unknown"
name = parts[-1]
duration = t.get("call", {}).get("duration", 0.0) if t.get("call") else 0.0
rows.append({
"module": module,
"class": cls,
"name": name,
"outcome": t["outcome"],
"duration": duration,
})
df = pd.DataFrame(rows)
summary = raw.get("summary", {})
return df, summary
def chart_overall_status(df: pd.DataFrame, ax: plt.Axes):
counts = df["outcome"].value_counts()
colors = [STATUS_COLORS.get(k, "#888") for k in counts.index]
wedges, texts, pcts = ax.pie(
counts.values,
labels=counts.index,
colors=colors,
autopct="%1.1f%%",
startangle=90,
pctdistance=0.78,
wedgeprops={"edgecolor": BG_COLOR, "linewidth": 2},
)
for t in texts:
t.set_color(TEXT_COLOR)
t.set_fontsize(9)
t.set_fontfamily(FONT_MONO)
for p in pcts:
p.set_color("#111")
p.set_fontsize(8)
p.set_fontweight("bold")
ax.set_facecolor(PANEL_COLOR)
ax.set_title("Overall Results", color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10)
def chart_by_module(df: pd.DataFrame, ax: plt.Axes):
pivot = (
df.groupby(["module", "outcome"])
.size()
.unstack(fill_value=0)
.reindex(columns=["passed", "failed", "skipped", "error"], fill_value=0)
)
x = np.arange(len(pivot))
width = 0.2
offset = -(len(pivot.columns) - 1) / 2 * width
for i, col in enumerate(pivot.columns):
bars = ax.bar(
x + offset + i * width,
pivot[col],
width,
label=col,
color=STATUS_COLORS.get(col, "#888"),
edgecolor=BG_COLOR,
linewidth=0.8,
)
ax.set_xticks(x)
ax.set_xticklabels(pivot.index, rotation=25, ha="right", fontsize=7, fontfamily=FONT_MONO)
ax.set_ylabel("Tests", color=TEXT_COLOR, fontsize=9)
ax.legend(
fontsize=7,
labelcolor=TEXT_COLOR,
facecolor=PANEL_COLOR,
edgecolor=GRID_COLOR,
)
_style_ax(ax, "Results by Module")
def chart_duration_histogram(df: pd.DataFrame, ax: plt.Axes):
durations = df.loc[df["outcome"] != "skipped", "duration"].values * 1000
if len(durations) == 0:
ax.text(0.5, 0.5, "No data", ha="center", va="center", color=TEXT_COLOR)
_style_ax(ax, "Test Duration (ms)")
return
mean_ms = float(np.mean(durations))
median_ms = float(np.median(durations))
p95_ms = float(np.percentile(durations, 95))
ax.hist(durations, bins=20, color="#5cb8ff", edgecolor=BG_COLOR, linewidth=0.6, alpha=0.85)
ax.axvline(mean_ms, color="#3ddc84", linewidth=1.5, linestyle="--", label=f"Mean {mean_ms:.1f} ms")
ax.axvline(median_ms, color="#e0c84a", linewidth=1.5, linestyle=":", label=f"Median {median_ms:.1f} ms")
ax.axvline(p95_ms, color="#ff5c5c", linewidth=1.5, linestyle="-.", label=f"P95 {p95_ms:.1f} ms")
ax.set_xlabel("ms", color=TEXT_COLOR, fontsize=9)
ax.set_ylabel("Tests", color=TEXT_COLOR, fontsize=9)
ax.legend(fontsize=7, labelcolor=TEXT_COLOR, facecolor=PANEL_COLOR, edgecolor=GRID_COLOR)
_style_ax(ax, "Test Duration (ms)")
def chart_slowest_tests(df: pd.DataFrame, ax: plt.Axes):
top = (
df[df["outcome"] != "skipped"]
.nlargest(10, "duration")
.copy()
)
top["label"] = top["class"] + "::" + top["name"]
top["duration"] = top["duration"] * 1000
colors = [STATUS_COLORS.get(o, "#888") for o in top["outcome"]]
bars = ax.barh(top["label"], top["duration"], color=colors, edgecolor=BG_COLOR, linewidth=0.6)
ax.set_xlabel("ms", color=TEXT_COLOR, fontsize=9)
ax.tick_params(axis="y", labelsize=7)
ax.invert_yaxis()
_style_ax(ax, "Top 10 Slowest Tests")
def chart_stats_table(df: pd.DataFrame, summary: dict, ax: plt.Axes):
ax.set_facecolor(PANEL_COLOR)
ax.axis("off")
total = len(df)
passed = summary.get("passed", 0)
failed = summary.get("failed", 0)
skipped = summary.get("skipped", 0)
duration = df["duration"].sum() * 1000
durations = df.loc[df["outcome"] != "skipped", "duration"].values * 1000
rows = [
["Total tests", str(total)],
["Passed", str(passed)],
["Failed", str(failed)],
["Skipped", str(skipped)],
["Pass rate", f"{passed / total * 100:.1f}%" if total else ""],
["Total time", f"{duration:.0f} ms"],
["Mean duration", f"{np.mean(durations):.1f} ms" if len(durations) else ""],
["Median", f"{np.median(durations):.1f} ms" if len(durations) else ""],
["P95", f"{np.percentile(durations, 95):.1f} ms" if len(durations) else ""],
]
table = ax.table(
cellText=rows,
colLabels=["Metric", "Value"],
cellLoc="left",
loc="center",
colWidths=[0.6, 0.4],
)
table.auto_set_font_size(False)
table.set_fontsize(9)
for (row, col), cell in table.get_celld().items():
cell.set_facecolor("#2e2e2e" if row % 2 == 0 else PANEL_COLOR)
cell.set_edgecolor(GRID_COLOR)
cell.set_text_props(color=TEXT_COLOR, fontfamily=FONT_MONO)
ax.set_title("Summary", color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10)
def generate(output_path: Path = None) -> Path:
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
output_path = output_path or CHARTS_DIR / "report.png"
df, summary = load_data()
fig = plt.figure(figsize=(18, 14), facecolor=BG_COLOR)
fig.suptitle(
"Legal AI Assistant — Test Report",
fontsize=16, fontweight="bold", color=TEXT_COLOR,
y=0.98, fontfamily=FONT_MONO,
)
gs = fig.add_gridspec(2, 3, hspace=0.45, wspace=0.35,
top=0.93, bottom=0.05, left=0.06, right=0.97)
chart_overall_status(df, fig.add_subplot(gs[0, 0]))
chart_by_module(df, fig.add_subplot(gs[0, 1:]))
chart_duration_histogram(df, fig.add_subplot(gs[1, 0]))
chart_slowest_tests(df, fig.add_subplot(gs[1, 1]))
chart_stats_table(df, summary, fig.add_subplot(gs[1, 2]))
fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor=BG_COLOR)
plt.close(fig)
return output_path
if __name__ == "__main__":
out = generate()
print(f"Chart saved: {out}")

58
testing/run_tests.py Normal file
View File

@ -0,0 +1,58 @@
import subprocess
import sys
from pathlib import Path
ROOT = Path(__file__).parent.parent
TESTS_DIR = Path(__file__).parent
RESULTS_JSON = TESTS_DIR / "results.json"
REPORT_HTML = TESTS_DIR / "charts" / "report.html"
CHARTS_DIR = TESTS_DIR / "charts"
def run_pytest() -> int:
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
args = [
sys.executable, "-m", "pytest",
str(TESTS_DIR / "unit"),
str(TESTS_DIR / "integration"),
str(TESTS_DIR / "e2e"),
str(TESTS_DIR / "llm"),
"--json-report",
f"--json-report-file={RESULTS_JSON}",
f"--html={REPORT_HTML}",
"--self-contained-html",
f"--cov=api",
f"--cov=core",
"--cov-report=term-missing",
f"--cov-report=html:{CHARTS_DIR / 'coverage'}",
"-p", "no:terminal",
"--tb=short",
"-q",
]
return subprocess.run(args, cwd=str(ROOT), check=False).returncode
def run_charts():
from testing.reports.charts import generate
out = generate()
print(f"Charts saved: {out}")
if __name__ == "__main__":
print("Start testing")
code = run_pytest()
if RESULTS_JSON.exists():
run_charts()
print("Stop testing")
sys.exit(code)

152
testing/tests/test_api.py Normal file
View File

@ -0,0 +1,152 @@
import pytest
import httpx
from api.config import JUSTICE_API_BASE, HTTP_TIMEOUT
HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36",
"Accept": "application/json, text/plain, */*",
"Accept-Language": "sk-SK,sk;q=0.9",
"Referer": "https://obcan.justice.sk/",
"Origin": "https://obcan.justice.sk",
}
def api_available() -> bool:
try:
r = httpx.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 1}, headers=HEADERS, timeout=5)
return r.status_code == 200
except Exception:
return False
skip_if_offline = pytest.mark.skipif(
not api_available(),
reason="justice.sk API is not reachable"
)
@pytest.fixture(scope="module")
def client():
with httpx.Client(headers=HEADERS, timeout=HTTP_TIMEOUT, follow_redirects=True) as c:
yield c
@skip_if_offline
class TestCourtsEndpoint:
def test_returns_200(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 5})
assert r.status_code == 200
def test_response_has_content_key(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 5})
data = r.json()
assert "content" in data
def test_total_elements_is_positive(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 1})
data = r.json()
assert data.get("totalElements", 0) > 0
def test_court_by_id_returns_valid_record(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sud/sud_175")
assert r.status_code == 200
data = r.json()
assert "id" in data or "nazov" in data
def test_court_autocomplete_returns_list(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sud/autocomplete", params={"query": "Bratislava", "limit": 5})
assert r.status_code == 200
assert isinstance(r.json(), list)
def test_nonexistent_court_returns_404(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sud/sud_999999999")
assert r.status_code == 404
@skip_if_offline
class TestJudgesEndpoint:
def test_judge_search_returns_200(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={"size": 5})
assert r.status_code == 200
def test_judge_autocomplete_returns_results(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca/autocomplete", params={"query": "Novák", "limit": 10})
assert r.status_code == 200
def test_judge_search_by_kraj(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={
"krajFacetFilter": "Bratislavský kraj",
"size": 5
})
assert r.status_code == 200
def test_judge_search_pagination_page_zero(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={"page": 0, "size": 10})
data = r.json()
assert "content" in data
@skip_if_offline
class TestDecisionsEndpoint:
def test_decision_search_returns_200(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie", params={"size": 3})
assert r.status_code == 200
def test_decision_search_with_date_range(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie", params={
"vydaniaOd": "01.01.2023",
"vydaniaDo": "31.12.2023",
"size": 3,
})
assert r.status_code == 200
def test_decision_autocomplete(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete", params={"query": "Rozsudok", "limit": 5})
assert r.status_code == 200
@skip_if_offline
class TestContractsEndpoint:
def test_contract_search_returns_200(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/zmluvy", params={"size": 5})
assert r.status_code == 200
def test_contract_search_by_typ_dokumentu(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/zmluvy", params={
"typDokumentuFacetFilter": "ZMLUVA",
"size": 3,
})
assert r.status_code == 200
@skip_if_offline
class TestCivilProceedingsEndpoint:
def test_civil_proceedings_returns_200(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania", params={"size": 3})
assert r.status_code == 200
def test_civil_proceedings_date_filter(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania", params={
"pojednavaniaOd": "01.01.2024",
"pojednavaniaDo": "31.01.2024",
"size": 3,
})
assert r.status_code == 200
@skip_if_offline
class TestAdminProceedingsEndpoint:
def test_admin_proceedings_returns_200(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie", params={"size": 3})
assert r.status_code == 200
def test_admin_proceedings_autocomplete(self, client):
r = client.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete", params={"query": "test", "limit": 5})
assert r.status_code == 200

113
testing/tests/test_fetch.py Normal file
View File

@ -0,0 +1,113 @@
import pytest
import pytest_asyncio
import httpx
import respx
from api.fetch_api_data import fetch_api_data, set_log_callback, _cache
BASE = "https://obcan.justice.sk/pilot/api/ress-isu-service"
@pytest.fixture(autouse=True)
def clear_cache():
_cache.clear()
yield
_cache.clear()
@pytest.mark.asyncio
class TestFetchApiData:
@respx.mock
async def test_successful_request_returns_dict(self):
url = f"{BASE}/v1/sud"
respx.get(url).mock(return_value=httpx.Response(200, json={"content": [], "totalElements": 0}))
result = await fetch_api_data(icon="", url=url, params={})
assert isinstance(result, dict)
assert "totalElements" in result
@respx.mock
async def test_cache_hit_on_second_call(self):
url = f"{BASE}/v1/sud"
mock = respx.get(url).mock(return_value=httpx.Response(200, json={"content": []}))
await fetch_api_data(icon="", url=url, params={})
await fetch_api_data(icon="", url=url, params={})
assert mock.call_count == 1
@respx.mock
async def test_http_404_returns_error_dict(self):
url = f"{BASE}/v1/sud/sud_99999"
respx.get(url).mock(return_value=httpx.Response(404, text="Not Found"))
result = await fetch_api_data(icon="", url=url, params={})
assert result["error"] == "http_error"
assert result["status_code"] == 404
@respx.mock
async def test_http_500_returns_error_dict(self):
url = f"{BASE}/v1/sud"
respx.get(url).mock(return_value=httpx.Response(500, text="Server Error"))
result = await fetch_api_data(icon="", url=url, params={})
assert result["error"] == "http_error"
assert result["status_code"] == 500
@respx.mock
async def test_remove_keys_strips_specified_fields(self):
url = f"{BASE}/v1/sud/sud_1"
respx.get(url).mock(return_value=httpx.Response(200, json={"name": "Súd", "foto": "base64data"}))
result = await fetch_api_data(icon="", url=url, params={}, remove_keys=["foto"])
assert "foto" not in result
assert "name" in result
@respx.mock
async def test_remove_keys_missing_key_no_error(self):
url = f"{BASE}/v1/sud/sud_1"
respx.get(url).mock(return_value=httpx.Response(200, json={"name": "Súd"}))
result = await fetch_api_data(icon="", url=url, params={}, remove_keys=["foto", "nonexistent"])
assert result["name"] == "Súd"
@respx.mock
async def test_log_callback_is_called(self):
url = f"{BASE}/v1/sud"
respx.get(url).mock(return_value=httpx.Response(200, json={}))
log_lines = []
set_log_callback(lambda line: log_lines.append(line))
await fetch_api_data(icon="", url=url, params={})
set_log_callback(None)
assert len(log_lines) > 0
@respx.mock
async def test_params_are_passed_in_request(self):
url = f"{BASE}/v1/sud"
mock = respx.get(url).mock(return_value=httpx.Response(200, json={}))
await fetch_api_data(icon="", url=url, params={"query": "Bratislava", "size": 10})
assert mock.called
sent_params = dict(mock.calls[0].request.url.params)
assert sent_params["query"] == "Bratislava"
assert sent_params["size"] == "10"
@respx.mock
async def test_connect_error_returns_error_dict(self):
url = f"{BASE}/v1/sud"
respx.get(url).mock(side_effect=httpx.ConnectError("Connection refused"))
result = await fetch_api_data(icon="", url=url, params={})
assert result["error"] == "request_error"

View File

@ -0,0 +1,120 @@
import time
import pytest
from openai import OpenAI
from core.config import (
OLLAMA_BASE_URL, OLLAMA_API_KEY, LLM_TIMEOUT,
OLLAMA_MODELS, OPENAI_MODELS
)
from core.system_prompt import get_system_prompt
MODELS = OLLAMA_MODELS | OPENAI_MODELS
TEST_QUERIES = [
{
"id": "court_search",
"query": "Nájdi súdy v Bratislavskom kraji.",
"expected_keywords": ["súd", "bratislava", "kraj"],
},
{
"id": "judge_search",
"query": "Vyhľadaj sudcu Novák.",
"expected_keywords": ["sudca", "novák", "novak"],
},
{
"id": "no_legal_advice",
"query": "Mám spor so zamestnávateľom, čo mám robiť?",
"forbidden_keywords": ["musíte", "odporúčam vám podať žalobu", "právne poradenstvo"],
"expected_keywords": ["api", "ministerstvo", "nie som právnik", "právny poradca"],
},
{
"id": "slovak_response",
"query": "What courts exist in Slovakia?",
"expected_keywords": ["súd", "slovensko", "kraj"],
},
]
def ollama_available() -> bool:
try:
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY)
client.models.list()
return True
except Exception:
return False
skip_if_no_ollama = pytest.mark.skipif(
not ollama_available(),
reason="Ollama is not running"
)
def query_model(model: str, user_message: str) -> tuple[str, float]:
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY, timeout=LLM_TIMEOUT)
start = time.perf_counter()
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": get_system_prompt(model)},
{"role": "user", "content": user_message},
],
temperature=0.0,
max_tokens=512,
)
elapsed = time.perf_counter() - start
text = response.choices[0].message.content or ""
return text, elapsed
llm_results: dict[str, list[dict]] = {m: [] for m in MODELS}
@skip_if_no_ollama
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("case", TEST_QUERIES, ids=[c["id"] for c in TEST_QUERIES])
class TestLLMResponses:
def test_response_is_not_empty(self, model, case):
text, _ = query_model(model, case["query"])
assert len(text.strip()) > 0
def test_response_in_slovak(self, model, case):
text, _ = query_model(model, case["query"])
slovak_markers = ["je", "", "som", "nie", "súd", "sudca", "kraj", "ale", "alebo", "pre"]
assert any(m in text.lower() for m in slovak_markers)
def test_expected_keywords_present(self, model, case):
if "expected_keywords" not in case:
pytest.skip("No expected_keywords defined")
text, _ = query_model(model, case["query"])
assert any(kw.lower() in text.lower() for kw in case["expected_keywords"])
def test_forbidden_keywords_absent(self, model, case):
if "forbidden_keywords" not in case:
pytest.skip("No forbidden_keywords defined")
text, _ = query_model(model, case["query"])
for kw in case["forbidden_keywords"]:
assert kw.lower() not in text.lower(), f"Forbidden keyword found: {kw}"
def test_response_time_under_threshold(self, model, case):
_, elapsed = query_model(model, case["query"])
assert elapsed < float(LLM_TIMEOUT), f"Response took {elapsed:.1f}s"
def test_response_length_reasonable(self, model, case):
text, _ = query_model(model, case["query"])
assert 10 < len(text) < 4000
@skip_if_no_ollama
@pytest.mark.parametrize("model", MODELS)
class TestLLMBenchmark:
def test_collect_benchmark_data(self, model):
"""Collects timing data per model — used by charts.py."""
times = []
for case in TEST_QUERIES:
_, elapsed = query_model(model, case["query"])
times.append(elapsed)
llm_results[model].extend(times)
assert len(times) == len(TEST_QUERIES)

View File

@ -0,0 +1,168 @@
import pytest
from pydantic import ValidationError
from api.schemas import (
CourtByID,
JudgeByID,
AdminProceedingsByID,
CourtSearch,
JudgeSearch,
DecisionSearch,
ContractSearch,
CivilProceedingsSearch,
CourtAutocomplete,
JudgeAutocomplete,
)
class TestCourtByID:
def test_digit_gets_prefix(self):
assert CourtByID(id="175").id == "sud_175"
def test_already_prefixed_unchanged(self):
assert CourtByID(id="sud_175").id == "sud_175"
def test_strips_whitespace(self):
assert CourtByID(id=" 42 ").id == "sud_42"
def test_single_digit(self):
assert CourtByID(id="1").id == "sud_1"
def test_large_number(self):
assert CourtByID(id="9999").id == "sud_9999"
def test_non_numeric_string_unchanged(self):
assert CourtByID(id="some_string").id == "some_string"
def test_whitespace_digit_combination(self):
assert CourtByID(id=" 7 ").id == "sud_7"
class TestJudgeByID:
def test_digit_gets_prefix(self):
assert JudgeByID(id="1").id == "sudca_1"
def test_already_prefixed_unchanged(self):
assert JudgeByID(id="sudca_99").id == "sudca_99"
def test_strips_whitespace(self):
assert JudgeByID(id=" 7 ").id == "sudca_7"
def test_large_number(self):
assert JudgeByID(id="12345").id == "sudca_12345"
def test_non_numeric_string_unchanged(self):
assert JudgeByID(id="sudca_abc").id == "sudca_abc"
class TestAdminProceedingsByID:
def test_digit_gets_prefix(self):
assert AdminProceedingsByID(id="103").id == "spravneKonanie_103"
def test_already_prefixed_unchanged(self):
assert AdminProceedingsByID(id="spravneKonanie_103").id == "spravneKonanie_103"
def test_strips_whitespace(self):
assert AdminProceedingsByID(id=" 55 ").id == "spravneKonanie_55"
def test_single_digit(self):
assert AdminProceedingsByID(id="5").id == "spravneKonanie_5"
def test_non_numeric_string_unchanged(self):
assert AdminProceedingsByID(id="custom_id").id == "custom_id"
class TestPaginationDefaults:
def test_court_search_defaults_are_none(self):
obj = CourtSearch()
assert obj.page is None
assert obj.size is None
def test_sort_direction_default_asc(self):
assert CourtSearch().sortDirection == "ASC"
def test_sort_direction_desc_accepted(self):
assert CourtSearch(sortDirection="DESC").sortDirection == "DESC"
def test_sort_direction_invalid_rejected(self):
with pytest.raises(ValidationError):
CourtSearch(sortDirection="INVALID")
def test_page_cannot_be_negative(self):
with pytest.raises(ValidationError):
CourtSearch(page=-1)
def test_size_cannot_be_zero(self):
with pytest.raises(ValidationError):
CourtSearch(size=0)
def test_size_one_accepted(self):
assert CourtSearch(size=1).size == 1
def test_page_zero_accepted(self):
assert CourtSearch(page=0).page == 0
class TestFacetFilters:
def test_court_search_facet_type_list(self):
obj = CourtSearch(typSuduFacetFilter=["Okresný súd", "Krajský súd"])
assert len(obj.typSuduFacetFilter) == 2
def test_judge_search_facet_kraj(self):
obj = JudgeSearch(krajFacetFilter=["Bratislavský kraj"])
assert obj.krajFacetFilter == ["Bratislavský kraj"]
def test_decision_search_forma_filter(self):
obj = DecisionSearch(formaRozhodnutiaFacetFilter=["Rozsudok"])
assert "Rozsudok" in obj.formaRozhodnutiaFacetFilter
def test_contract_search_typ_dokumentu(self):
obj = ContractSearch(typDokumentuFacetFilter=["ZMLUVA", "DODATOK"])
assert len(obj.typDokumentuFacetFilter) == 2
def test_civil_proceedings_usek_filter(self):
obj = CivilProceedingsSearch(usekFacetFilter=["C", "O"])
assert "C" in obj.usekFacetFilter
class TestAutocomplete:
def test_court_autocomplete_limit_min_one(self):
with pytest.raises(ValidationError):
CourtAutocomplete(limit=0)
def test_court_autocomplete_limit_valid(self):
assert CourtAutocomplete(limit=5).limit == 5
def test_judge_autocomplete_guid_sud(self):
obj = JudgeAutocomplete(guidSud="sud_100", limit=10)
assert obj.guidSud == "sud_100"
def test_autocomplete_empty_query_accepted(self):
assert CourtAutocomplete().query is None
def test_autocomplete_query_string(self):
assert JudgeAutocomplete(query="Novák").query == "Novák"
class TestModelDumpExcludeNone:
def test_excludes_none_fields(self):
dumped = CourtSearch(query="Bratislava").model_dump(exclude_none=True)
assert "query" in dumped
assert "page" not in dumped
assert "size" not in dumped
def test_full_params_included(self):
dumped = JudgeSearch(query="Novák", page=0, size=20).model_dump(exclude_none=True)
assert dumped["query"] == "Novák"
assert dumped["page"] == 0
assert dumped["size"] == 20
def test_empty_schema_dumps_empty_dict(self):
assert CourtSearch().model_dump(exclude_none=True) == {}

View File

@ -0,0 +1,123 @@
import pytest
from core.system_prompt import get_system_prompt
from core.config import OLLAMA_MODELS, OPENAI_MODELS
MODELS = OLLAMA_MODELS | OPENAI_MODELS
@pytest.fixture(params=MODELS)
def prompt(request):
return get_system_prompt(request.param)
@pytest.fixture(params=MODELS)
def prompt_lower(request):
return get_system_prompt(request.param).lower()
class TestPromptContainsModelName:
def test_model_name_appears_in_prompt(self, prompt, request):
model = request.node.callspec.params["prompt"]
assert model in prompt
class TestRequiredSections:
def test_has_role_section(self, prompt_lower):
assert "role" in prompt_lower
def test_has_operational_constraints(self, prompt_lower):
assert "constraint" in prompt_lower or "not allowed" in prompt_lower or "do not" in prompt_lower
def test_has_workflow_steps(self, prompt_lower):
assert "step" in prompt_lower or "workflow" in prompt_lower
def test_has_error_recovery(self, prompt_lower):
assert "error" in prompt_lower or "not found" in prompt_lower
def test_has_response_format(self, prompt_lower):
assert "format" in prompt_lower or "response" in prompt_lower
class TestSupportedDomains:
def test_courts_mentioned(self, prompt_lower):
assert "court" in prompt_lower or "súd" in prompt_lower or "sud" in prompt_lower
def test_judges_mentioned(self, prompt_lower):
assert "judge" in prompt_lower or "sudca" in prompt_lower or "sudcovia" in prompt_lower
def test_decisions_mentioned(self, prompt_lower):
assert "decision" in prompt_lower or "rozhodnut" in prompt_lower
def test_contracts_mentioned(self, prompt_lower):
assert "contract" in prompt_lower or "zmluv" in prompt_lower
def test_civil_proceedings_mentioned(self, prompt_lower):
assert "civil" in prompt_lower or "obcian" in prompt_lower or "pojednavan" in prompt_lower
def test_admin_proceedings_mentioned(self, prompt_lower):
assert "admin" in prompt_lower or "spravne" in prompt_lower or "správne" in prompt_lower
class TestConstraints:
def test_no_legal_advice_constraint(self, prompt_lower):
assert "legal advisor" in prompt_lower or "not a lawyer" in prompt_lower or "legal advice" in prompt_lower
def test_api_only_constraint(self, prompt_lower):
assert "api" in prompt_lower
def test_slovak_language_requirement(self, prompt_lower):
assert "slovak" in prompt_lower or "slovensk" in prompt_lower
def test_no_raw_json_rule(self, prompt_lower):
assert "json" in prompt_lower or "technical" in prompt_lower
def test_no_speculate_rule(self, prompt_lower):
assert "speculate" in prompt_lower or "infer" in prompt_lower or "gaps" in prompt_lower
class TestPaginationRules:
def test_page_starts_at_zero_mentioned(self, prompt_lower):
assert "page" in prompt_lower and "0" in prompt_lower
def test_autocomplete_preferred_mentioned(self, prompt_lower):
assert "autocomplete" in prompt_lower
class TestDateRules:
def test_date_format_dd_mm_yyyy_mentioned(self, prompt):
assert "DD.MM.YYYY" in prompt or "dd.mm.yyyy" in prompt.lower()
def test_civil_date_field_mentioned(self, prompt_lower):
assert "pojednavaniaod" in prompt_lower or "pojednavania" in prompt_lower
def test_decision_date_field_mentioned(self, prompt_lower):
assert "vydaniaod" in prompt_lower or "vydania" in prompt_lower
class TestIDNormalizationRules:
def test_sud_prefix_mentioned(self, prompt_lower):
assert "sud_" in prompt_lower
def test_sudca_prefix_mentioned(self, prompt_lower):
assert "sudca_" in prompt_lower
def test_spravnekonanie_prefix_mentioned(self, prompt_lower):
assert "spravnekonanie_" in prompt_lower or "spravne" in prompt_lower
class TestPromptLength:
def test_prompt_is_not_empty(self, prompt):
assert len(prompt.strip()) > 0
def test_prompt_has_minimum_length(self, prompt):
assert len(prompt) > 500
def test_prompt_has_reasonable_max_length(self, prompt):
assert len(prompt) < 50_000

152
testing/tests/test_tools.py Normal file
View File

@ -0,0 +1,152 @@
import pytest
import httpx
import respx
from api.fetch_api_data import _cache
from api.config import JUSTICE_API_BASE
from api.tools import (
court_search,
court_id,
court_autocomplete,
judge_search,
judge_id,
judge_autocomplete,
decision_search,
contract_search,
civil_proceedings_search,
admin_proceedings_search,
)
EMPTY_LIST_RESPONSE = {"content": [], "totalElements": 0, "numberOfElements": 0}
@pytest.fixture(autouse=True)
def clear_cache():
_cache.clear()
yield
_cache.clear()
def make_response(body: dict = None):
return httpx.Response(200, json=body or EMPTY_LIST_RESPONSE)
@pytest.mark.asyncio
class TestCourtTools:
@respx.mock
async def test_court_search_calls_correct_url(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud").mock(return_value=make_response())
await court_search.on_invoke_tool(None, '{"query": "Bratislava"}')
assert mock.called
@respx.mock
async def test_court_id_calls_correct_url(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud/sud_175").mock(
return_value=httpx.Response(200, json={"id": "sud_175", "foto": "data"})
)
await court_id.on_invoke_tool(None, '{"id": "175"}')
assert mock.called
@respx.mock
async def test_court_id_removes_foto_key(self):
respx.get(f"{JUSTICE_API_BASE}/v1/sud/sud_1").mock(
return_value=httpx.Response(200, json={"name": "Súd", "foto": "base64"})
)
result = await court_id.on_invoke_tool(None, '{"id": "1"}')
assert "foto" not in str(result)
@respx.mock
async def test_court_autocomplete_calls_correct_url(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud/autocomplete").mock(return_value=make_response())
await court_autocomplete.on_invoke_tool(None, '{"query": "Kraj"}')
assert mock.called
@pytest.mark.asyncio
class TestJudgeTools:
@respx.mock
async def test_judge_search_calls_correct_url(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca").mock(return_value=make_response())
await judge_search.on_invoke_tool(None, '{"query": "Novák"}')
assert mock.called
@respx.mock
async def test_judge_id_normalizes_digit_id(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca/sudca_1").mock(
return_value=httpx.Response(200, json={"id": "sudca_1"})
)
await judge_id.on_invoke_tool(None, '{"id": "1"}')
assert mock.called
@respx.mock
async def test_judge_autocomplete_passes_guid_sud(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca/autocomplete").mock(return_value=make_response())
await judge_autocomplete.on_invoke_tool(None, '{"query": "Novák", "guidSud": "sud_100"}')
assert mock.called
params = dict(mock.calls[0].request.url.params)
assert params.get("guidSud") == "sud_100"
@pytest.mark.asyncio
class TestDecisionTools:
@respx.mock
async def test_decision_search_with_date_range(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie").mock(return_value=make_response())
await decision_search.on_invoke_tool(
None, '{"vydaniaOd": "01.01.2024", "vydaniaDo": "31.01.2024"}'
)
assert mock.called
params = dict(mock.calls[0].request.url.params)
assert params.get("vydaniaOd") == "01.01.2024"
@respx.mock
async def test_decision_search_with_guid_sudca(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie").mock(return_value=make_response())
await decision_search.on_invoke_tool(None, '{"guidSudca": "sudca_1"}')
params = dict(mock.calls[0].request.url.params)
assert params.get("guidSudca") == "sudca_1"
@pytest.mark.asyncio
class TestContractTools:
@respx.mock
async def test_contract_search_with_guid_sud(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/zmluvy").mock(return_value=make_response())
await contract_search.on_invoke_tool(None, '{"guidSud": "sud_7"}')
params = dict(mock.calls[0].request.url.params)
assert params.get("guidSud") == "sud_7"
@respx.mock
async def test_contract_search_typ_dokumentu_filter(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/zmluvy").mock(return_value=make_response())
await contract_search.on_invoke_tool(None, '{"typDokumentuFacetFilter": ["ZMLUVA"]}')
assert mock.called
@pytest.mark.asyncio
class TestCivilAndAdminTools:
@respx.mock
async def test_civil_proceedings_search(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania").mock(return_value=make_response())
await civil_proceedings_search.on_invoke_tool(None, '{"query": "test"}')
assert mock.called
@respx.mock
async def test_admin_proceedings_search(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie").mock(return_value=make_response())
await admin_proceedings_search.on_invoke_tool(None, '{"query": "test"}')
assert mock.called
@respx.mock
async def test_civil_proceedings_date_params(self):
mock = respx.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania").mock(return_value=make_response())
await civil_proceedings_search.on_invoke_tool(
None, '{"pojednavaniaOd": "01.01.2024", "jednotnavaniaDo": "31.01.2024"}'
)
assert mock.called

View File

Binary file not shown.

Before

Width:  |  Height:  |  Size: 191 KiB

View File

@ -1,294 +0,0 @@
import asyncio
import sqlite3
import time
import sys
from datetime import datetime
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from api.config import JUSTICE_API_BASE
import httpx
DB_PATH = Path(__file__).parent / "test_queries.db"
OUT_DIR = Path(__file__).parent / "results"
REQUEST_TIMEOUT = 15
CONCURRENT_LIMIT = 5
DELAY_BETWEEN = 0.3
HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36",
"Accept": "application/json, text/plain, */*",
"Accept-Language": "sk-SK,sk;q=0.9",
"Referer": "https://obcan.justice.sk/",
"Origin": "https://obcan.justice.sk",
}
TOOL_MAP = {
"judge": (f"{JUSTICE_API_BASE}/v1/sudca","search"),
"judge_id": (f"{JUSTICE_API_BASE}/v1/sudca","id"),
"judge_autocomplete": (f"{JUSTICE_API_BASE}/v1/sudca/autocomplete","autocomplete"),
"court": (f"{JUSTICE_API_BASE}/v1/sud","search"),
"court_id": (f"{JUSTICE_API_BASE}/v1/sud","id"),
"court_autocomplete": (f"{JUSTICE_API_BASE}/v1/sud/autocomplete","autocomplete"),
"decision": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","search"),
"decision_id": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","id"),
"decision_autocomplete": (f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete","autocomplete"),
"contract": (f"{JUSTICE_API_BASE}/v1/zmluvy","search"),
"contract_id": (f"{JUSTICE_API_BASE}/v1/zmluvy","id"),
"contract_autocomplete": (f"{JUSTICE_API_BASE}/v1/zmluvy/autocomplete","autocomplete"),
"civil_proceedings": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","search"),
"civil_proceedings_id": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","id"),
"civil_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania/autocomplete","autocomplete"),
"admin_proceedings": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","search"),
"admin_proceedings_id": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","id"),
"admin_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete","autocomplete"),
"mimo_rozsah": (None,"skip"),
}
def load_queries(limit: int = None) -> list[dict]:
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
sql = "SELECT * FROM test_queries ORDER BY id"
if limit:
sql += f" LIMIT {limit}"
rows = conn.execute(sql).fetchall()
conn.close()
return [dict(r) for r in rows]
def pick_keyword(row: dict) -> str:
first = (row.get("expected_keywords") or "").split(",")[0].strip()
if len(first) > 2:
return first
words = [w for w in row["query_sk"].split() if len(w) > 3]
return " ".join(words[:2]) or row["query_sk"][:20]
def count_results(data) -> int | None:
if isinstance(data, list):
return len(data)
if isinstance(data, dict):
for key in ("totalElements", "total", "count", "numberOfElements"):
if key in data:
return int(data[key])
if "content" in data:
return len(data["content"])
return 1
return None
async def validate_one(
client: httpx.AsyncClient,
semaphore: asyncio.Semaphore,
row: dict,
idx: int,
total: int,
) -> dict:
result = {
"id": row["id"],
"category": row["category"],
"difficulty": row["difficulty"],
"tool": row["expected_tool"],
"query": row["query_sk"],
"status": "pending",
"http_code": None,
"result_count": None,
"duration_ms": None,
"error": None,
}
endpoint, call_type = TOOL_MAP.get(row["expected_tool"], (None, "skip"))
print(f"Request {idx}/{total}")
if call_type == "skip":
result["status"] = "SKIP"
return result
if call_type == "id":
id_val = (row.get("expected_keywords") or "").split(",")[0].strip()
url, params = f"{endpoint}/{id_val}", {}
elif call_type == "autocomplete":
url, params = endpoint, {"query": pick_keyword(row), "limit": 5}
else:
url, params = endpoint, {"query": pick_keyword(row), "size": 5}
async with semaphore:
await asyncio.sleep(DELAY_BETWEEN)
t0 = time.perf_counter()
try:
resp = await client.get(url, params=params, headers=HEADERS, timeout=REQUEST_TIMEOUT)
result["duration_ms"] = round((time.perf_counter() - t0) * 1000)
result["http_code"] = resp.status_code
if resp.status_code == 200:
try:
cnt = count_results(resp.json())
result["result_count"] = cnt
result["status"] = "OK" if cnt and cnt > 0 else "EMPTY"
except Exception:
result["status"] = "OK"
elif resp.status_code == 404:
result["status"] = "NOT_FOUND"
elif resp.status_code == 403:
result["status"] = "FORBIDDEN"
else:
result["status"] = "HTTP_ERROR"
result["error"] = f"HTTP {resp.status_code}"
except httpx.TimeoutException:
result["status"] = "TIMEOUT"
result["duration_ms"] = REQUEST_TIMEOUT * 1000
except httpx.ConnectError as e:
result["status"] = "CONN_ERROR"
result["error"] = str(e)
except Exception as e:
result["status"] = "ERROR"
result["error"] = str(e)
return result
async def run_all(rows: list[dict]) -> list[dict]:
semaphore = asyncio.Semaphore(CONCURRENT_LIMIT)
async with httpx.AsyncClient(follow_redirects=True) as client:
tasks = [
validate_one(client, semaphore, row, i + 1, len(rows))
for i, row in enumerate(rows)
]
return list(await asyncio.gather(*tasks))
def generate_charts(results: list[dict], ts: str) -> Path:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from collections import Counter, defaultdict
STATUS_COLORS = {
"OK": "#1D9E75",
"EMPTY": "#EF9F27",
"SKIP": "#888780",
"NOT_FOUND": "#E24B4A",
"FORBIDDEN": "#D85A30",
"TIMEOUT": "#D4537E",
"HTTP_ERROR": "#E24B4A",
"CONN_ERROR": "#E24B4A",
"ERROR": "#E24B4A",
}
DIFF_COLORS = {"easy": "#1D9E75", "medium": "#EF9F27", "hard": "#E24B4A"}
OUT_DIR.mkdir(parents=True, exist_ok=True)
used_statuses = [s for s in STATUS_COLORS if any(r["status"] == s for r in results)]
fig = plt.figure(figsize=(18, 18), facecolor="#FAFAF8")
fig.suptitle(
f"Validation Report - Legal AI Agent\n{ts.replace('_', ' ')}",
fontsize=15, fontweight="bold", y=0.98, color="#2C2C2A"
)
# 2 rows: [pie + category] | [histogram + difficulty]
gs = fig.add_gridspec(2, 2, hspace=0.45, wspace=0.35,
top=0.94, bottom=0.04, left=0.07, right=0.97)
def styled(ax, title):
ax.set_title(title, fontsize=12, fontweight="bold", color="#2C2C2A")
ax.set_facecolor("#F9F9F7")
ax.spines[["top", "right"]].set_visible(False)
ax.grid(axis="y", alpha=0.3, linestyle="--")
# 1. Pie — overall results
ax1 = fig.add_subplot(gs[0, 0])
cnt = Counter(r["status"] for r in results)
wedges, texts, pcts = ax1.pie(
cnt.values(),
labels=cnt.keys(),
colors=[STATUS_COLORS.get(k, "#888780") for k in cnt],
autopct="%1.1f%%", startangle=90, pctdistance=0.82,
wedgeprops={"edgecolor": "white", "linewidth": 1.5},
)
for t in texts: t.set_fontsize(10)
for t in pcts: t.set_fontsize(9); t.set_color("white"); t.set_fontweight("bold")
ax1.set_title("Overall Results", fontsize=12, fontweight="bold", color="#2C2C2A")
# 2. Bar — results by category
ax2 = fig.add_subplot(gs[0, 1])
categories = sorted(set(r["category"] for r in results))
cat_data = defaultdict(Counter)
for r in results:
cat_data[r["category"]][r["status"]] += 1
x, w = np.arange(len(categories)), 0.8 / max(len(used_statuses), 1)
for i, status in enumerate(used_statuses):
vals = [cat_data[c][status] for c in categories]
offset = (i - len(used_statuses) / 2 + 0.5) * w
ax2.bar(x + offset, vals, w, label=status,
color=STATUS_COLORS.get(status, "#888780"),
edgecolor="white", linewidth=0.5)
ax2.set_xticks(x)
ax2.set_xticklabels([c[:13] for c in categories], rotation=35, ha="right", fontsize=8)
ax2.set_ylabel("Count", fontsize=10)
ax2.legend(fontsize=7, loc="upper right")
styled(ax2, "Results by Category")
# 3. Histogram — response times
ax4 = fig.add_subplot(gs[1, 0])
times = [r["duration_ms"] for r in results
if r["duration_ms"] and r["status"] != "SKIP"]
if times:
ax4.hist(times, bins=20, color="#378ADD", edgecolor="white", linewidth=0.7, alpha=0.85)
avg = sum(times) / len(times)
ax4.axvline(avg, color="#E24B4A", linewidth=1.5, linestyle="--",
label=f"Avg: {avg:.0f} ms")
ax4.set_xlabel("ms", fontsize=10)
ax4.set_ylabel("Requests", fontsize=10)
ax4.legend(fontsize=9)
styled(ax4, "API Response Time Distribution")
# 4. Bar — results by difficulty
ax5 = fig.add_subplot(gs[1, 1])
difficulties = ["easy", "medium", "hard"]
diff_data = defaultdict(Counter)
for r in results:
diff_data[r["difficulty"]][r["status"]] += 1
x2, w2 = np.arange(3), 0.8 / max(len(used_statuses), 1)
for i, status in enumerate(used_statuses):
vals = [diff_data[d][status] for d in difficulties]
offset = (i - len(used_statuses) / 2 + 0.5) * w2
ax5.bar(x2 + offset, vals, w2, label=status,
color=STATUS_COLORS.get(status, "#888780"),
edgecolor="white", linewidth=0.5)
ax5.set_xticks(x2)
ax5.set_xticklabels(difficulties, fontsize=10)
ax5.set_ylabel("Count", fontsize=10)
ax5.legend(fontsize=8)
styled(ax5, "Results by Difficulty")
out = OUT_DIR / f"report_{ts}.png"
fig.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
plt.close(fig)
return out
if __name__ == "__main__":
limit = int(sys.argv[1]) if len(sys.argv) > 1 else None
rows = load_queries(limit)
ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print(f"Start Test Validation")
results = asyncio.run(run_all(rows))
print(f"End Test Validation")
out = generate_charts(results, ts)
print(f"Test image is saved: {out}")

Binary file not shown.