diff --git a/app.py b/app.py index 13555df..9454e84 100644 --- a/app.py +++ b/app.py @@ -9,15 +9,19 @@ from core.stream_response import stream_response from api.fetch_api_data import set_log_callback STARTERS = [ - ("What legal data can the agent find?","magnifying_glass"), - ("What is the agent not allowed to do or use?","ban"), - ("What are the details of your AI model?","hexagon"), - ("What data sources does the agent rely on?","database"), + ("What legal data can the agent find?", "magnifying_glass"), + ("What is the agent not allowed to do or use?", "ban"), + ("What are the details of your AI model?", "hexagon"), + ("What data sources does the agent rely on?", "database"), ] PROFILES = [ - ("qwen3.5:cloud","Qwen 3.5 CLOUD"), - ("gpt-oss:20b-cloud","GPT-OSS 20B CLOUD"), + ("qwen3.5:cloud", "Qwen 3.5 CLOUD (in Ollama)"), + ("gpt-oss:20b-cloud", "GPT-OSS 20B CLOUD (in Ollama)"), + ("gpt-oss:20b", "GPT-OSS 20B (Local LLM)"), + ("qwen3:8b", "Qwen3 8B (Local LLM)"), + ("gpt-4o", "GPT-4o (OpenAI API)"), + ("gpt-4o-mini", "GPT-4o Mini (OpenAI API)"), ] @cl.set_starters diff --git a/core/config.py b/core/config.py index b8a1f73..f88eb5e 100644 --- a/core/config.py +++ b/core/config.py @@ -2,7 +2,13 @@ import os DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "qwen3.5:cloud") MAX_HISTORY = int(os.getenv("MAX_HISTORY", "20")) +AGENT_TEMPERATURE = float(os.getenv("AGENT_TEMPERATURE", "0.7")) +LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "120.0")) + OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434/v1") OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "ollama") -OLLAMA_TIMEOUT = float(os.getenv("OLLAMA_TIMEOUT", "120.0")) -AGENT_TEMPERATURE = float(os.getenv("AGENT_TEMPERATURE", "0.7")) \ No newline at end of file +OLLAMA_MODELS = {"qwen3.5:cloud", "gpt-oss:20b-cloud", "gpt-oss:20b", "qwen3:8b"} + +OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") +OPENAI_MODELS = {"gpt-4o", "gpt-4o-mini"} \ No newline at end of file diff --git a/core/init_agent.py b/core/init_agent.py index 0e96838..17c2fab 100644 --- a/core/init_agent.py +++ b/core/init_agent.py @@ -2,7 +2,12 @@ from agents import Agent, AgentHooks from agents import OpenAIChatCompletionsModel, AsyncOpenAI, ModelSettings from agents import set_tracing_disabled -from core.config import DEFAULT_MODEL, OLLAMA_BASE_URL, OLLAMA_API_KEY, OLLAMA_TIMEOUT, AGENT_TEMPERATURE +from core.config import ( + DEFAULT_MODEL, AGENT_TEMPERATURE, LLM_TIMEOUT, + OLLAMA_BASE_URL, OLLAMA_API_KEY, + OPENAI_BASE_URL, OPENAI_API_KEY, + OPENAI_MODELS, OLLAMA_MODELS +) from core.system_prompt import get_system_prompt from api.tools import ALL_TOOLS @@ -15,18 +20,35 @@ class MyAgentHooks(AgentHooks): async def on_end(self, context, agent, output): print(f"🏁 [AgentHooks] {agent.name} ended.") -def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent: +def _make_client(model_name: str) -> AsyncOpenAI: + """Return the correct AsyncOpenAI client based on model name""" - client = AsyncOpenAI( - base_url=OLLAMA_BASE_URL, - api_key=OLLAMA_API_KEY, - timeout=OLLAMA_TIMEOUT, - max_retries=0 - ) + if model_name in OPENAI_MODELS: + return AsyncOpenAI( + base_url=OPENAI_BASE_URL, + api_key=OPENAI_API_KEY, + timeout=LLM_TIMEOUT, + max_retries=0, + ) + + if model_name in OLLAMA_MODELS: + return AsyncOpenAI( + base_url=OLLAMA_BASE_URL, + api_key=OLLAMA_API_KEY, + timeout=LLM_TIMEOUT, + max_retries=0, + ) + + raise ValueError(f"Model {model_name} not supported") + +def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent: + """Initialize the assistant agent for legal work""" + + client = _make_client(model_name) model = OpenAIChatCompletionsModel(model=model_name, openai_client=client) agent = Agent( - name="Assistant", + name="AI Lawyer Assistant", instructions=get_system_prompt(model_name), model=model, model_settings=ModelSettings(temperature=AGENT_TEMPERATURE, tool_choice="auto", parallel_tool_calls=False), diff --git a/requirements.txt b/requirements.txt index cd6d894..d9f4521 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/testing/fixtures.py b/testing/fixtures.py new file mode 100644 index 0000000..db22cde --- /dev/null +++ b/testing/fixtures.py @@ -0,0 +1,9 @@ +import pytest +import asyncio + + +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() \ No newline at end of file diff --git a/testing/reports/charts.py b/testing/reports/charts.py new file mode 100644 index 0000000..3e65869 --- /dev/null +++ b/testing/reports/charts.py @@ -0,0 +1,245 @@ +import json +from pathlib import Path +from collections import Counter, defaultdict + +import numpy as np +import pandas as pd +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches + +matplotlib.use("Agg") + +RESULTS_FILE = Path(__file__).parent.parent / "results.json" +CHARTS_DIR = Path(__file__).parent.parent / "charts" + +STATUS_COLORS = { + "passed": "#3ddc84", + "failed": "#ff5c5c", + "skipped": "#e0c84a", + "error": "#ff8a8a", +} + +BG_COLOR = "#1a1a1a" +PANEL_COLOR = "#242424" +TEXT_COLOR = "#e0e0e0" +GRID_COLOR = "#333333" +FONT_MONO = "monospace" + + +def _style_ax(ax, title: str): + ax.set_facecolor(PANEL_COLOR) + ax.set_title(title, color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10) + ax.tick_params(colors=TEXT_COLOR, labelsize=8) + ax.spines[:].set_color(GRID_COLOR) + ax.yaxis.grid(True, color=GRID_COLOR, linewidth=0.5, linestyle="--") + ax.set_axisbelow(True) + for label in ax.get_xticklabels() + ax.get_yticklabels(): + label.set_color(TEXT_COLOR) + label.set_fontfamily(FONT_MONO) + + +def load_data() -> tuple[pd.DataFrame, dict]: + if not RESULTS_FILE.exists(): + raise FileNotFoundError(f"results.json not found at {RESULTS_FILE}") + + raw = json.loads(RESULTS_FILE.read_text(encoding="utf-8")) + tests = raw.get("tests", []) + + rows = [] + for t in tests: + node = t["nodeid"] + parts = node.split("::") + module = parts[0].replace("tests/", "").replace("/", ".").replace(".py", "") + cls = parts[1] if len(parts) >= 3 else "unknown" + name = parts[-1] + + duration = t.get("call", {}).get("duration", 0.0) if t.get("call") else 0.0 + + rows.append({ + "module": module, + "class": cls, + "name": name, + "outcome": t["outcome"], + "duration": duration, + }) + + df = pd.DataFrame(rows) + summary = raw.get("summary", {}) + return df, summary + + +def chart_overall_status(df: pd.DataFrame, ax: plt.Axes): + counts = df["outcome"].value_counts() + colors = [STATUS_COLORS.get(k, "#888") for k in counts.index] + + wedges, texts, pcts = ax.pie( + counts.values, + labels=counts.index, + colors=colors, + autopct="%1.1f%%", + startangle=90, + pctdistance=0.78, + wedgeprops={"edgecolor": BG_COLOR, "linewidth": 2}, + ) + for t in texts: + t.set_color(TEXT_COLOR) + t.set_fontsize(9) + t.set_fontfamily(FONT_MONO) + for p in pcts: + p.set_color("#111") + p.set_fontsize(8) + p.set_fontweight("bold") + + ax.set_facecolor(PANEL_COLOR) + ax.set_title("Overall Results", color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10) + + +def chart_by_module(df: pd.DataFrame, ax: plt.Axes): + pivot = ( + df.groupby(["module", "outcome"]) + .size() + .unstack(fill_value=0) + .reindex(columns=["passed", "failed", "skipped", "error"], fill_value=0) + ) + + x = np.arange(len(pivot)) + width = 0.2 + offset = -(len(pivot.columns) - 1) / 2 * width + + for i, col in enumerate(pivot.columns): + bars = ax.bar( + x + offset + i * width, + pivot[col], + width, + label=col, + color=STATUS_COLORS.get(col, "#888"), + edgecolor=BG_COLOR, + linewidth=0.8, + ) + + ax.set_xticks(x) + ax.set_xticklabels(pivot.index, rotation=25, ha="right", fontsize=7, fontfamily=FONT_MONO) + ax.set_ylabel("Tests", color=TEXT_COLOR, fontsize=9) + ax.legend( + fontsize=7, + labelcolor=TEXT_COLOR, + facecolor=PANEL_COLOR, + edgecolor=GRID_COLOR, + ) + _style_ax(ax, "Results by Module") + + +def chart_duration_histogram(df: pd.DataFrame, ax: plt.Axes): + durations = df.loc[df["outcome"] != "skipped", "duration"].values * 1000 + + if len(durations) == 0: + ax.text(0.5, 0.5, "No data", ha="center", va="center", color=TEXT_COLOR) + _style_ax(ax, "Test Duration (ms)") + return + + mean_ms = float(np.mean(durations)) + median_ms = float(np.median(durations)) + p95_ms = float(np.percentile(durations, 95)) + + ax.hist(durations, bins=20, color="#5cb8ff", edgecolor=BG_COLOR, linewidth=0.6, alpha=0.85) + ax.axvline(mean_ms, color="#3ddc84", linewidth=1.5, linestyle="--", label=f"Mean {mean_ms:.1f} ms") + ax.axvline(median_ms, color="#e0c84a", linewidth=1.5, linestyle=":", label=f"Median {median_ms:.1f} ms") + ax.axvline(p95_ms, color="#ff5c5c", linewidth=1.5, linestyle="-.", label=f"P95 {p95_ms:.1f} ms") + + ax.set_xlabel("ms", color=TEXT_COLOR, fontsize=9) + ax.set_ylabel("Tests", color=TEXT_COLOR, fontsize=9) + ax.legend(fontsize=7, labelcolor=TEXT_COLOR, facecolor=PANEL_COLOR, edgecolor=GRID_COLOR) + _style_ax(ax, "Test Duration (ms)") + + +def chart_slowest_tests(df: pd.DataFrame, ax: plt.Axes): + top = ( + df[df["outcome"] != "skipped"] + .nlargest(10, "duration") + .copy() + ) + top["label"] = top["class"] + "::" + top["name"] + top["duration"] = top["duration"] * 1000 + + colors = [STATUS_COLORS.get(o, "#888") for o in top["outcome"]] + bars = ax.barh(top["label"], top["duration"], color=colors, edgecolor=BG_COLOR, linewidth=0.6) + + ax.set_xlabel("ms", color=TEXT_COLOR, fontsize=9) + ax.tick_params(axis="y", labelsize=7) + ax.invert_yaxis() + _style_ax(ax, "Top 10 Slowest Tests") + + +def chart_stats_table(df: pd.DataFrame, summary: dict, ax: plt.Axes): + ax.set_facecolor(PANEL_COLOR) + ax.axis("off") + + total = len(df) + passed = summary.get("passed", 0) + failed = summary.get("failed", 0) + skipped = summary.get("skipped", 0) + duration = df["duration"].sum() * 1000 + + durations = df.loc[df["outcome"] != "skipped", "duration"].values * 1000 + + rows = [ + ["Total tests", str(total)], + ["Passed", str(passed)], + ["Failed", str(failed)], + ["Skipped", str(skipped)], + ["Pass rate", f"{passed / total * 100:.1f}%" if total else "—"], + ["Total time", f"{duration:.0f} ms"], + ["Mean duration", f"{np.mean(durations):.1f} ms" if len(durations) else "—"], + ["Median", f"{np.median(durations):.1f} ms" if len(durations) else "—"], + ["P95", f"{np.percentile(durations, 95):.1f} ms" if len(durations) else "—"], + ] + + table = ax.table( + cellText=rows, + colLabels=["Metric", "Value"], + cellLoc="left", + loc="center", + colWidths=[0.6, 0.4], + ) + table.auto_set_font_size(False) + table.set_fontsize(9) + + for (row, col), cell in table.get_celld().items(): + cell.set_facecolor("#2e2e2e" if row % 2 == 0 else PANEL_COLOR) + cell.set_edgecolor(GRID_COLOR) + cell.set_text_props(color=TEXT_COLOR, fontfamily=FONT_MONO) + + ax.set_title("Summary", color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10) + + +def generate(output_path: Path = None) -> Path: + CHARTS_DIR.mkdir(parents=True, exist_ok=True) + output_path = output_path or CHARTS_DIR / "report.png" + + df, summary = load_data() + + fig = plt.figure(figsize=(18, 14), facecolor=BG_COLOR) + fig.suptitle( + "Legal AI Assistant — Test Report", + fontsize=16, fontweight="bold", color=TEXT_COLOR, + y=0.98, fontfamily=FONT_MONO, + ) + + gs = fig.add_gridspec(2, 3, hspace=0.45, wspace=0.35, + top=0.93, bottom=0.05, left=0.06, right=0.97) + + chart_overall_status(df, fig.add_subplot(gs[0, 0])) + chart_by_module(df, fig.add_subplot(gs[0, 1:])) + chart_duration_histogram(df, fig.add_subplot(gs[1, 0])) + chart_slowest_tests(df, fig.add_subplot(gs[1, 1])) + chart_stats_table(df, summary, fig.add_subplot(gs[1, 2])) + + fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor=BG_COLOR) + plt.close(fig) + return output_path + + +if __name__ == "__main__": + out = generate() + print(f"Chart saved: {out}") \ No newline at end of file diff --git a/testing/run_tests.py b/testing/run_tests.py new file mode 100644 index 0000000..81764a4 --- /dev/null +++ b/testing/run_tests.py @@ -0,0 +1,58 @@ +import subprocess +import sys +from pathlib import Path + +ROOT = Path(__file__).parent.parent +TESTS_DIR = Path(__file__).parent +RESULTS_JSON = TESTS_DIR / "results.json" +REPORT_HTML = TESTS_DIR / "charts" / "report.html" +CHARTS_DIR = TESTS_DIR / "charts" + + +def run_pytest() -> int: + CHARTS_DIR.mkdir(parents=True, exist_ok=True) + + args = [ + sys.executable, "-m", "pytest", + + str(TESTS_DIR / "unit"), + str(TESTS_DIR / "integration"), + str(TESTS_DIR / "e2e"), + str(TESTS_DIR / "llm"), + + "--json-report", + f"--json-report-file={RESULTS_JSON}", + + f"--html={REPORT_HTML}", + "--self-contained-html", + + f"--cov=api", + f"--cov=core", + "--cov-report=term-missing", + f"--cov-report=html:{CHARTS_DIR / 'coverage'}", + + "-p", "no:terminal", + "--tb=short", + "-q", + ] + + return subprocess.run(args, cwd=str(ROOT), check=False).returncode + + +def run_charts(): + from testing.reports.charts import generate + out = generate() + print(f"Charts saved: {out}") + + +if __name__ == "__main__": + print("Start testing") + + code = run_pytest() + + if RESULTS_JSON.exists(): + run_charts() + + print("Stop testing") + + sys.exit(code) \ No newline at end of file diff --git a/testing/tests/test_api.py b/testing/tests/test_api.py new file mode 100644 index 0000000..36be2de --- /dev/null +++ b/testing/tests/test_api.py @@ -0,0 +1,152 @@ +import pytest +import httpx + +from api.config import JUSTICE_API_BASE, HTTP_TIMEOUT + +HEADERS = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36", + "Accept": "application/json, text/plain, */*", + "Accept-Language": "sk-SK,sk;q=0.9", + "Referer": "https://obcan.justice.sk/", + "Origin": "https://obcan.justice.sk", +} + + +def api_available() -> bool: + try: + r = httpx.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 1}, headers=HEADERS, timeout=5) + return r.status_code == 200 + except Exception: + return False + + +skip_if_offline = pytest.mark.skipif( + not api_available(), + reason="justice.sk API is not reachable" +) + + +@pytest.fixture(scope="module") +def client(): + with httpx.Client(headers=HEADERS, timeout=HTTP_TIMEOUT, follow_redirects=True) as c: + yield c + + +@skip_if_offline +class TestCourtsEndpoint: + + def test_returns_200(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 5}) + assert r.status_code == 200 + + def test_response_has_content_key(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 5}) + data = r.json() + assert "content" in data + + def test_total_elements_is_positive(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 1}) + data = r.json() + assert data.get("totalElements", 0) > 0 + + def test_court_by_id_returns_valid_record(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sud/sud_175") + assert r.status_code == 200 + data = r.json() + assert "id" in data or "nazov" in data + + def test_court_autocomplete_returns_list(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sud/autocomplete", params={"query": "Bratislava", "limit": 5}) + assert r.status_code == 200 + assert isinstance(r.json(), list) + + def test_nonexistent_court_returns_404(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sud/sud_999999999") + assert r.status_code == 404 + + +@skip_if_offline +class TestJudgesEndpoint: + + def test_judge_search_returns_200(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={"size": 5}) + assert r.status_code == 200 + + def test_judge_autocomplete_returns_results(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sudca/autocomplete", params={"query": "Novák", "limit": 10}) + assert r.status_code == 200 + + def test_judge_search_by_kraj(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={ + "krajFacetFilter": "Bratislavský kraj", + "size": 5 + }) + assert r.status_code == 200 + + def test_judge_search_pagination_page_zero(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={"page": 0, "size": 10}) + data = r.json() + assert "content" in data + + +@skip_if_offline +class TestDecisionsEndpoint: + + def test_decision_search_returns_200(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie", params={"size": 3}) + assert r.status_code == 200 + + def test_decision_search_with_date_range(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie", params={ + "vydaniaOd": "01.01.2023", + "vydaniaDo": "31.12.2023", + "size": 3, + }) + assert r.status_code == 200 + + def test_decision_autocomplete(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete", params={"query": "Rozsudok", "limit": 5}) + assert r.status_code == 200 + + +@skip_if_offline +class TestContractsEndpoint: + + def test_contract_search_returns_200(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/zmluvy", params={"size": 5}) + assert r.status_code == 200 + + def test_contract_search_by_typ_dokumentu(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/zmluvy", params={ + "typDokumentuFacetFilter": "ZMLUVA", + "size": 3, + }) + assert r.status_code == 200 + + +@skip_if_offline +class TestCivilProceedingsEndpoint: + + def test_civil_proceedings_returns_200(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania", params={"size": 3}) + assert r.status_code == 200 + + def test_civil_proceedings_date_filter(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania", params={ + "pojednavaniaOd": "01.01.2024", + "pojednavaniaDo": "31.01.2024", + "size": 3, + }) + assert r.status_code == 200 + + +@skip_if_offline +class TestAdminProceedingsEndpoint: + + def test_admin_proceedings_returns_200(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie", params={"size": 3}) + assert r.status_code == 200 + + def test_admin_proceedings_autocomplete(self, client): + r = client.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete", params={"query": "test", "limit": 5}) + assert r.status_code == 200 \ No newline at end of file diff --git a/testing/tests/test_fetch.py b/testing/tests/test_fetch.py new file mode 100644 index 0000000..b0ead81 --- /dev/null +++ b/testing/tests/test_fetch.py @@ -0,0 +1,113 @@ +import pytest +import pytest_asyncio +import httpx +import respx + +from api.fetch_api_data import fetch_api_data, set_log_callback, _cache + + +BASE = "https://obcan.justice.sk/pilot/api/ress-isu-service" + + +@pytest.fixture(autouse=True) +def clear_cache(): + _cache.clear() + yield + _cache.clear() + + +@pytest.mark.asyncio +class TestFetchApiData: + + @respx.mock + async def test_successful_request_returns_dict(self): + url = f"{BASE}/v1/sud" + respx.get(url).mock(return_value=httpx.Response(200, json={"content": [], "totalElements": 0})) + + result = await fetch_api_data(icon="", url=url, params={}) + + assert isinstance(result, dict) + assert "totalElements" in result + + @respx.mock + async def test_cache_hit_on_second_call(self): + url = f"{BASE}/v1/sud" + mock = respx.get(url).mock(return_value=httpx.Response(200, json={"content": []})) + + await fetch_api_data(icon="", url=url, params={}) + await fetch_api_data(icon="", url=url, params={}) + + assert mock.call_count == 1 + + @respx.mock + async def test_http_404_returns_error_dict(self): + url = f"{BASE}/v1/sud/sud_99999" + respx.get(url).mock(return_value=httpx.Response(404, text="Not Found")) + + result = await fetch_api_data(icon="", url=url, params={}) + + assert result["error"] == "http_error" + assert result["status_code"] == 404 + + @respx.mock + async def test_http_500_returns_error_dict(self): + url = f"{BASE}/v1/sud" + respx.get(url).mock(return_value=httpx.Response(500, text="Server Error")) + + result = await fetch_api_data(icon="", url=url, params={}) + + assert result["error"] == "http_error" + assert result["status_code"] == 500 + + @respx.mock + async def test_remove_keys_strips_specified_fields(self): + url = f"{BASE}/v1/sud/sud_1" + respx.get(url).mock(return_value=httpx.Response(200, json={"name": "Súd", "foto": "base64data"})) + + result = await fetch_api_data(icon="", url=url, params={}, remove_keys=["foto"]) + + assert "foto" not in result + assert "name" in result + + @respx.mock + async def test_remove_keys_missing_key_no_error(self): + url = f"{BASE}/v1/sud/sud_1" + respx.get(url).mock(return_value=httpx.Response(200, json={"name": "Súd"})) + + result = await fetch_api_data(icon="", url=url, params={}, remove_keys=["foto", "nonexistent"]) + + assert result["name"] == "Súd" + + @respx.mock + async def test_log_callback_is_called(self): + url = f"{BASE}/v1/sud" + respx.get(url).mock(return_value=httpx.Response(200, json={})) + + log_lines = [] + set_log_callback(lambda line: log_lines.append(line)) + + await fetch_api_data(icon="", url=url, params={}) + + set_log_callback(None) + assert len(log_lines) > 0 + + @respx.mock + async def test_params_are_passed_in_request(self): + url = f"{BASE}/v1/sud" + mock = respx.get(url).mock(return_value=httpx.Response(200, json={})) + + await fetch_api_data(icon="", url=url, params={"query": "Bratislava", "size": 10}) + + assert mock.called + sent_params = dict(mock.calls[0].request.url.params) + assert sent_params["query"] == "Bratislava" + assert sent_params["size"] == "10" + + @respx.mock + async def test_connect_error_returns_error_dict(self): + url = f"{BASE}/v1/sud" + respx.get(url).mock(side_effect=httpx.ConnectError("Connection refused")) + + result = await fetch_api_data(icon="", url=url, params={}) + + assert result["error"] == "request_error" \ No newline at end of file diff --git a/testing/tests/test_llm_compare.py b/testing/tests/test_llm_compare.py new file mode 100644 index 0000000..13a3368 --- /dev/null +++ b/testing/tests/test_llm_compare.py @@ -0,0 +1,120 @@ +import time +import pytest +from openai import OpenAI + +from core.config import ( + OLLAMA_BASE_URL, OLLAMA_API_KEY, LLM_TIMEOUT, + OLLAMA_MODELS, OPENAI_MODELS +) +from core.system_prompt import get_system_prompt + +MODELS = OLLAMA_MODELS | OPENAI_MODELS + +TEST_QUERIES = [ + { + "id": "court_search", + "query": "Nájdi súdy v Bratislavskom kraji.", + "expected_keywords": ["súd", "bratislava", "kraj"], + }, + { + "id": "judge_search", + "query": "Vyhľadaj sudcu Novák.", + "expected_keywords": ["sudca", "novák", "novak"], + }, + { + "id": "no_legal_advice", + "query": "Mám spor so zamestnávateľom, čo mám robiť?", + "forbidden_keywords": ["musíte", "odporúčam vám podať žalobu", "právne poradenstvo"], + "expected_keywords": ["api", "ministerstvo", "nie som právnik", "právny poradca"], + }, + { + "id": "slovak_response", + "query": "What courts exist in Slovakia?", + "expected_keywords": ["súd", "slovensko", "kraj"], + }, +] + + +def ollama_available() -> bool: + try: + client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY) + client.models.list() + return True + except Exception: + return False + + +skip_if_no_ollama = pytest.mark.skipif( + not ollama_available(), + reason="Ollama is not running" +) + + +def query_model(model: str, user_message: str) -> tuple[str, float]: + client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY, timeout=LLM_TIMEOUT) + start = time.perf_counter() + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": get_system_prompt(model)}, + {"role": "user", "content": user_message}, + ], + temperature=0.0, + max_tokens=512, + ) + elapsed = time.perf_counter() - start + text = response.choices[0].message.content or "" + return text, elapsed + + +llm_results: dict[str, list[dict]] = {m: [] for m in MODELS} + + +@skip_if_no_ollama +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("case", TEST_QUERIES, ids=[c["id"] for c in TEST_QUERIES]) +class TestLLMResponses: + + def test_response_is_not_empty(self, model, case): + text, _ = query_model(model, case["query"]) + assert len(text.strip()) > 0 + + def test_response_in_slovak(self, model, case): + text, _ = query_model(model, case["query"]) + slovak_markers = ["je", "sú", "som", "nie", "súd", "sudca", "kraj", "ale", "alebo", "pre"] + assert any(m in text.lower() for m in slovak_markers) + + def test_expected_keywords_present(self, model, case): + if "expected_keywords" not in case: + pytest.skip("No expected_keywords defined") + text, _ = query_model(model, case["query"]) + assert any(kw.lower() in text.lower() for kw in case["expected_keywords"]) + + def test_forbidden_keywords_absent(self, model, case): + if "forbidden_keywords" not in case: + pytest.skip("No forbidden_keywords defined") + text, _ = query_model(model, case["query"]) + for kw in case["forbidden_keywords"]: + assert kw.lower() not in text.lower(), f"Forbidden keyword found: {kw}" + + def test_response_time_under_threshold(self, model, case): + _, elapsed = query_model(model, case["query"]) + assert elapsed < float(LLM_TIMEOUT), f"Response took {elapsed:.1f}s" + + def test_response_length_reasonable(self, model, case): + text, _ = query_model(model, case["query"]) + assert 10 < len(text) < 4000 + + +@skip_if_no_ollama +@pytest.mark.parametrize("model", MODELS) +class TestLLMBenchmark: + + def test_collect_benchmark_data(self, model): + """Collects timing data per model — used by charts.py.""" + times = [] + for case in TEST_QUERIES: + _, elapsed = query_model(model, case["query"]) + times.append(elapsed) + llm_results[model].extend(times) + assert len(times) == len(TEST_QUERIES) \ No newline at end of file diff --git a/testing/tests/test_schemas.py b/testing/tests/test_schemas.py new file mode 100644 index 0000000..e144338 --- /dev/null +++ b/testing/tests/test_schemas.py @@ -0,0 +1,168 @@ +import pytest +from pydantic import ValidationError + +from api.schemas import ( + CourtByID, + JudgeByID, + AdminProceedingsByID, + CourtSearch, + JudgeSearch, + DecisionSearch, + ContractSearch, + CivilProceedingsSearch, + CourtAutocomplete, + JudgeAutocomplete, +) + + +class TestCourtByID: + + def test_digit_gets_prefix(self): + assert CourtByID(id="175").id == "sud_175" + + def test_already_prefixed_unchanged(self): + assert CourtByID(id="sud_175").id == "sud_175" + + def test_strips_whitespace(self): + assert CourtByID(id=" 42 ").id == "sud_42" + + def test_single_digit(self): + assert CourtByID(id="1").id == "sud_1" + + def test_large_number(self): + assert CourtByID(id="9999").id == "sud_9999" + + def test_non_numeric_string_unchanged(self): + assert CourtByID(id="some_string").id == "some_string" + + def test_whitespace_digit_combination(self): + assert CourtByID(id=" 7 ").id == "sud_7" + + +class TestJudgeByID: + + def test_digit_gets_prefix(self): + assert JudgeByID(id="1").id == "sudca_1" + + def test_already_prefixed_unchanged(self): + assert JudgeByID(id="sudca_99").id == "sudca_99" + + def test_strips_whitespace(self): + assert JudgeByID(id=" 7 ").id == "sudca_7" + + def test_large_number(self): + assert JudgeByID(id="12345").id == "sudca_12345" + + def test_non_numeric_string_unchanged(self): + assert JudgeByID(id="sudca_abc").id == "sudca_abc" + + +class TestAdminProceedingsByID: + + def test_digit_gets_prefix(self): + assert AdminProceedingsByID(id="103").id == "spravneKonanie_103" + + def test_already_prefixed_unchanged(self): + assert AdminProceedingsByID(id="spravneKonanie_103").id == "spravneKonanie_103" + + def test_strips_whitespace(self): + assert AdminProceedingsByID(id=" 55 ").id == "spravneKonanie_55" + + def test_single_digit(self): + assert AdminProceedingsByID(id="5").id == "spravneKonanie_5" + + def test_non_numeric_string_unchanged(self): + assert AdminProceedingsByID(id="custom_id").id == "custom_id" + + +class TestPaginationDefaults: + + def test_court_search_defaults_are_none(self): + obj = CourtSearch() + assert obj.page is None + assert obj.size is None + + def test_sort_direction_default_asc(self): + assert CourtSearch().sortDirection == "ASC" + + def test_sort_direction_desc_accepted(self): + assert CourtSearch(sortDirection="DESC").sortDirection == "DESC" + + def test_sort_direction_invalid_rejected(self): + with pytest.raises(ValidationError): + CourtSearch(sortDirection="INVALID") + + def test_page_cannot_be_negative(self): + with pytest.raises(ValidationError): + CourtSearch(page=-1) + + def test_size_cannot_be_zero(self): + with pytest.raises(ValidationError): + CourtSearch(size=0) + + def test_size_one_accepted(self): + assert CourtSearch(size=1).size == 1 + + def test_page_zero_accepted(self): + assert CourtSearch(page=0).page == 0 + + +class TestFacetFilters: + + def test_court_search_facet_type_list(self): + obj = CourtSearch(typSuduFacetFilter=["Okresný súd", "Krajský súd"]) + assert len(obj.typSuduFacetFilter) == 2 + + def test_judge_search_facet_kraj(self): + obj = JudgeSearch(krajFacetFilter=["Bratislavský kraj"]) + assert obj.krajFacetFilter == ["Bratislavský kraj"] + + def test_decision_search_forma_filter(self): + obj = DecisionSearch(formaRozhodnutiaFacetFilter=["Rozsudok"]) + assert "Rozsudok" in obj.formaRozhodnutiaFacetFilter + + def test_contract_search_typ_dokumentu(self): + obj = ContractSearch(typDokumentuFacetFilter=["ZMLUVA", "DODATOK"]) + assert len(obj.typDokumentuFacetFilter) == 2 + + def test_civil_proceedings_usek_filter(self): + obj = CivilProceedingsSearch(usekFacetFilter=["C", "O"]) + assert "C" in obj.usekFacetFilter + + +class TestAutocomplete: + + def test_court_autocomplete_limit_min_one(self): + with pytest.raises(ValidationError): + CourtAutocomplete(limit=0) + + def test_court_autocomplete_limit_valid(self): + assert CourtAutocomplete(limit=5).limit == 5 + + def test_judge_autocomplete_guid_sud(self): + obj = JudgeAutocomplete(guidSud="sud_100", limit=10) + assert obj.guidSud == "sud_100" + + def test_autocomplete_empty_query_accepted(self): + assert CourtAutocomplete().query is None + + def test_autocomplete_query_string(self): + assert JudgeAutocomplete(query="Novák").query == "Novák" + + +class TestModelDumpExcludeNone: + + def test_excludes_none_fields(self): + dumped = CourtSearch(query="Bratislava").model_dump(exclude_none=True) + assert "query" in dumped + assert "page" not in dumped + assert "size" not in dumped + + def test_full_params_included(self): + dumped = JudgeSearch(query="Novák", page=0, size=20).model_dump(exclude_none=True) + assert dumped["query"] == "Novák" + assert dumped["page"] == 0 + assert dumped["size"] == 20 + + def test_empty_schema_dumps_empty_dict(self): + assert CourtSearch().model_dump(exclude_none=True) == {} \ No newline at end of file diff --git a/testing/tests/test_sys_prompt.py b/testing/tests/test_sys_prompt.py new file mode 100644 index 0000000..2c01b3f --- /dev/null +++ b/testing/tests/test_sys_prompt.py @@ -0,0 +1,123 @@ +import pytest +from core.system_prompt import get_system_prompt +from core.config import OLLAMA_MODELS, OPENAI_MODELS + +MODELS = OLLAMA_MODELS | OPENAI_MODELS + +@pytest.fixture(params=MODELS) +def prompt(request): + return get_system_prompt(request.param) + + +@pytest.fixture(params=MODELS) +def prompt_lower(request): + return get_system_prompt(request.param).lower() + + +class TestPromptContainsModelName: + + def test_model_name_appears_in_prompt(self, prompt, request): + model = request.node.callspec.params["prompt"] + assert model in prompt + + +class TestRequiredSections: + + def test_has_role_section(self, prompt_lower): + assert "role" in prompt_lower + + def test_has_operational_constraints(self, prompt_lower): + assert "constraint" in prompt_lower or "not allowed" in prompt_lower or "do not" in prompt_lower + + def test_has_workflow_steps(self, prompt_lower): + assert "step" in prompt_lower or "workflow" in prompt_lower + + def test_has_error_recovery(self, prompt_lower): + assert "error" in prompt_lower or "not found" in prompt_lower + + def test_has_response_format(self, prompt_lower): + assert "format" in prompt_lower or "response" in prompt_lower + + +class TestSupportedDomains: + + def test_courts_mentioned(self, prompt_lower): + assert "court" in prompt_lower or "súd" in prompt_lower or "sud" in prompt_lower + + def test_judges_mentioned(self, prompt_lower): + assert "judge" in prompt_lower or "sudca" in prompt_lower or "sudcovia" in prompt_lower + + def test_decisions_mentioned(self, prompt_lower): + assert "decision" in prompt_lower or "rozhodnut" in prompt_lower + + def test_contracts_mentioned(self, prompt_lower): + assert "contract" in prompt_lower or "zmluv" in prompt_lower + + def test_civil_proceedings_mentioned(self, prompt_lower): + assert "civil" in prompt_lower or "obcian" in prompt_lower or "pojednavan" in prompt_lower + + def test_admin_proceedings_mentioned(self, prompt_lower): + assert "admin" in prompt_lower or "spravne" in prompt_lower or "správne" in prompt_lower + + +class TestConstraints: + + def test_no_legal_advice_constraint(self, prompt_lower): + assert "legal advisor" in prompt_lower or "not a lawyer" in prompt_lower or "legal advice" in prompt_lower + + def test_api_only_constraint(self, prompt_lower): + assert "api" in prompt_lower + + def test_slovak_language_requirement(self, prompt_lower): + assert "slovak" in prompt_lower or "slovensk" in prompt_lower + + def test_no_raw_json_rule(self, prompt_lower): + assert "json" in prompt_lower or "technical" in prompt_lower + + def test_no_speculate_rule(self, prompt_lower): + assert "speculate" in prompt_lower or "infer" in prompt_lower or "gaps" in prompt_lower + + +class TestPaginationRules: + + def test_page_starts_at_zero_mentioned(self, prompt_lower): + assert "page" in prompt_lower and "0" in prompt_lower + + def test_autocomplete_preferred_mentioned(self, prompt_lower): + assert "autocomplete" in prompt_lower + + +class TestDateRules: + + def test_date_format_dd_mm_yyyy_mentioned(self, prompt): + assert "DD.MM.YYYY" in prompt or "dd.mm.yyyy" in prompt.lower() + + def test_civil_date_field_mentioned(self, prompt_lower): + assert "pojednavaniaod" in prompt_lower or "pojednavania" in prompt_lower + + def test_decision_date_field_mentioned(self, prompt_lower): + assert "vydaniaod" in prompt_lower or "vydania" in prompt_lower + + +class TestIDNormalizationRules: + + def test_sud_prefix_mentioned(self, prompt_lower): + assert "sud_" in prompt_lower + + def test_sudca_prefix_mentioned(self, prompt_lower): + assert "sudca_" in prompt_lower + + def test_spravnekonanie_prefix_mentioned(self, prompt_lower): + assert "spravnekonanie_" in prompt_lower or "spravne" in prompt_lower + + +class TestPromptLength: + + def test_prompt_is_not_empty(self, prompt): + assert len(prompt.strip()) > 0 + + def test_prompt_has_minimum_length(self, prompt): + assert len(prompt) > 500 + + def test_prompt_has_reasonable_max_length(self, prompt): + assert len(prompt) < 50_000 \ No newline at end of file diff --git a/testing/tests/test_tools.py b/testing/tests/test_tools.py new file mode 100644 index 0000000..55cd4c9 --- /dev/null +++ b/testing/tests/test_tools.py @@ -0,0 +1,152 @@ +import pytest +import httpx +import respx + +from api.fetch_api_data import _cache +from api.config import JUSTICE_API_BASE +from api.tools import ( + court_search, + court_id, + court_autocomplete, + judge_search, + judge_id, + judge_autocomplete, + decision_search, + contract_search, + civil_proceedings_search, + admin_proceedings_search, +) + + +EMPTY_LIST_RESPONSE = {"content": [], "totalElements": 0, "numberOfElements": 0} + + +@pytest.fixture(autouse=True) +def clear_cache(): + _cache.clear() + yield + _cache.clear() + + +def make_response(body: dict = None): + return httpx.Response(200, json=body or EMPTY_LIST_RESPONSE) + + +@pytest.mark.asyncio +class TestCourtTools: + + @respx.mock + async def test_court_search_calls_correct_url(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud").mock(return_value=make_response()) + await court_search.on_invoke_tool(None, '{"query": "Bratislava"}') + assert mock.called + + @respx.mock + async def test_court_id_calls_correct_url(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud/sud_175").mock( + return_value=httpx.Response(200, json={"id": "sud_175", "foto": "data"}) + ) + await court_id.on_invoke_tool(None, '{"id": "175"}') + assert mock.called + + @respx.mock + async def test_court_id_removes_foto_key(self): + respx.get(f"{JUSTICE_API_BASE}/v1/sud/sud_1").mock( + return_value=httpx.Response(200, json={"name": "Súd", "foto": "base64"}) + ) + result = await court_id.on_invoke_tool(None, '{"id": "1"}') + assert "foto" not in str(result) + + @respx.mock + async def test_court_autocomplete_calls_correct_url(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud/autocomplete").mock(return_value=make_response()) + await court_autocomplete.on_invoke_tool(None, '{"query": "Kraj"}') + assert mock.called + + +@pytest.mark.asyncio +class TestJudgeTools: + + @respx.mock + async def test_judge_search_calls_correct_url(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca").mock(return_value=make_response()) + await judge_search.on_invoke_tool(None, '{"query": "Novák"}') + assert mock.called + + @respx.mock + async def test_judge_id_normalizes_digit_id(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca/sudca_1").mock( + return_value=httpx.Response(200, json={"id": "sudca_1"}) + ) + await judge_id.on_invoke_tool(None, '{"id": "1"}') + assert mock.called + + @respx.mock + async def test_judge_autocomplete_passes_guid_sud(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca/autocomplete").mock(return_value=make_response()) + await judge_autocomplete.on_invoke_tool(None, '{"query": "Novák", "guidSud": "sud_100"}') + assert mock.called + params = dict(mock.calls[0].request.url.params) + assert params.get("guidSud") == "sud_100" + + +@pytest.mark.asyncio +class TestDecisionTools: + + @respx.mock + async def test_decision_search_with_date_range(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie").mock(return_value=make_response()) + await decision_search.on_invoke_tool( + None, '{"vydaniaOd": "01.01.2024", "vydaniaDo": "31.01.2024"}' + ) + assert mock.called + params = dict(mock.calls[0].request.url.params) + assert params.get("vydaniaOd") == "01.01.2024" + + @respx.mock + async def test_decision_search_with_guid_sudca(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie").mock(return_value=make_response()) + await decision_search.on_invoke_tool(None, '{"guidSudca": "sudca_1"}') + params = dict(mock.calls[0].request.url.params) + assert params.get("guidSudca") == "sudca_1" + + +@pytest.mark.asyncio +class TestContractTools: + + @respx.mock + async def test_contract_search_with_guid_sud(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/zmluvy").mock(return_value=make_response()) + await contract_search.on_invoke_tool(None, '{"guidSud": "sud_7"}') + params = dict(mock.calls[0].request.url.params) + assert params.get("guidSud") == "sud_7" + + @respx.mock + async def test_contract_search_typ_dokumentu_filter(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/zmluvy").mock(return_value=make_response()) + await contract_search.on_invoke_tool(None, '{"typDokumentuFacetFilter": ["ZMLUVA"]}') + assert mock.called + + +@pytest.mark.asyncio +class TestCivilAndAdminTools: + + @respx.mock + async def test_civil_proceedings_search(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania").mock(return_value=make_response()) + await civil_proceedings_search.on_invoke_tool(None, '{"query": "test"}') + assert mock.called + + @respx.mock + async def test_admin_proceedings_search(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie").mock(return_value=make_response()) + await admin_proceedings_search.on_invoke_tool(None, '{"query": "test"}') + assert mock.called + + @respx.mock + async def test_civil_proceedings_date_params(self): + mock = respx.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania").mock(return_value=make_response()) + await civil_proceedings_search.on_invoke_tool( + None, '{"pojednavaniaOd": "01.01.2024", "jednotnavaniaDo": "31.01.2024"}' + ) + assert mock.called \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/results/report_2026-03-16_06-26-14.png b/tests/results/report_2026-03-16_06-26-14.png deleted file mode 100644 index d1dc63e..0000000 Binary files a/tests/results/report_2026-03-16_06-26-14.png and /dev/null differ diff --git a/tests/runner.py b/tests/runner.py deleted file mode 100644 index 36d92ac..0000000 --- a/tests/runner.py +++ /dev/null @@ -1,294 +0,0 @@ -import asyncio -import sqlite3 -import time -import sys -from datetime import datetime -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from api.config import JUSTICE_API_BASE -import httpx - -DB_PATH = Path(__file__).parent / "test_queries.db" -OUT_DIR = Path(__file__).parent / "results" -REQUEST_TIMEOUT = 15 -CONCURRENT_LIMIT = 5 -DELAY_BETWEEN = 0.3 - -HEADERS = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36", - "Accept": "application/json, text/plain, */*", - "Accept-Language": "sk-SK,sk;q=0.9", - "Referer": "https://obcan.justice.sk/", - "Origin": "https://obcan.justice.sk", -} - -TOOL_MAP = { - "judge": (f"{JUSTICE_API_BASE}/v1/sudca","search"), - "judge_id": (f"{JUSTICE_API_BASE}/v1/sudca","id"), - "judge_autocomplete": (f"{JUSTICE_API_BASE}/v1/sudca/autocomplete","autocomplete"), - "court": (f"{JUSTICE_API_BASE}/v1/sud","search"), - "court_id": (f"{JUSTICE_API_BASE}/v1/sud","id"), - "court_autocomplete": (f"{JUSTICE_API_BASE}/v1/sud/autocomplete","autocomplete"), - "decision": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","search"), - "decision_id": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","id"), - "decision_autocomplete": (f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete","autocomplete"), - "contract": (f"{JUSTICE_API_BASE}/v1/zmluvy","search"), - "contract_id": (f"{JUSTICE_API_BASE}/v1/zmluvy","id"), - "contract_autocomplete": (f"{JUSTICE_API_BASE}/v1/zmluvy/autocomplete","autocomplete"), - "civil_proceedings": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","search"), - "civil_proceedings_id": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","id"), - "civil_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania/autocomplete","autocomplete"), - "admin_proceedings": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","search"), - "admin_proceedings_id": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","id"), - "admin_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete","autocomplete"), - "mimo_rozsah": (None,"skip"), -} - - -def load_queries(limit: int = None) -> list[dict]: - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - sql = "SELECT * FROM test_queries ORDER BY id" - if limit: - sql += f" LIMIT {limit}" - rows = conn.execute(sql).fetchall() - conn.close() - return [dict(r) for r in rows] - - -def pick_keyword(row: dict) -> str: - first = (row.get("expected_keywords") or "").split(",")[0].strip() - if len(first) > 2: - return first - words = [w for w in row["query_sk"].split() if len(w) > 3] - return " ".join(words[:2]) or row["query_sk"][:20] - - -def count_results(data) -> int | None: - if isinstance(data, list): - return len(data) - if isinstance(data, dict): - for key in ("totalElements", "total", "count", "numberOfElements"): - if key in data: - return int(data[key]) - if "content" in data: - return len(data["content"]) - return 1 - return None - - -async def validate_one( - client: httpx.AsyncClient, - semaphore: asyncio.Semaphore, - row: dict, - idx: int, - total: int, -) -> dict: - result = { - "id": row["id"], - "category": row["category"], - "difficulty": row["difficulty"], - "tool": row["expected_tool"], - "query": row["query_sk"], - "status": "pending", - "http_code": None, - "result_count": None, - "duration_ms": None, - "error": None, - } - - endpoint, call_type = TOOL_MAP.get(row["expected_tool"], (None, "skip")) - - print(f"Request {idx}/{total}") - - if call_type == "skip": - result["status"] = "SKIP" - return result - - if call_type == "id": - id_val = (row.get("expected_keywords") or "").split(",")[0].strip() - url, params = f"{endpoint}/{id_val}", {} - elif call_type == "autocomplete": - url, params = endpoint, {"query": pick_keyword(row), "limit": 5} - else: - url, params = endpoint, {"query": pick_keyword(row), "size": 5} - - async with semaphore: - await asyncio.sleep(DELAY_BETWEEN) - t0 = time.perf_counter() - try: - resp = await client.get(url, params=params, headers=HEADERS, timeout=REQUEST_TIMEOUT) - result["duration_ms"] = round((time.perf_counter() - t0) * 1000) - result["http_code"] = resp.status_code - - if resp.status_code == 200: - try: - cnt = count_results(resp.json()) - result["result_count"] = cnt - result["status"] = "OK" if cnt and cnt > 0 else "EMPTY" - except Exception: - result["status"] = "OK" - elif resp.status_code == 404: - result["status"] = "NOT_FOUND" - elif resp.status_code == 403: - result["status"] = "FORBIDDEN" - else: - result["status"] = "HTTP_ERROR" - result["error"] = f"HTTP {resp.status_code}" - - except httpx.TimeoutException: - result["status"] = "TIMEOUT" - result["duration_ms"] = REQUEST_TIMEOUT * 1000 - except httpx.ConnectError as e: - result["status"] = "CONN_ERROR" - result["error"] = str(e) - except Exception as e: - result["status"] = "ERROR" - result["error"] = str(e) - - return result - - -async def run_all(rows: list[dict]) -> list[dict]: - semaphore = asyncio.Semaphore(CONCURRENT_LIMIT) - async with httpx.AsyncClient(follow_redirects=True) as client: - tasks = [ - validate_one(client, semaphore, row, i + 1, len(rows)) - for i, row in enumerate(rows) - ] - return list(await asyncio.gather(*tasks)) - - -def generate_charts(results: list[dict], ts: str) -> Path: - import matplotlib - matplotlib.use("Agg") - import matplotlib.pyplot as plt - import matplotlib.patches as mpatches - import numpy as np - from collections import Counter, defaultdict - - STATUS_COLORS = { - "OK": "#1D9E75", - "EMPTY": "#EF9F27", - "SKIP": "#888780", - "NOT_FOUND": "#E24B4A", - "FORBIDDEN": "#D85A30", - "TIMEOUT": "#D4537E", - "HTTP_ERROR": "#E24B4A", - "CONN_ERROR": "#E24B4A", - "ERROR": "#E24B4A", - } - DIFF_COLORS = {"easy": "#1D9E75", "medium": "#EF9F27", "hard": "#E24B4A"} - - OUT_DIR.mkdir(parents=True, exist_ok=True) - - used_statuses = [s for s in STATUS_COLORS if any(r["status"] == s for r in results)] - - fig = plt.figure(figsize=(18, 18), facecolor="#FAFAF8") - fig.suptitle( - f"Validation Report - Legal AI Agent\n{ts.replace('_', ' ')}", - fontsize=15, fontweight="bold", y=0.98, color="#2C2C2A" - ) - # 2 rows: [pie + category] | [histogram + difficulty] - gs = fig.add_gridspec(2, 2, hspace=0.45, wspace=0.35, - top=0.94, bottom=0.04, left=0.07, right=0.97) - - def styled(ax, title): - ax.set_title(title, fontsize=12, fontweight="bold", color="#2C2C2A") - ax.set_facecolor("#F9F9F7") - ax.spines[["top", "right"]].set_visible(False) - ax.grid(axis="y", alpha=0.3, linestyle="--") - - # 1. Pie — overall results - ax1 = fig.add_subplot(gs[0, 0]) - cnt = Counter(r["status"] for r in results) - wedges, texts, pcts = ax1.pie( - cnt.values(), - labels=cnt.keys(), - colors=[STATUS_COLORS.get(k, "#888780") for k in cnt], - autopct="%1.1f%%", startangle=90, pctdistance=0.82, - wedgeprops={"edgecolor": "white", "linewidth": 1.5}, - ) - for t in texts: t.set_fontsize(10) - for t in pcts: t.set_fontsize(9); t.set_color("white"); t.set_fontweight("bold") - ax1.set_title("Overall Results", fontsize=12, fontweight="bold", color="#2C2C2A") - - # 2. Bar — results by category - ax2 = fig.add_subplot(gs[0, 1]) - categories = sorted(set(r["category"] for r in results)) - cat_data = defaultdict(Counter) - for r in results: - cat_data[r["category"]][r["status"]] += 1 - - x, w = np.arange(len(categories)), 0.8 / max(len(used_statuses), 1) - for i, status in enumerate(used_statuses): - vals = [cat_data[c][status] for c in categories] - offset = (i - len(used_statuses) / 2 + 0.5) * w - ax2.bar(x + offset, vals, w, label=status, - color=STATUS_COLORS.get(status, "#888780"), - edgecolor="white", linewidth=0.5) - - ax2.set_xticks(x) - ax2.set_xticklabels([c[:13] for c in categories], rotation=35, ha="right", fontsize=8) - ax2.set_ylabel("Count", fontsize=10) - ax2.legend(fontsize=7, loc="upper right") - styled(ax2, "Results by Category") - - # 3. Histogram — response times - ax4 = fig.add_subplot(gs[1, 0]) - times = [r["duration_ms"] for r in results - if r["duration_ms"] and r["status"] != "SKIP"] - if times: - ax4.hist(times, bins=20, color="#378ADD", edgecolor="white", linewidth=0.7, alpha=0.85) - avg = sum(times) / len(times) - ax4.axvline(avg, color="#E24B4A", linewidth=1.5, linestyle="--", - label=f"Avg: {avg:.0f} ms") - ax4.set_xlabel("ms", fontsize=10) - ax4.set_ylabel("Requests", fontsize=10) - ax4.legend(fontsize=9) - styled(ax4, "API Response Time Distribution") - - # 4. Bar — results by difficulty - ax5 = fig.add_subplot(gs[1, 1]) - difficulties = ["easy", "medium", "hard"] - diff_data = defaultdict(Counter) - for r in results: - diff_data[r["difficulty"]][r["status"]] += 1 - - x2, w2 = np.arange(3), 0.8 / max(len(used_statuses), 1) - for i, status in enumerate(used_statuses): - vals = [diff_data[d][status] for d in difficulties] - offset = (i - len(used_statuses) / 2 + 0.5) * w2 - ax5.bar(x2 + offset, vals, w2, label=status, - color=STATUS_COLORS.get(status, "#888780"), - edgecolor="white", linewidth=0.5) - - ax5.set_xticks(x2) - ax5.set_xticklabels(difficulties, fontsize=10) - ax5.set_ylabel("Count", fontsize=10) - ax5.legend(fontsize=8) - styled(ax5, "Results by Difficulty") - - - out = OUT_DIR / f"report_{ts}.png" - fig.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) - plt.close(fig) - return out - - -if __name__ == "__main__": - limit = int(sys.argv[1]) if len(sys.argv) > 1 else None - - rows = load_queries(limit) - ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - - print(f"Start Test Validation") - - results = asyncio.run(run_all(rows)) - - print(f"End Test Validation") - - out = generate_charts(results, ts) - print(f"Test image is saved: {out}") \ No newline at end of file diff --git a/tests/test_queries.db b/tests/test_queries.db deleted file mode 100644 index d2b25a1..0000000 Binary files a/tests/test_queries.db and /dev/null differ