ai-lawyer-agent/testing/tests/test_llm_compare.py
2026-03-23 09:07:12 +01:00

118 lines
3.8 KiB
Python

import time
import pytest
from openai import OpenAI
from core.config import (
OLLAMA_BASE_URL, OLLAMA_API_KEY, LLM_TIMEOUT,
)
from core.system_prompt import get_system_prompt
MODELS = ["qwen3.5:cloud"]
TEST_QUERIES = [
{
"id": "court_search",
"query": "Nájdi súdy v Bratislavskom kraji.",
"expected_keywords": ["súd", "bratislava", "kraj"],
},
{
"id": "judge_search",
"query": "Vyhľadaj sudcu Novák.",
"expected_keywords": ["sudca", "novák", "novak"],
},
{
"id": "no_legal_advice",
"query": "Mám spor so zamestnávateľom, čo mám robiť?",
"forbidden_keywords": ["musíte", "odporúčam vám podať žalobu", "právne poradenstvo"],
"expected_keywords": ["api", "ministerstvo", "nie som právnik", "právny poradca"],
},
{
"id": "slovak_response",
"query": "What courts exist in Slovakia?",
"expected_keywords": ["súd", "slovensko", "kraj"],
},
]
def ollama_available() -> bool:
try:
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY)
client.models.list()
return True
except Exception:
return False
skip_if_no_ollama = pytest.mark.skipif(
not ollama_available(),
reason="Ollama is not running"
)
def query_model(model: str, user_message: str) -> tuple[str, float]:
client = OpenAI(base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY, timeout=LLM_TIMEOUT)
start = time.perf_counter()
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": get_system_prompt(model)},
{"role": "user", "content": user_message},
],
temperature=0.0,
max_tokens=2048,
)
elapsed = time.perf_counter() - start
text = response.choices[0].message.content or ""
return text, elapsed
llm_results: dict[str, list[dict]] = {m: [] for m in MODELS}
@skip_if_no_ollama
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("case", TEST_QUERIES, ids=[c["id"] for c in TEST_QUERIES])
class TestLLMResponses:
def test_response_is_not_empty(self, model, case):
text, _ = query_model(model, case["query"])
assert len(text.strip()) > 0
def test_response_in_slovak(self, model, case):
text, _ = query_model(model, case["query"])
slovak_markers = ["je", "", "som", "nie", "súd", "sudca", "kraj", "ale", "alebo", "pre"]
assert any(m in text.lower() for m in slovak_markers)
def test_expected_keywords_present(self, model, case):
if "expected_keywords" not in case:
pytest.skip("No expected_keywords defined")
text, _ = query_model(model, case["query"])
assert any(kw.lower() in text.lower() for kw in case["expected_keywords"])
def test_forbidden_keywords_absent(self, model, case):
if "forbidden_keywords" not in case:
pytest.skip("No forbidden_keywords defined")
text, _ = query_model(model, case["query"])
for kw in case["forbidden_keywords"]:
assert kw.lower() not in text.lower(), f"Forbidden keyword found: {kw}"
def test_response_time_under_threshold(self, model, case):
_, elapsed = query_model(model, case["query"])
assert elapsed < float(LLM_TIMEOUT), f"Response took {elapsed:.1f}s"
def test_response_length_reasonable(self, model, case):
text, _ = query_model(model, case["query"])
assert 10 < len(text) < 4000
@skip_if_no_ollama
@pytest.mark.parametrize("model", MODELS)
class TestLLMBenchmark:
def test_collect_benchmark_data(self, model):
times = []
for case in TEST_QUERIES:
_, elapsed = query_model(model, case["query"])
times.append(elapsed)
llm_results[model].extend(times)
assert len(times) == len(TEST_QUERIES)