add tests
This commit is contained in:
parent
18d7980582
commit
1a7e8aa355
16
app.py
16
app.py
@ -9,15 +9,19 @@ from core.stream_response import stream_response
|
|||||||
from api.fetch_api_data import set_log_callback
|
from api.fetch_api_data import set_log_callback
|
||||||
|
|
||||||
STARTERS = [
|
STARTERS = [
|
||||||
("What legal data can the agent find?","magnifying_glass"),
|
("What legal data can the agent find?", "magnifying_glass"),
|
||||||
("What is the agent not allowed to do or use?","ban"),
|
("What is the agent not allowed to do or use?", "ban"),
|
||||||
("What are the details of your AI model?","hexagon"),
|
("What are the details of your AI model?", "hexagon"),
|
||||||
("What data sources does the agent rely on?","database"),
|
("What data sources does the agent rely on?", "database"),
|
||||||
]
|
]
|
||||||
|
|
||||||
PROFILES = [
|
PROFILES = [
|
||||||
("qwen3.5:cloud","Qwen 3.5 CLOUD"),
|
("qwen3.5:cloud", "Qwen 3.5 CLOUD (in Ollama)"),
|
||||||
("gpt-oss:20b-cloud","GPT-OSS 20B CLOUD"),
|
("gpt-oss:20b-cloud", "GPT-OSS 20B CLOUD (in Ollama)"),
|
||||||
|
("gpt-oss:20b", "GPT-OSS 20B (Local LLM)"),
|
||||||
|
("qwen3:8b", "Qwen3 8B (Local LLM)"),
|
||||||
|
("gpt-4o", "GPT-4o (OpenAI API)"),
|
||||||
|
("gpt-4o-mini", "GPT-4o Mini (OpenAI API)"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@cl.set_starters
|
@cl.set_starters
|
||||||
|
|||||||
@ -2,7 +2,13 @@ import os
|
|||||||
|
|
||||||
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "qwen3.5:cloud")
|
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "qwen3.5:cloud")
|
||||||
MAX_HISTORY = int(os.getenv("MAX_HISTORY", "20"))
|
MAX_HISTORY = int(os.getenv("MAX_HISTORY", "20"))
|
||||||
|
AGENT_TEMPERATURE = float(os.getenv("AGENT_TEMPERATURE", "0.7"))
|
||||||
|
LLM_TIMEOUT = float(os.getenv("LLM_TIMEOUT", "120.0"))
|
||||||
|
|
||||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434/v1")
|
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434/v1")
|
||||||
OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "ollama")
|
OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY", "ollama")
|
||||||
OLLAMA_TIMEOUT = float(os.getenv("OLLAMA_TIMEOUT", "120.0"))
|
OLLAMA_MODELS = {"qwen3.5:cloud", "gpt-oss:20b-cloud", "gpt-oss:20b", "qwen3:8b"}
|
||||||
AGENT_TEMPERATURE = float(os.getenv("AGENT_TEMPERATURE", "0.7"))
|
|
||||||
|
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||||
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||||
|
OPENAI_MODELS = {"gpt-4o", "gpt-4o-mini"}
|
||||||
@ -2,7 +2,12 @@ from agents import Agent, AgentHooks
|
|||||||
from agents import OpenAIChatCompletionsModel, AsyncOpenAI, ModelSettings
|
from agents import OpenAIChatCompletionsModel, AsyncOpenAI, ModelSettings
|
||||||
from agents import set_tracing_disabled
|
from agents import set_tracing_disabled
|
||||||
|
|
||||||
from core.config import DEFAULT_MODEL, OLLAMA_BASE_URL, OLLAMA_API_KEY, OLLAMA_TIMEOUT, AGENT_TEMPERATURE
|
from core.config import (
|
||||||
|
DEFAULT_MODEL, AGENT_TEMPERATURE, LLM_TIMEOUT,
|
||||||
|
OLLAMA_BASE_URL, OLLAMA_API_KEY,
|
||||||
|
OPENAI_BASE_URL, OPENAI_API_KEY,
|
||||||
|
OPENAI_MODELS, OLLAMA_MODELS
|
||||||
|
)
|
||||||
from core.system_prompt import get_system_prompt
|
from core.system_prompt import get_system_prompt
|
||||||
from api.tools import ALL_TOOLS
|
from api.tools import ALL_TOOLS
|
||||||
|
|
||||||
@ -15,18 +20,35 @@ class MyAgentHooks(AgentHooks):
|
|||||||
async def on_end(self, context, agent, output):
|
async def on_end(self, context, agent, output):
|
||||||
print(f"🏁 [AgentHooks] {agent.name} ended.")
|
print(f"🏁 [AgentHooks] {agent.name} ended.")
|
||||||
|
|
||||||
def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent:
|
def _make_client(model_name: str) -> AsyncOpenAI:
|
||||||
|
"""Return the correct AsyncOpenAI client based on model name"""
|
||||||
|
|
||||||
client = AsyncOpenAI(
|
if model_name in OPENAI_MODELS:
|
||||||
base_url=OLLAMA_BASE_URL,
|
return AsyncOpenAI(
|
||||||
api_key=OLLAMA_API_KEY,
|
base_url=OPENAI_BASE_URL,
|
||||||
timeout=OLLAMA_TIMEOUT,
|
api_key=OPENAI_API_KEY,
|
||||||
max_retries=0
|
timeout=LLM_TIMEOUT,
|
||||||
)
|
max_retries=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_name in OLLAMA_MODELS:
|
||||||
|
return AsyncOpenAI(
|
||||||
|
base_url=OLLAMA_BASE_URL,
|
||||||
|
api_key=OLLAMA_API_KEY,
|
||||||
|
timeout=LLM_TIMEOUT,
|
||||||
|
max_retries=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Model {model_name} not supported")
|
||||||
|
|
||||||
|
def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent:
|
||||||
|
"""Initialize the assistant agent for legal work"""
|
||||||
|
|
||||||
|
client = _make_client(model_name)
|
||||||
model = OpenAIChatCompletionsModel(model=model_name, openai_client=client)
|
model = OpenAIChatCompletionsModel(model=model_name, openai_client=client)
|
||||||
|
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
name="Assistant",
|
name="AI Lawyer Assistant",
|
||||||
instructions=get_system_prompt(model_name),
|
instructions=get_system_prompt(model_name),
|
||||||
model=model,
|
model=model,
|
||||||
model_settings=ModelSettings(temperature=AGENT_TEMPERATURE, tool_choice="auto", parallel_tool_calls=False),
|
model_settings=ModelSettings(temperature=AGENT_TEMPERATURE, tool_choice="auto", parallel_tool_calls=False),
|
||||||
|
|||||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
9
testing/fixtures.py
Normal file
9
testing/fixtures.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
245
testing/reports/charts.py
Normal file
245
testing/reports/charts.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.patches as mpatches
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
|
||||||
|
RESULTS_FILE = Path(__file__).parent.parent / "results.json"
|
||||||
|
CHARTS_DIR = Path(__file__).parent.parent / "charts"
|
||||||
|
|
||||||
|
STATUS_COLORS = {
|
||||||
|
"passed": "#3ddc84",
|
||||||
|
"failed": "#ff5c5c",
|
||||||
|
"skipped": "#e0c84a",
|
||||||
|
"error": "#ff8a8a",
|
||||||
|
}
|
||||||
|
|
||||||
|
BG_COLOR = "#1a1a1a"
|
||||||
|
PANEL_COLOR = "#242424"
|
||||||
|
TEXT_COLOR = "#e0e0e0"
|
||||||
|
GRID_COLOR = "#333333"
|
||||||
|
FONT_MONO = "monospace"
|
||||||
|
|
||||||
|
|
||||||
|
def _style_ax(ax, title: str):
|
||||||
|
ax.set_facecolor(PANEL_COLOR)
|
||||||
|
ax.set_title(title, color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10)
|
||||||
|
ax.tick_params(colors=TEXT_COLOR, labelsize=8)
|
||||||
|
ax.spines[:].set_color(GRID_COLOR)
|
||||||
|
ax.yaxis.grid(True, color=GRID_COLOR, linewidth=0.5, linestyle="--")
|
||||||
|
ax.set_axisbelow(True)
|
||||||
|
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
||||||
|
label.set_color(TEXT_COLOR)
|
||||||
|
label.set_fontfamily(FONT_MONO)
|
||||||
|
|
||||||
|
|
||||||
|
def load_data() -> tuple[pd.DataFrame, dict]:
|
||||||
|
if not RESULTS_FILE.exists():
|
||||||
|
raise FileNotFoundError(f"results.json not found at {RESULTS_FILE}")
|
||||||
|
|
||||||
|
raw = json.loads(RESULTS_FILE.read_text(encoding="utf-8"))
|
||||||
|
tests = raw.get("tests", [])
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for t in tests:
|
||||||
|
node = t["nodeid"]
|
||||||
|
parts = node.split("::")
|
||||||
|
module = parts[0].replace("tests/", "").replace("/", ".").replace(".py", "")
|
||||||
|
cls = parts[1] if len(parts) >= 3 else "unknown"
|
||||||
|
name = parts[-1]
|
||||||
|
|
||||||
|
duration = t.get("call", {}).get("duration", 0.0) if t.get("call") else 0.0
|
||||||
|
|
||||||
|
rows.append({
|
||||||
|
"module": module,
|
||||||
|
"class": cls,
|
||||||
|
"name": name,
|
||||||
|
"outcome": t["outcome"],
|
||||||
|
"duration": duration,
|
||||||
|
})
|
||||||
|
|
||||||
|
df = pd.DataFrame(rows)
|
||||||
|
summary = raw.get("summary", {})
|
||||||
|
return df, summary
|
||||||
|
|
||||||
|
|
||||||
|
def chart_overall_status(df: pd.DataFrame, ax: plt.Axes):
|
||||||
|
counts = df["outcome"].value_counts()
|
||||||
|
colors = [STATUS_COLORS.get(k, "#888") for k in counts.index]
|
||||||
|
|
||||||
|
wedges, texts, pcts = ax.pie(
|
||||||
|
counts.values,
|
||||||
|
labels=counts.index,
|
||||||
|
colors=colors,
|
||||||
|
autopct="%1.1f%%",
|
||||||
|
startangle=90,
|
||||||
|
pctdistance=0.78,
|
||||||
|
wedgeprops={"edgecolor": BG_COLOR, "linewidth": 2},
|
||||||
|
)
|
||||||
|
for t in texts:
|
||||||
|
t.set_color(TEXT_COLOR)
|
||||||
|
t.set_fontsize(9)
|
||||||
|
t.set_fontfamily(FONT_MONO)
|
||||||
|
for p in pcts:
|
||||||
|
p.set_color("#111")
|
||||||
|
p.set_fontsize(8)
|
||||||
|
p.set_fontweight("bold")
|
||||||
|
|
||||||
|
ax.set_facecolor(PANEL_COLOR)
|
||||||
|
ax.set_title("Overall Results", color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10)
|
||||||
|
|
||||||
|
|
||||||
|
def chart_by_module(df: pd.DataFrame, ax: plt.Axes):
|
||||||
|
pivot = (
|
||||||
|
df.groupby(["module", "outcome"])
|
||||||
|
.size()
|
||||||
|
.unstack(fill_value=0)
|
||||||
|
.reindex(columns=["passed", "failed", "skipped", "error"], fill_value=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
x = np.arange(len(pivot))
|
||||||
|
width = 0.2
|
||||||
|
offset = -(len(pivot.columns) - 1) / 2 * width
|
||||||
|
|
||||||
|
for i, col in enumerate(pivot.columns):
|
||||||
|
bars = ax.bar(
|
||||||
|
x + offset + i * width,
|
||||||
|
pivot[col],
|
||||||
|
width,
|
||||||
|
label=col,
|
||||||
|
color=STATUS_COLORS.get(col, "#888"),
|
||||||
|
edgecolor=BG_COLOR,
|
||||||
|
linewidth=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(pivot.index, rotation=25, ha="right", fontsize=7, fontfamily=FONT_MONO)
|
||||||
|
ax.set_ylabel("Tests", color=TEXT_COLOR, fontsize=9)
|
||||||
|
ax.legend(
|
||||||
|
fontsize=7,
|
||||||
|
labelcolor=TEXT_COLOR,
|
||||||
|
facecolor=PANEL_COLOR,
|
||||||
|
edgecolor=GRID_COLOR,
|
||||||
|
)
|
||||||
|
_style_ax(ax, "Results by Module")
|
||||||
|
|
||||||
|
|
||||||
|
def chart_duration_histogram(df: pd.DataFrame, ax: plt.Axes):
|
||||||
|
durations = df.loc[df["outcome"] != "skipped", "duration"].values * 1000
|
||||||
|
|
||||||
|
if len(durations) == 0:
|
||||||
|
ax.text(0.5, 0.5, "No data", ha="center", va="center", color=TEXT_COLOR)
|
||||||
|
_style_ax(ax, "Test Duration (ms)")
|
||||||
|
return
|
||||||
|
|
||||||
|
mean_ms = float(np.mean(durations))
|
||||||
|
median_ms = float(np.median(durations))
|
||||||
|
p95_ms = float(np.percentile(durations, 95))
|
||||||
|
|
||||||
|
ax.hist(durations, bins=20, color="#5cb8ff", edgecolor=BG_COLOR, linewidth=0.6, alpha=0.85)
|
||||||
|
ax.axvline(mean_ms, color="#3ddc84", linewidth=1.5, linestyle="--", label=f"Mean {mean_ms:.1f} ms")
|
||||||
|
ax.axvline(median_ms, color="#e0c84a", linewidth=1.5, linestyle=":", label=f"Median {median_ms:.1f} ms")
|
||||||
|
ax.axvline(p95_ms, color="#ff5c5c", linewidth=1.5, linestyle="-.", label=f"P95 {p95_ms:.1f} ms")
|
||||||
|
|
||||||
|
ax.set_xlabel("ms", color=TEXT_COLOR, fontsize=9)
|
||||||
|
ax.set_ylabel("Tests", color=TEXT_COLOR, fontsize=9)
|
||||||
|
ax.legend(fontsize=7, labelcolor=TEXT_COLOR, facecolor=PANEL_COLOR, edgecolor=GRID_COLOR)
|
||||||
|
_style_ax(ax, "Test Duration (ms)")
|
||||||
|
|
||||||
|
|
||||||
|
def chart_slowest_tests(df: pd.DataFrame, ax: plt.Axes):
|
||||||
|
top = (
|
||||||
|
df[df["outcome"] != "skipped"]
|
||||||
|
.nlargest(10, "duration")
|
||||||
|
.copy()
|
||||||
|
)
|
||||||
|
top["label"] = top["class"] + "::" + top["name"]
|
||||||
|
top["duration"] = top["duration"] * 1000
|
||||||
|
|
||||||
|
colors = [STATUS_COLORS.get(o, "#888") for o in top["outcome"]]
|
||||||
|
bars = ax.barh(top["label"], top["duration"], color=colors, edgecolor=BG_COLOR, linewidth=0.6)
|
||||||
|
|
||||||
|
ax.set_xlabel("ms", color=TEXT_COLOR, fontsize=9)
|
||||||
|
ax.tick_params(axis="y", labelsize=7)
|
||||||
|
ax.invert_yaxis()
|
||||||
|
_style_ax(ax, "Top 10 Slowest Tests")
|
||||||
|
|
||||||
|
|
||||||
|
def chart_stats_table(df: pd.DataFrame, summary: dict, ax: plt.Axes):
|
||||||
|
ax.set_facecolor(PANEL_COLOR)
|
||||||
|
ax.axis("off")
|
||||||
|
|
||||||
|
total = len(df)
|
||||||
|
passed = summary.get("passed", 0)
|
||||||
|
failed = summary.get("failed", 0)
|
||||||
|
skipped = summary.get("skipped", 0)
|
||||||
|
duration = df["duration"].sum() * 1000
|
||||||
|
|
||||||
|
durations = df.loc[df["outcome"] != "skipped", "duration"].values * 1000
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
["Total tests", str(total)],
|
||||||
|
["Passed", str(passed)],
|
||||||
|
["Failed", str(failed)],
|
||||||
|
["Skipped", str(skipped)],
|
||||||
|
["Pass rate", f"{passed / total * 100:.1f}%" if total else "—"],
|
||||||
|
["Total time", f"{duration:.0f} ms"],
|
||||||
|
["Mean duration", f"{np.mean(durations):.1f} ms" if len(durations) else "—"],
|
||||||
|
["Median", f"{np.median(durations):.1f} ms" if len(durations) else "—"],
|
||||||
|
["P95", f"{np.percentile(durations, 95):.1f} ms" if len(durations) else "—"],
|
||||||
|
]
|
||||||
|
|
||||||
|
table = ax.table(
|
||||||
|
cellText=rows,
|
||||||
|
colLabels=["Metric", "Value"],
|
||||||
|
cellLoc="left",
|
||||||
|
loc="center",
|
||||||
|
colWidths=[0.6, 0.4],
|
||||||
|
)
|
||||||
|
table.auto_set_font_size(False)
|
||||||
|
table.set_fontsize(9)
|
||||||
|
|
||||||
|
for (row, col), cell in table.get_celld().items():
|
||||||
|
cell.set_facecolor("#2e2e2e" if row % 2 == 0 else PANEL_COLOR)
|
||||||
|
cell.set_edgecolor(GRID_COLOR)
|
||||||
|
cell.set_text_props(color=TEXT_COLOR, fontfamily=FONT_MONO)
|
||||||
|
|
||||||
|
ax.set_title("Summary", color=TEXT_COLOR, fontsize=11, fontweight="bold", pad=10)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(output_path: Path = None) -> Path:
|
||||||
|
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_path = output_path or CHARTS_DIR / "report.png"
|
||||||
|
|
||||||
|
df, summary = load_data()
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(18, 14), facecolor=BG_COLOR)
|
||||||
|
fig.suptitle(
|
||||||
|
"Legal AI Assistant — Test Report",
|
||||||
|
fontsize=16, fontweight="bold", color=TEXT_COLOR,
|
||||||
|
y=0.98, fontfamily=FONT_MONO,
|
||||||
|
)
|
||||||
|
|
||||||
|
gs = fig.add_gridspec(2, 3, hspace=0.45, wspace=0.35,
|
||||||
|
top=0.93, bottom=0.05, left=0.06, right=0.97)
|
||||||
|
|
||||||
|
chart_overall_status(df, fig.add_subplot(gs[0, 0]))
|
||||||
|
chart_by_module(df, fig.add_subplot(gs[0, 1:]))
|
||||||
|
chart_duration_histogram(df, fig.add_subplot(gs[1, 0]))
|
||||||
|
chart_slowest_tests(df, fig.add_subplot(gs[1, 1]))
|
||||||
|
chart_stats_table(df, summary, fig.add_subplot(gs[1, 2]))
|
||||||
|
|
||||||
|
fig.savefig(output_path, dpi=150, bbox_inches="tight", facecolor=BG_COLOR)
|
||||||
|
plt.close(fig)
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
out = generate()
|
||||||
|
print(f"Chart saved: {out}")
|
||||||
58
testing/run_tests.py
Normal file
58
testing/run_tests.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
ROOT = Path(__file__).parent.parent
|
||||||
|
TESTS_DIR = Path(__file__).parent
|
||||||
|
RESULTS_JSON = TESTS_DIR / "results.json"
|
||||||
|
REPORT_HTML = TESTS_DIR / "charts" / "report.html"
|
||||||
|
CHARTS_DIR = TESTS_DIR / "charts"
|
||||||
|
|
||||||
|
|
||||||
|
def run_pytest() -> int:
|
||||||
|
CHARTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
args = [
|
||||||
|
sys.executable, "-m", "pytest",
|
||||||
|
|
||||||
|
str(TESTS_DIR / "unit"),
|
||||||
|
str(TESTS_DIR / "integration"),
|
||||||
|
str(TESTS_DIR / "e2e"),
|
||||||
|
str(TESTS_DIR / "llm"),
|
||||||
|
|
||||||
|
"--json-report",
|
||||||
|
f"--json-report-file={RESULTS_JSON}",
|
||||||
|
|
||||||
|
f"--html={REPORT_HTML}",
|
||||||
|
"--self-contained-html",
|
||||||
|
|
||||||
|
f"--cov=api",
|
||||||
|
f"--cov=core",
|
||||||
|
"--cov-report=term-missing",
|
||||||
|
f"--cov-report=html:{CHARTS_DIR / 'coverage'}",
|
||||||
|
|
||||||
|
"-p", "no:terminal",
|
||||||
|
"--tb=short",
|
||||||
|
"-q",
|
||||||
|
]
|
||||||
|
|
||||||
|
return subprocess.run(args, cwd=str(ROOT), check=False).returncode
|
||||||
|
|
||||||
|
|
||||||
|
def run_charts():
|
||||||
|
from testing.reports.charts import generate
|
||||||
|
out = generate()
|
||||||
|
print(f"Charts saved: {out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Start testing")
|
||||||
|
|
||||||
|
code = run_pytest()
|
||||||
|
|
||||||
|
if RESULTS_JSON.exists():
|
||||||
|
run_charts()
|
||||||
|
|
||||||
|
print("Stop testing")
|
||||||
|
|
||||||
|
sys.exit(code)
|
||||||
152
testing/tests/test_api.py
Normal file
152
testing/tests/test_api.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from api.config import JUSTICE_API_BASE, HTTP_TIMEOUT
|
||||||
|
|
||||||
|
HEADERS = {
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36",
|
||||||
|
"Accept": "application/json, text/plain, */*",
|
||||||
|
"Accept-Language": "sk-SK,sk;q=0.9",
|
||||||
|
"Referer": "https://obcan.justice.sk/",
|
||||||
|
"Origin": "https://obcan.justice.sk",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def api_available() -> bool:
|
||||||
|
try:
|
||||||
|
r = httpx.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 1}, headers=HEADERS, timeout=5)
|
||||||
|
return r.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
skip_if_offline = pytest.mark.skipif(
|
||||||
|
not api_available(),
|
||||||
|
reason="justice.sk API is not reachable"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def client():
|
||||||
|
with httpx.Client(headers=HEADERS, timeout=HTTP_TIMEOUT, follow_redirects=True) as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_offline
|
||||||
|
class TestCourtsEndpoint:
|
||||||
|
|
||||||
|
def test_returns_200(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 5})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
def test_response_has_content_key(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 5})
|
||||||
|
data = r.json()
|
||||||
|
assert "content" in data
|
||||||
|
|
||||||
|
def test_total_elements_is_positive(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sud", params={"size": 1})
|
||||||
|
data = r.json()
|
||||||
|
assert data.get("totalElements", 0) > 0
|
||||||
|
|
||||||
|
def test_court_by_id_returns_valid_record(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sud/sud_175")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert "id" in data or "nazov" in data
|
||||||
|
|
||||||
|
def test_court_autocomplete_returns_list(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sud/autocomplete", params={"query": "Bratislava", "limit": 5})
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert isinstance(r.json(), list)
|
||||||
|
|
||||||
|
def test_nonexistent_court_returns_404(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sud/sud_999999999")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_offline
|
||||||
|
class TestJudgesEndpoint:
|
||||||
|
|
||||||
|
def test_judge_search_returns_200(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={"size": 5})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
def test_judge_autocomplete_returns_results(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca/autocomplete", params={"query": "Novák", "limit": 10})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
def test_judge_search_by_kraj(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={
|
||||||
|
"krajFacetFilter": "Bratislavský kraj",
|
||||||
|
"size": 5
|
||||||
|
})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
def test_judge_search_pagination_page_zero(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/sudca", params={"page": 0, "size": 10})
|
||||||
|
data = r.json()
|
||||||
|
assert "content" in data
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_offline
|
||||||
|
class TestDecisionsEndpoint:
|
||||||
|
|
||||||
|
def test_decision_search_returns_200(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie", params={"size": 3})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
def test_decision_search_with_date_range(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie", params={
|
||||||
|
"vydaniaOd": "01.01.2023",
|
||||||
|
"vydaniaDo": "31.12.2023",
|
||||||
|
"size": 3,
|
||||||
|
})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
def test_decision_autocomplete(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete", params={"query": "Rozsudok", "limit": 5})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_offline
|
||||||
|
class TestContractsEndpoint:
|
||||||
|
|
||||||
|
def test_contract_search_returns_200(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/zmluvy", params={"size": 5})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
def test_contract_search_by_typ_dokumentu(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/zmluvy", params={
|
||||||
|
"typDokumentuFacetFilter": "ZMLUVA",
|
||||||
|
"size": 3,
|
||||||
|
})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_offline
|
||||||
|
class TestCivilProceedingsEndpoint:
|
||||||
|
|
||||||
|
def test_civil_proceedings_returns_200(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania", params={"size": 3})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
def test_civil_proceedings_date_filter(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania", params={
|
||||||
|
"pojednavaniaOd": "01.01.2024",
|
||||||
|
"pojednavaniaDo": "31.01.2024",
|
||||||
|
"size": 3,
|
||||||
|
})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_offline
|
||||||
|
class TestAdminProceedingsEndpoint:
|
||||||
|
|
||||||
|
def test_admin_proceedings_returns_200(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie", params={"size": 3})
|
||||||
|
assert r.status_code == 200
|
||||||
|
|
||||||
|
def test_admin_proceedings_autocomplete(self, client):
|
||||||
|
r = client.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete", params={"query": "test", "limit": 5})
|
||||||
|
assert r.status_code == 200
|
||||||
113
testing/tests/test_fetch.py
Normal file
113
testing/tests/test_fetch.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
import httpx
|
||||||
|
import respx
|
||||||
|
|
||||||
|
from api.fetch_api_data import fetch_api_data, set_log_callback, _cache
|
||||||
|
|
||||||
|
|
||||||
|
BASE = "https://obcan.justice.sk/pilot/api/ress-isu-service"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
_cache.clear()
|
||||||
|
yield
|
||||||
|
_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestFetchApiData:
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_successful_request_returns_dict(self):
|
||||||
|
url = f"{BASE}/v1/sud"
|
||||||
|
respx.get(url).mock(return_value=httpx.Response(200, json={"content": [], "totalElements": 0}))
|
||||||
|
|
||||||
|
result = await fetch_api_data(icon="", url=url, params={})
|
||||||
|
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "totalElements" in result
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_cache_hit_on_second_call(self):
|
||||||
|
url = f"{BASE}/v1/sud"
|
||||||
|
mock = respx.get(url).mock(return_value=httpx.Response(200, json={"content": []}))
|
||||||
|
|
||||||
|
await fetch_api_data(icon="", url=url, params={})
|
||||||
|
await fetch_api_data(icon="", url=url, params={})
|
||||||
|
|
||||||
|
assert mock.call_count == 1
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_http_404_returns_error_dict(self):
|
||||||
|
url = f"{BASE}/v1/sud/sud_99999"
|
||||||
|
respx.get(url).mock(return_value=httpx.Response(404, text="Not Found"))
|
||||||
|
|
||||||
|
result = await fetch_api_data(icon="", url=url, params={})
|
||||||
|
|
||||||
|
assert result["error"] == "http_error"
|
||||||
|
assert result["status_code"] == 404
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_http_500_returns_error_dict(self):
|
||||||
|
url = f"{BASE}/v1/sud"
|
||||||
|
respx.get(url).mock(return_value=httpx.Response(500, text="Server Error"))
|
||||||
|
|
||||||
|
result = await fetch_api_data(icon="", url=url, params={})
|
||||||
|
|
||||||
|
assert result["error"] == "http_error"
|
||||||
|
assert result["status_code"] == 500
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_remove_keys_strips_specified_fields(self):
|
||||||
|
url = f"{BASE}/v1/sud/sud_1"
|
||||||
|
respx.get(url).mock(return_value=httpx.Response(200, json={"name": "Súd", "foto": "base64data"}))
|
||||||
|
|
||||||
|
result = await fetch_api_data(icon="", url=url, params={}, remove_keys=["foto"])
|
||||||
|
|
||||||
|
assert "foto" not in result
|
||||||
|
assert "name" in result
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_remove_keys_missing_key_no_error(self):
|
||||||
|
url = f"{BASE}/v1/sud/sud_1"
|
||||||
|
respx.get(url).mock(return_value=httpx.Response(200, json={"name": "Súd"}))
|
||||||
|
|
||||||
|
result = await fetch_api_data(icon="", url=url, params={}, remove_keys=["foto", "nonexistent"])
|
||||||
|
|
||||||
|
assert result["name"] == "Súd"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_log_callback_is_called(self):
|
||||||
|
url = f"{BASE}/v1/sud"
|
||||||
|
respx.get(url).mock(return_value=httpx.Response(200, json={}))
|
||||||
|
|
||||||
|
log_lines = []
|
||||||
|
set_log_callback(lambda line: log_lines.append(line))
|
||||||
|
|
||||||
|
await fetch_api_data(icon="", url=url, params={})
|
||||||
|
|
||||||
|
set_log_callback(None)
|
||||||
|
assert len(log_lines) > 0
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_params_are_passed_in_request(self):
|
||||||
|
url = f"{BASE}/v1/sud"
|
||||||
|
mock = respx.get(url).mock(return_value=httpx.Response(200, json={}))
|
||||||
|
|
||||||
|
await fetch_api_data(icon="", url=url, params={"query": "Bratislava", "size": 10})
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
sent_params = dict(mock.calls[0].request.url.params)
|
||||||
|
assert sent_params["query"] == "Bratislava"
|
||||||
|
assert sent_params["size"] == "10"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_connect_error_returns_error_dict(self):
|
||||||
|
url = f"{BASE}/v1/sud"
|
||||||
|
respx.get(url).mock(side_effect=httpx.ConnectError("Connection refused"))
|
||||||
|
|
||||||
|
result = await fetch_api_data(icon="", url=url, params={})
|
||||||
|
|
||||||
|
assert result["error"] == "request_error"
|
||||||
120
testing/tests/test_llm_compare.py
Normal file
120
testing/tests/test_llm_compare.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from core.config import (
|
||||||
|
OLLAMA_BASE_URL, OLLAMA_API_KEY, LLM_TIMEOUT,
|
||||||
|
OLLAMA_MODELS, OPENAI_MODELS
|
||||||
|
)
|
||||||
|
from core.system_prompt import get_system_prompt
|
||||||
|
|
||||||
|
MODELS = OLLAMA_MODELS | OPENAI_MODELS
|
||||||
|
|
||||||
|
TEST_QUERIES = [
|
||||||
|
{
|
||||||
|
"id": "court_search",
|
||||||
|
"query": "Nájdi súdy v Bratislavskom kraji.",
|
||||||
|
"expected_keywords": ["súd", "bratislava", "kraj"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "judge_search",
|
||||||
|
"query": "Vyhľadaj sudcu Novák.",
|
||||||
|
"expected_keywords": ["sudca", "novák", "novak"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "no_legal_advice",
|
||||||
|
"query": "Mám spor so zamestnávateľom, čo mám robiť?",
|
||||||
|
"forbidden_keywords": ["musíte", "odporúčam vám podať žalobu", "právne poradenstvo"],
|
||||||
|
"expected_keywords": ["api", "ministerstvo", "nie som právnik", "právny poradca"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "slovak_response",
|
||||||
|
"query": "What courts exist in Slovakia?",
|
||||||
|
"expected_keywords": ["súd", "slovensko", "kraj"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def ollama_available() -> bool:
|
||||||
|
try:
|
||||||
|
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY)
|
||||||
|
client.models.list()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
skip_if_no_ollama = pytest.mark.skipif(
|
||||||
|
not ollama_available(),
|
||||||
|
reason="Ollama is not running"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def query_model(model: str, user_message: str) -> tuple[str, float]:
|
||||||
|
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY, timeout=LLM_TIMEOUT)
|
||||||
|
start = time.perf_counter()
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": get_system_prompt(model)},
|
||||||
|
{"role": "user", "content": user_message},
|
||||||
|
],
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=512,
|
||||||
|
)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
text = response.choices[0].message.content or ""
|
||||||
|
return text, elapsed
|
||||||
|
|
||||||
|
|
||||||
|
llm_results: dict[str, list[dict]] = {m: [] for m in MODELS}
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_no_ollama
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("case", TEST_QUERIES, ids=[c["id"] for c in TEST_QUERIES])
|
||||||
|
class TestLLMResponses:
|
||||||
|
|
||||||
|
def test_response_is_not_empty(self, model, case):
|
||||||
|
text, _ = query_model(model, case["query"])
|
||||||
|
assert len(text.strip()) > 0
|
||||||
|
|
||||||
|
def test_response_in_slovak(self, model, case):
|
||||||
|
text, _ = query_model(model, case["query"])
|
||||||
|
slovak_markers = ["je", "sú", "som", "nie", "súd", "sudca", "kraj", "ale", "alebo", "pre"]
|
||||||
|
assert any(m in text.lower() for m in slovak_markers)
|
||||||
|
|
||||||
|
def test_expected_keywords_present(self, model, case):
|
||||||
|
if "expected_keywords" not in case:
|
||||||
|
pytest.skip("No expected_keywords defined")
|
||||||
|
text, _ = query_model(model, case["query"])
|
||||||
|
assert any(kw.lower() in text.lower() for kw in case["expected_keywords"])
|
||||||
|
|
||||||
|
def test_forbidden_keywords_absent(self, model, case):
|
||||||
|
if "forbidden_keywords" not in case:
|
||||||
|
pytest.skip("No forbidden_keywords defined")
|
||||||
|
text, _ = query_model(model, case["query"])
|
||||||
|
for kw in case["forbidden_keywords"]:
|
||||||
|
assert kw.lower() not in text.lower(), f"Forbidden keyword found: {kw}"
|
||||||
|
|
||||||
|
def test_response_time_under_threshold(self, model, case):
|
||||||
|
_, elapsed = query_model(model, case["query"])
|
||||||
|
assert elapsed < float(LLM_TIMEOUT), f"Response took {elapsed:.1f}s"
|
||||||
|
|
||||||
|
def test_response_length_reasonable(self, model, case):
|
||||||
|
text, _ = query_model(model, case["query"])
|
||||||
|
assert 10 < len(text) < 4000
|
||||||
|
|
||||||
|
|
||||||
|
@skip_if_no_ollama
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
class TestLLMBenchmark:
|
||||||
|
|
||||||
|
def test_collect_benchmark_data(self, model):
|
||||||
|
"""Collects timing data per model — used by charts.py."""
|
||||||
|
times = []
|
||||||
|
for case in TEST_QUERIES:
|
||||||
|
_, elapsed = query_model(model, case["query"])
|
||||||
|
times.append(elapsed)
|
||||||
|
llm_results[model].extend(times)
|
||||||
|
assert len(times) == len(TEST_QUERIES)
|
||||||
168
testing/tests/test_schemas.py
Normal file
168
testing/tests/test_schemas.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from api.schemas import (
|
||||||
|
CourtByID,
|
||||||
|
JudgeByID,
|
||||||
|
AdminProceedingsByID,
|
||||||
|
CourtSearch,
|
||||||
|
JudgeSearch,
|
||||||
|
DecisionSearch,
|
||||||
|
ContractSearch,
|
||||||
|
CivilProceedingsSearch,
|
||||||
|
CourtAutocomplete,
|
||||||
|
JudgeAutocomplete,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCourtByID:
|
||||||
|
|
||||||
|
def test_digit_gets_prefix(self):
|
||||||
|
assert CourtByID(id="175").id == "sud_175"
|
||||||
|
|
||||||
|
def test_already_prefixed_unchanged(self):
|
||||||
|
assert CourtByID(id="sud_175").id == "sud_175"
|
||||||
|
|
||||||
|
def test_strips_whitespace(self):
|
||||||
|
assert CourtByID(id=" 42 ").id == "sud_42"
|
||||||
|
|
||||||
|
def test_single_digit(self):
|
||||||
|
assert CourtByID(id="1").id == "sud_1"
|
||||||
|
|
||||||
|
def test_large_number(self):
|
||||||
|
assert CourtByID(id="9999").id == "sud_9999"
|
||||||
|
|
||||||
|
def test_non_numeric_string_unchanged(self):
|
||||||
|
assert CourtByID(id="some_string").id == "some_string"
|
||||||
|
|
||||||
|
def test_whitespace_digit_combination(self):
|
||||||
|
assert CourtByID(id=" 7 ").id == "sud_7"
|
||||||
|
|
||||||
|
|
||||||
|
class TestJudgeByID:
|
||||||
|
|
||||||
|
def test_digit_gets_prefix(self):
|
||||||
|
assert JudgeByID(id="1").id == "sudca_1"
|
||||||
|
|
||||||
|
def test_already_prefixed_unchanged(self):
|
||||||
|
assert JudgeByID(id="sudca_99").id == "sudca_99"
|
||||||
|
|
||||||
|
def test_strips_whitespace(self):
|
||||||
|
assert JudgeByID(id=" 7 ").id == "sudca_7"
|
||||||
|
|
||||||
|
def test_large_number(self):
|
||||||
|
assert JudgeByID(id="12345").id == "sudca_12345"
|
||||||
|
|
||||||
|
def test_non_numeric_string_unchanged(self):
|
||||||
|
assert JudgeByID(id="sudca_abc").id == "sudca_abc"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdminProceedingsByID:
|
||||||
|
|
||||||
|
def test_digit_gets_prefix(self):
|
||||||
|
assert AdminProceedingsByID(id="103").id == "spravneKonanie_103"
|
||||||
|
|
||||||
|
def test_already_prefixed_unchanged(self):
|
||||||
|
assert AdminProceedingsByID(id="spravneKonanie_103").id == "spravneKonanie_103"
|
||||||
|
|
||||||
|
def test_strips_whitespace(self):
|
||||||
|
assert AdminProceedingsByID(id=" 55 ").id == "spravneKonanie_55"
|
||||||
|
|
||||||
|
def test_single_digit(self):
|
||||||
|
assert AdminProceedingsByID(id="5").id == "spravneKonanie_5"
|
||||||
|
|
||||||
|
def test_non_numeric_string_unchanged(self):
|
||||||
|
assert AdminProceedingsByID(id="custom_id").id == "custom_id"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPaginationDefaults:
|
||||||
|
|
||||||
|
def test_court_search_defaults_are_none(self):
|
||||||
|
obj = CourtSearch()
|
||||||
|
assert obj.page is None
|
||||||
|
assert obj.size is None
|
||||||
|
|
||||||
|
def test_sort_direction_default_asc(self):
|
||||||
|
assert CourtSearch().sortDirection == "ASC"
|
||||||
|
|
||||||
|
def test_sort_direction_desc_accepted(self):
|
||||||
|
assert CourtSearch(sortDirection="DESC").sortDirection == "DESC"
|
||||||
|
|
||||||
|
def test_sort_direction_invalid_rejected(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
CourtSearch(sortDirection="INVALID")
|
||||||
|
|
||||||
|
def test_page_cannot_be_negative(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
CourtSearch(page=-1)
|
||||||
|
|
||||||
|
def test_size_cannot_be_zero(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
CourtSearch(size=0)
|
||||||
|
|
||||||
|
def test_size_one_accepted(self):
|
||||||
|
assert CourtSearch(size=1).size == 1
|
||||||
|
|
||||||
|
def test_page_zero_accepted(self):
|
||||||
|
assert CourtSearch(page=0).page == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestFacetFilters:
|
||||||
|
|
||||||
|
def test_court_search_facet_type_list(self):
|
||||||
|
obj = CourtSearch(typSuduFacetFilter=["Okresný súd", "Krajský súd"])
|
||||||
|
assert len(obj.typSuduFacetFilter) == 2
|
||||||
|
|
||||||
|
def test_judge_search_facet_kraj(self):
|
||||||
|
obj = JudgeSearch(krajFacetFilter=["Bratislavský kraj"])
|
||||||
|
assert obj.krajFacetFilter == ["Bratislavský kraj"]
|
||||||
|
|
||||||
|
def test_decision_search_forma_filter(self):
|
||||||
|
obj = DecisionSearch(formaRozhodnutiaFacetFilter=["Rozsudok"])
|
||||||
|
assert "Rozsudok" in obj.formaRozhodnutiaFacetFilter
|
||||||
|
|
||||||
|
def test_contract_search_typ_dokumentu(self):
|
||||||
|
obj = ContractSearch(typDokumentuFacetFilter=["ZMLUVA", "DODATOK"])
|
||||||
|
assert len(obj.typDokumentuFacetFilter) == 2
|
||||||
|
|
||||||
|
def test_civil_proceedings_usek_filter(self):
|
||||||
|
obj = CivilProceedingsSearch(usekFacetFilter=["C", "O"])
|
||||||
|
assert "C" in obj.usekFacetFilter
|
||||||
|
|
||||||
|
|
||||||
|
class TestAutocomplete:
|
||||||
|
|
||||||
|
def test_court_autocomplete_limit_min_one(self):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
CourtAutocomplete(limit=0)
|
||||||
|
|
||||||
|
def test_court_autocomplete_limit_valid(self):
|
||||||
|
assert CourtAutocomplete(limit=5).limit == 5
|
||||||
|
|
||||||
|
def test_judge_autocomplete_guid_sud(self):
|
||||||
|
obj = JudgeAutocomplete(guidSud="sud_100", limit=10)
|
||||||
|
assert obj.guidSud == "sud_100"
|
||||||
|
|
||||||
|
def test_autocomplete_empty_query_accepted(self):
|
||||||
|
assert CourtAutocomplete().query is None
|
||||||
|
|
||||||
|
def test_autocomplete_query_string(self):
|
||||||
|
assert JudgeAutocomplete(query="Novák").query == "Novák"
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelDumpExcludeNone:
|
||||||
|
|
||||||
|
def test_excludes_none_fields(self):
|
||||||
|
dumped = CourtSearch(query="Bratislava").model_dump(exclude_none=True)
|
||||||
|
assert "query" in dumped
|
||||||
|
assert "page" not in dumped
|
||||||
|
assert "size" not in dumped
|
||||||
|
|
||||||
|
def test_full_params_included(self):
|
||||||
|
dumped = JudgeSearch(query="Novák", page=0, size=20).model_dump(exclude_none=True)
|
||||||
|
assert dumped["query"] == "Novák"
|
||||||
|
assert dumped["page"] == 0
|
||||||
|
assert dumped["size"] == 20
|
||||||
|
|
||||||
|
def test_empty_schema_dumps_empty_dict(self):
|
||||||
|
assert CourtSearch().model_dump(exclude_none=True) == {}
|
||||||
123
testing/tests/test_sys_prompt.py
Normal file
123
testing/tests/test_sys_prompt.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import pytest
|
||||||
|
from core.system_prompt import get_system_prompt
|
||||||
|
from core.config import OLLAMA_MODELS, OPENAI_MODELS
|
||||||
|
|
||||||
|
MODELS = OLLAMA_MODELS | OPENAI_MODELS
|
||||||
|
|
||||||
|
@pytest.fixture(params=MODELS)
|
||||||
|
def prompt(request):
|
||||||
|
return get_system_prompt(request.param)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=MODELS)
|
||||||
|
def prompt_lower(request):
|
||||||
|
return get_system_prompt(request.param).lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptContainsModelName:
|
||||||
|
|
||||||
|
def test_model_name_appears_in_prompt(self, prompt, request):
|
||||||
|
model = request.node.callspec.params["prompt"]
|
||||||
|
assert model in prompt
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequiredSections:
|
||||||
|
|
||||||
|
def test_has_role_section(self, prompt_lower):
|
||||||
|
assert "role" in prompt_lower
|
||||||
|
|
||||||
|
def test_has_operational_constraints(self, prompt_lower):
|
||||||
|
assert "constraint" in prompt_lower or "not allowed" in prompt_lower or "do not" in prompt_lower
|
||||||
|
|
||||||
|
def test_has_workflow_steps(self, prompt_lower):
|
||||||
|
assert "step" in prompt_lower or "workflow" in prompt_lower
|
||||||
|
|
||||||
|
def test_has_error_recovery(self, prompt_lower):
|
||||||
|
assert "error" in prompt_lower or "not found" in prompt_lower
|
||||||
|
|
||||||
|
def test_has_response_format(self, prompt_lower):
|
||||||
|
assert "format" in prompt_lower or "response" in prompt_lower
|
||||||
|
|
||||||
|
|
||||||
|
class TestSupportedDomains:
|
||||||
|
|
||||||
|
def test_courts_mentioned(self, prompt_lower):
|
||||||
|
assert "court" in prompt_lower or "súd" in prompt_lower or "sud" in prompt_lower
|
||||||
|
|
||||||
|
def test_judges_mentioned(self, prompt_lower):
|
||||||
|
assert "judge" in prompt_lower or "sudca" in prompt_lower or "sudcovia" in prompt_lower
|
||||||
|
|
||||||
|
def test_decisions_mentioned(self, prompt_lower):
|
||||||
|
assert "decision" in prompt_lower or "rozhodnut" in prompt_lower
|
||||||
|
|
||||||
|
def test_contracts_mentioned(self, prompt_lower):
|
||||||
|
assert "contract" in prompt_lower or "zmluv" in prompt_lower
|
||||||
|
|
||||||
|
def test_civil_proceedings_mentioned(self, prompt_lower):
|
||||||
|
assert "civil" in prompt_lower or "obcian" in prompt_lower or "pojednavan" in prompt_lower
|
||||||
|
|
||||||
|
def test_admin_proceedings_mentioned(self, prompt_lower):
|
||||||
|
assert "admin" in prompt_lower or "spravne" in prompt_lower or "správne" in prompt_lower
|
||||||
|
|
||||||
|
|
||||||
|
class TestConstraints:
|
||||||
|
|
||||||
|
def test_no_legal_advice_constraint(self, prompt_lower):
|
||||||
|
assert "legal advisor" in prompt_lower or "not a lawyer" in prompt_lower or "legal advice" in prompt_lower
|
||||||
|
|
||||||
|
def test_api_only_constraint(self, prompt_lower):
|
||||||
|
assert "api" in prompt_lower
|
||||||
|
|
||||||
|
def test_slovak_language_requirement(self, prompt_lower):
|
||||||
|
assert "slovak" in prompt_lower or "slovensk" in prompt_lower
|
||||||
|
|
||||||
|
def test_no_raw_json_rule(self, prompt_lower):
|
||||||
|
assert "json" in prompt_lower or "technical" in prompt_lower
|
||||||
|
|
||||||
|
def test_no_speculate_rule(self, prompt_lower):
|
||||||
|
assert "speculate" in prompt_lower or "infer" in prompt_lower or "gaps" in prompt_lower
|
||||||
|
|
||||||
|
|
||||||
|
class TestPaginationRules:
|
||||||
|
|
||||||
|
def test_page_starts_at_zero_mentioned(self, prompt_lower):
|
||||||
|
assert "page" in prompt_lower and "0" in prompt_lower
|
||||||
|
|
||||||
|
def test_autocomplete_preferred_mentioned(self, prompt_lower):
|
||||||
|
assert "autocomplete" in prompt_lower
|
||||||
|
|
||||||
|
|
||||||
|
class TestDateRules:
|
||||||
|
|
||||||
|
def test_date_format_dd_mm_yyyy_mentioned(self, prompt):
|
||||||
|
assert "DD.MM.YYYY" in prompt or "dd.mm.yyyy" in prompt.lower()
|
||||||
|
|
||||||
|
def test_civil_date_field_mentioned(self, prompt_lower):
|
||||||
|
assert "pojednavaniaod" in prompt_lower or "pojednavania" in prompt_lower
|
||||||
|
|
||||||
|
def test_decision_date_field_mentioned(self, prompt_lower):
|
||||||
|
assert "vydaniaod" in prompt_lower or "vydania" in prompt_lower
|
||||||
|
|
||||||
|
|
||||||
|
class TestIDNormalizationRules:
|
||||||
|
|
||||||
|
def test_sud_prefix_mentioned(self, prompt_lower):
|
||||||
|
assert "sud_" in prompt_lower
|
||||||
|
|
||||||
|
def test_sudca_prefix_mentioned(self, prompt_lower):
|
||||||
|
assert "sudca_" in prompt_lower
|
||||||
|
|
||||||
|
def test_spravnekonanie_prefix_mentioned(self, prompt_lower):
|
||||||
|
assert "spravnekonanie_" in prompt_lower or "spravne" in prompt_lower
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptLength:
|
||||||
|
|
||||||
|
def test_prompt_is_not_empty(self, prompt):
|
||||||
|
assert len(prompt.strip()) > 0
|
||||||
|
|
||||||
|
def test_prompt_has_minimum_length(self, prompt):
|
||||||
|
assert len(prompt) > 500
|
||||||
|
|
||||||
|
def test_prompt_has_reasonable_max_length(self, prompt):
|
||||||
|
assert len(prompt) < 50_000
|
||||||
152
testing/tests/test_tools.py
Normal file
152
testing/tests/test_tools.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
import respx
|
||||||
|
|
||||||
|
from api.fetch_api_data import _cache
|
||||||
|
from api.config import JUSTICE_API_BASE
|
||||||
|
from api.tools import (
|
||||||
|
court_search,
|
||||||
|
court_id,
|
||||||
|
court_autocomplete,
|
||||||
|
judge_search,
|
||||||
|
judge_id,
|
||||||
|
judge_autocomplete,
|
||||||
|
decision_search,
|
||||||
|
contract_search,
|
||||||
|
civil_proceedings_search,
|
||||||
|
admin_proceedings_search,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EMPTY_LIST_RESPONSE = {"content": [], "totalElements": 0, "numberOfElements": 0}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_cache():
|
||||||
|
_cache.clear()
|
||||||
|
yield
|
||||||
|
_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def make_response(body: dict = None):
|
||||||
|
return httpx.Response(200, json=body or EMPTY_LIST_RESPONSE)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestCourtTools:
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_court_search_calls_correct_url(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud").mock(return_value=make_response())
|
||||||
|
await court_search.on_invoke_tool(None, '{"query": "Bratislava"}')
|
||||||
|
assert mock.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_court_id_calls_correct_url(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud/sud_175").mock(
|
||||||
|
return_value=httpx.Response(200, json={"id": "sud_175", "foto": "data"})
|
||||||
|
)
|
||||||
|
await court_id.on_invoke_tool(None, '{"id": "175"}')
|
||||||
|
assert mock.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_court_id_removes_foto_key(self):
|
||||||
|
respx.get(f"{JUSTICE_API_BASE}/v1/sud/sud_1").mock(
|
||||||
|
return_value=httpx.Response(200, json={"name": "Súd", "foto": "base64"})
|
||||||
|
)
|
||||||
|
result = await court_id.on_invoke_tool(None, '{"id": "1"}')
|
||||||
|
assert "foto" not in str(result)
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_court_autocomplete_calls_correct_url(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sud/autocomplete").mock(return_value=make_response())
|
||||||
|
await court_autocomplete.on_invoke_tool(None, '{"query": "Kraj"}')
|
||||||
|
assert mock.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestJudgeTools:
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_judge_search_calls_correct_url(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca").mock(return_value=make_response())
|
||||||
|
await judge_search.on_invoke_tool(None, '{"query": "Novák"}')
|
||||||
|
assert mock.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_judge_id_normalizes_digit_id(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca/sudca_1").mock(
|
||||||
|
return_value=httpx.Response(200, json={"id": "sudca_1"})
|
||||||
|
)
|
||||||
|
await judge_id.on_invoke_tool(None, '{"id": "1"}')
|
||||||
|
assert mock.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_judge_autocomplete_passes_guid_sud(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/sudca/autocomplete").mock(return_value=make_response())
|
||||||
|
await judge_autocomplete.on_invoke_tool(None, '{"query": "Novák", "guidSud": "sud_100"}')
|
||||||
|
assert mock.called
|
||||||
|
params = dict(mock.calls[0].request.url.params)
|
||||||
|
assert params.get("guidSud") == "sud_100"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestDecisionTools:
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_decision_search_with_date_range(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie").mock(return_value=make_response())
|
||||||
|
await decision_search.on_invoke_tool(
|
||||||
|
None, '{"vydaniaOd": "01.01.2024", "vydaniaDo": "31.01.2024"}'
|
||||||
|
)
|
||||||
|
assert mock.called
|
||||||
|
params = dict(mock.calls[0].request.url.params)
|
||||||
|
assert params.get("vydaniaOd") == "01.01.2024"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_decision_search_with_guid_sudca(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/rozhodnutie").mock(return_value=make_response())
|
||||||
|
await decision_search.on_invoke_tool(None, '{"guidSudca": "sudca_1"}')
|
||||||
|
params = dict(mock.calls[0].request.url.params)
|
||||||
|
assert params.get("guidSudca") == "sudca_1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestContractTools:
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_contract_search_with_guid_sud(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/zmluvy").mock(return_value=make_response())
|
||||||
|
await contract_search.on_invoke_tool(None, '{"guidSud": "sud_7"}')
|
||||||
|
params = dict(mock.calls[0].request.url.params)
|
||||||
|
assert params.get("guidSud") == "sud_7"
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_contract_search_typ_dokumentu_filter(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/zmluvy").mock(return_value=make_response())
|
||||||
|
await contract_search.on_invoke_tool(None, '{"typDokumentuFacetFilter": ["ZMLUVA"]}')
|
||||||
|
assert mock.called
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestCivilAndAdminTools:
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_civil_proceedings_search(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania").mock(return_value=make_response())
|
||||||
|
await civil_proceedings_search.on_invoke_tool(None, '{"query": "test"}')
|
||||||
|
assert mock.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_admin_proceedings_search(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/spravneKonanie").mock(return_value=make_response())
|
||||||
|
await admin_proceedings_search.on_invoke_tool(None, '{"query": "test"}')
|
||||||
|
assert mock.called
|
||||||
|
|
||||||
|
@respx.mock
|
||||||
|
async def test_civil_proceedings_date_params(self):
|
||||||
|
mock = respx.get(f"{JUSTICE_API_BASE}/v1/obcianPojednavania").mock(return_value=make_response())
|
||||||
|
await civil_proceedings_search.on_invoke_tool(
|
||||||
|
None, '{"pojednavaniaOd": "01.01.2024", "jednotnavaniaDo": "31.01.2024"}'
|
||||||
|
)
|
||||||
|
assert mock.called
|
||||||
Binary file not shown.
|
Before Width: | Height: | Size: 191 KiB |
294
tests/runner.py
294
tests/runner.py
@ -1,294 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import sqlite3
|
|
||||||
import time
|
|
||||||
import sys
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from api.config import JUSTICE_API_BASE
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
DB_PATH = Path(__file__).parent / "test_queries.db"
|
|
||||||
OUT_DIR = Path(__file__).parent / "results"
|
|
||||||
REQUEST_TIMEOUT = 15
|
|
||||||
CONCURRENT_LIMIT = 5
|
|
||||||
DELAY_BETWEEN = 0.3
|
|
||||||
|
|
||||||
HEADERS = {
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) Chrome/124.0.0.0 Safari/537.36",
|
|
||||||
"Accept": "application/json, text/plain, */*",
|
|
||||||
"Accept-Language": "sk-SK,sk;q=0.9",
|
|
||||||
"Referer": "https://obcan.justice.sk/",
|
|
||||||
"Origin": "https://obcan.justice.sk",
|
|
||||||
}
|
|
||||||
|
|
||||||
TOOL_MAP = {
|
|
||||||
"judge": (f"{JUSTICE_API_BASE}/v1/sudca","search"),
|
|
||||||
"judge_id": (f"{JUSTICE_API_BASE}/v1/sudca","id"),
|
|
||||||
"judge_autocomplete": (f"{JUSTICE_API_BASE}/v1/sudca/autocomplete","autocomplete"),
|
|
||||||
"court": (f"{JUSTICE_API_BASE}/v1/sud","search"),
|
|
||||||
"court_id": (f"{JUSTICE_API_BASE}/v1/sud","id"),
|
|
||||||
"court_autocomplete": (f"{JUSTICE_API_BASE}/v1/sud/autocomplete","autocomplete"),
|
|
||||||
"decision": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","search"),
|
|
||||||
"decision_id": (f"{JUSTICE_API_BASE}/v1/rozhodnutie","id"),
|
|
||||||
"decision_autocomplete": (f"{JUSTICE_API_BASE}/v1/rozhodnutie/autocomplete","autocomplete"),
|
|
||||||
"contract": (f"{JUSTICE_API_BASE}/v1/zmluvy","search"),
|
|
||||||
"contract_id": (f"{JUSTICE_API_BASE}/v1/zmluvy","id"),
|
|
||||||
"contract_autocomplete": (f"{JUSTICE_API_BASE}/v1/zmluvy/autocomplete","autocomplete"),
|
|
||||||
"civil_proceedings": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","search"),
|
|
||||||
"civil_proceedings_id": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania","id"),
|
|
||||||
"civil_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/obcianPojednavania/autocomplete","autocomplete"),
|
|
||||||
"admin_proceedings": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","search"),
|
|
||||||
"admin_proceedings_id": (f"{JUSTICE_API_BASE}/v1/spravneKonanie","id"),
|
|
||||||
"admin_proceedings_autocomplete": (f"{JUSTICE_API_BASE}/v1/spravneKonanie/autocomplete","autocomplete"),
|
|
||||||
"mimo_rozsah": (None,"skip"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_queries(limit: int = None) -> list[dict]:
|
|
||||||
conn = sqlite3.connect(DB_PATH)
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
sql = "SELECT * FROM test_queries ORDER BY id"
|
|
||||||
if limit:
|
|
||||||
sql += f" LIMIT {limit}"
|
|
||||||
rows = conn.execute(sql).fetchall()
|
|
||||||
conn.close()
|
|
||||||
return [dict(r) for r in rows]
|
|
||||||
|
|
||||||
|
|
||||||
def pick_keyword(row: dict) -> str:
|
|
||||||
first = (row.get("expected_keywords") or "").split(",")[0].strip()
|
|
||||||
if len(first) > 2:
|
|
||||||
return first
|
|
||||||
words = [w for w in row["query_sk"].split() if len(w) > 3]
|
|
||||||
return " ".join(words[:2]) or row["query_sk"][:20]
|
|
||||||
|
|
||||||
|
|
||||||
def count_results(data) -> int | None:
|
|
||||||
if isinstance(data, list):
|
|
||||||
return len(data)
|
|
||||||
if isinstance(data, dict):
|
|
||||||
for key in ("totalElements", "total", "count", "numberOfElements"):
|
|
||||||
if key in data:
|
|
||||||
return int(data[key])
|
|
||||||
if "content" in data:
|
|
||||||
return len(data["content"])
|
|
||||||
return 1
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_one(
|
|
||||||
client: httpx.AsyncClient,
|
|
||||||
semaphore: asyncio.Semaphore,
|
|
||||||
row: dict,
|
|
||||||
idx: int,
|
|
||||||
total: int,
|
|
||||||
) -> dict:
|
|
||||||
result = {
|
|
||||||
"id": row["id"],
|
|
||||||
"category": row["category"],
|
|
||||||
"difficulty": row["difficulty"],
|
|
||||||
"tool": row["expected_tool"],
|
|
||||||
"query": row["query_sk"],
|
|
||||||
"status": "pending",
|
|
||||||
"http_code": None,
|
|
||||||
"result_count": None,
|
|
||||||
"duration_ms": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
endpoint, call_type = TOOL_MAP.get(row["expected_tool"], (None, "skip"))
|
|
||||||
|
|
||||||
print(f"Request {idx}/{total}")
|
|
||||||
|
|
||||||
if call_type == "skip":
|
|
||||||
result["status"] = "SKIP"
|
|
||||||
return result
|
|
||||||
|
|
||||||
if call_type == "id":
|
|
||||||
id_val = (row.get("expected_keywords") or "").split(",")[0].strip()
|
|
||||||
url, params = f"{endpoint}/{id_val}", {}
|
|
||||||
elif call_type == "autocomplete":
|
|
||||||
url, params = endpoint, {"query": pick_keyword(row), "limit": 5}
|
|
||||||
else:
|
|
||||||
url, params = endpoint, {"query": pick_keyword(row), "size": 5}
|
|
||||||
|
|
||||||
async with semaphore:
|
|
||||||
await asyncio.sleep(DELAY_BETWEEN)
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
try:
|
|
||||||
resp = await client.get(url, params=params, headers=HEADERS, timeout=REQUEST_TIMEOUT)
|
|
||||||
result["duration_ms"] = round((time.perf_counter() - t0) * 1000)
|
|
||||||
result["http_code"] = resp.status_code
|
|
||||||
|
|
||||||
if resp.status_code == 200:
|
|
||||||
try:
|
|
||||||
cnt = count_results(resp.json())
|
|
||||||
result["result_count"] = cnt
|
|
||||||
result["status"] = "OK" if cnt and cnt > 0 else "EMPTY"
|
|
||||||
except Exception:
|
|
||||||
result["status"] = "OK"
|
|
||||||
elif resp.status_code == 404:
|
|
||||||
result["status"] = "NOT_FOUND"
|
|
||||||
elif resp.status_code == 403:
|
|
||||||
result["status"] = "FORBIDDEN"
|
|
||||||
else:
|
|
||||||
result["status"] = "HTTP_ERROR"
|
|
||||||
result["error"] = f"HTTP {resp.status_code}"
|
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
result["status"] = "TIMEOUT"
|
|
||||||
result["duration_ms"] = REQUEST_TIMEOUT * 1000
|
|
||||||
except httpx.ConnectError as e:
|
|
||||||
result["status"] = "CONN_ERROR"
|
|
||||||
result["error"] = str(e)
|
|
||||||
except Exception as e:
|
|
||||||
result["status"] = "ERROR"
|
|
||||||
result["error"] = str(e)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def run_all(rows: list[dict]) -> list[dict]:
|
|
||||||
semaphore = asyncio.Semaphore(CONCURRENT_LIMIT)
|
|
||||||
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
||||||
tasks = [
|
|
||||||
validate_one(client, semaphore, row, i + 1, len(rows))
|
|
||||||
for i, row in enumerate(rows)
|
|
||||||
]
|
|
||||||
return list(await asyncio.gather(*tasks))
|
|
||||||
|
|
||||||
|
|
||||||
def generate_charts(results: list[dict], ts: str) -> Path:
|
|
||||||
import matplotlib
|
|
||||||
matplotlib.use("Agg")
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib.patches as mpatches
|
|
||||||
import numpy as np
|
|
||||||
from collections import Counter, defaultdict
|
|
||||||
|
|
||||||
STATUS_COLORS = {
|
|
||||||
"OK": "#1D9E75",
|
|
||||||
"EMPTY": "#EF9F27",
|
|
||||||
"SKIP": "#888780",
|
|
||||||
"NOT_FOUND": "#E24B4A",
|
|
||||||
"FORBIDDEN": "#D85A30",
|
|
||||||
"TIMEOUT": "#D4537E",
|
|
||||||
"HTTP_ERROR": "#E24B4A",
|
|
||||||
"CONN_ERROR": "#E24B4A",
|
|
||||||
"ERROR": "#E24B4A",
|
|
||||||
}
|
|
||||||
DIFF_COLORS = {"easy": "#1D9E75", "medium": "#EF9F27", "hard": "#E24B4A"}
|
|
||||||
|
|
||||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
used_statuses = [s for s in STATUS_COLORS if any(r["status"] == s for r in results)]
|
|
||||||
|
|
||||||
fig = plt.figure(figsize=(18, 18), facecolor="#FAFAF8")
|
|
||||||
fig.suptitle(
|
|
||||||
f"Validation Report - Legal AI Agent\n{ts.replace('_', ' ')}",
|
|
||||||
fontsize=15, fontweight="bold", y=0.98, color="#2C2C2A"
|
|
||||||
)
|
|
||||||
# 2 rows: [pie + category] | [histogram + difficulty]
|
|
||||||
gs = fig.add_gridspec(2, 2, hspace=0.45, wspace=0.35,
|
|
||||||
top=0.94, bottom=0.04, left=0.07, right=0.97)
|
|
||||||
|
|
||||||
def styled(ax, title):
|
|
||||||
ax.set_title(title, fontsize=12, fontweight="bold", color="#2C2C2A")
|
|
||||||
ax.set_facecolor("#F9F9F7")
|
|
||||||
ax.spines[["top", "right"]].set_visible(False)
|
|
||||||
ax.grid(axis="y", alpha=0.3, linestyle="--")
|
|
||||||
|
|
||||||
# 1. Pie — overall results
|
|
||||||
ax1 = fig.add_subplot(gs[0, 0])
|
|
||||||
cnt = Counter(r["status"] for r in results)
|
|
||||||
wedges, texts, pcts = ax1.pie(
|
|
||||||
cnt.values(),
|
|
||||||
labels=cnt.keys(),
|
|
||||||
colors=[STATUS_COLORS.get(k, "#888780") for k in cnt],
|
|
||||||
autopct="%1.1f%%", startangle=90, pctdistance=0.82,
|
|
||||||
wedgeprops={"edgecolor": "white", "linewidth": 1.5},
|
|
||||||
)
|
|
||||||
for t in texts: t.set_fontsize(10)
|
|
||||||
for t in pcts: t.set_fontsize(9); t.set_color("white"); t.set_fontweight("bold")
|
|
||||||
ax1.set_title("Overall Results", fontsize=12, fontweight="bold", color="#2C2C2A")
|
|
||||||
|
|
||||||
# 2. Bar — results by category
|
|
||||||
ax2 = fig.add_subplot(gs[0, 1])
|
|
||||||
categories = sorted(set(r["category"] for r in results))
|
|
||||||
cat_data = defaultdict(Counter)
|
|
||||||
for r in results:
|
|
||||||
cat_data[r["category"]][r["status"]] += 1
|
|
||||||
|
|
||||||
x, w = np.arange(len(categories)), 0.8 / max(len(used_statuses), 1)
|
|
||||||
for i, status in enumerate(used_statuses):
|
|
||||||
vals = [cat_data[c][status] for c in categories]
|
|
||||||
offset = (i - len(used_statuses) / 2 + 0.5) * w
|
|
||||||
ax2.bar(x + offset, vals, w, label=status,
|
|
||||||
color=STATUS_COLORS.get(status, "#888780"),
|
|
||||||
edgecolor="white", linewidth=0.5)
|
|
||||||
|
|
||||||
ax2.set_xticks(x)
|
|
||||||
ax2.set_xticklabels([c[:13] for c in categories], rotation=35, ha="right", fontsize=8)
|
|
||||||
ax2.set_ylabel("Count", fontsize=10)
|
|
||||||
ax2.legend(fontsize=7, loc="upper right")
|
|
||||||
styled(ax2, "Results by Category")
|
|
||||||
|
|
||||||
# 3. Histogram — response times
|
|
||||||
ax4 = fig.add_subplot(gs[1, 0])
|
|
||||||
times = [r["duration_ms"] for r in results
|
|
||||||
if r["duration_ms"] and r["status"] != "SKIP"]
|
|
||||||
if times:
|
|
||||||
ax4.hist(times, bins=20, color="#378ADD", edgecolor="white", linewidth=0.7, alpha=0.85)
|
|
||||||
avg = sum(times) / len(times)
|
|
||||||
ax4.axvline(avg, color="#E24B4A", linewidth=1.5, linestyle="--",
|
|
||||||
label=f"Avg: {avg:.0f} ms")
|
|
||||||
ax4.set_xlabel("ms", fontsize=10)
|
|
||||||
ax4.set_ylabel("Requests", fontsize=10)
|
|
||||||
ax4.legend(fontsize=9)
|
|
||||||
styled(ax4, "API Response Time Distribution")
|
|
||||||
|
|
||||||
# 4. Bar — results by difficulty
|
|
||||||
ax5 = fig.add_subplot(gs[1, 1])
|
|
||||||
difficulties = ["easy", "medium", "hard"]
|
|
||||||
diff_data = defaultdict(Counter)
|
|
||||||
for r in results:
|
|
||||||
diff_data[r["difficulty"]][r["status"]] += 1
|
|
||||||
|
|
||||||
x2, w2 = np.arange(3), 0.8 / max(len(used_statuses), 1)
|
|
||||||
for i, status in enumerate(used_statuses):
|
|
||||||
vals = [diff_data[d][status] for d in difficulties]
|
|
||||||
offset = (i - len(used_statuses) / 2 + 0.5) * w2
|
|
||||||
ax5.bar(x2 + offset, vals, w2, label=status,
|
|
||||||
color=STATUS_COLORS.get(status, "#888780"),
|
|
||||||
edgecolor="white", linewidth=0.5)
|
|
||||||
|
|
||||||
ax5.set_xticks(x2)
|
|
||||||
ax5.set_xticklabels(difficulties, fontsize=10)
|
|
||||||
ax5.set_ylabel("Count", fontsize=10)
|
|
||||||
ax5.legend(fontsize=8)
|
|
||||||
styled(ax5, "Results by Difficulty")
|
|
||||||
|
|
||||||
|
|
||||||
out = OUT_DIR / f"report_{ts}.png"
|
|
||||||
fig.savefig(out, dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor())
|
|
||||||
plt.close(fig)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
limit = int(sys.argv[1]) if len(sys.argv) > 1 else None
|
|
||||||
|
|
||||||
rows = load_queries(limit)
|
|
||||||
ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
||||||
|
|
||||||
print(f"Start Test Validation")
|
|
||||||
|
|
||||||
results = asyncio.run(run_all(rows))
|
|
||||||
|
|
||||||
print(f"End Test Validation")
|
|
||||||
|
|
||||||
out = generate_charts(results, ts)
|
|
||||||
print(f"Test image is saved: {out}")
|
|
||||||
Binary file not shown.
Loading…
Reference in New Issue
Block a user