dp-zp-agent/scripts/search_utils.py

240 lines
5.9 KiB
Python

from __future__ import annotations
import re
import sqlite3
import unicodedata
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any
TECHNICAL_TERMS = {
"rag",
"agent",
"graph",
"knowledge",
"chatbot",
"nlp",
"llm",
"lm",
"openwebui",
"docker",
"webhook",
"database",
"db",
"neo4j",
"python",
"search",
"retrieval",
"generation",
"embedding",
"vector",
"vectors",
"langchain",
"graphrag",
"qa",
"question",
"answer",
"cloud",
"api",
}
def normalize_text(text: str) -> str:
text = text.lower()
text = text.replace("_", " ")
text = text.replace("/", " ")
text = text.replace("-", " ")
text = unicodedata.normalize("NFKD", text)
text = "".join(ch for ch in text if not unicodedata.combining(ch))
return re.sub(r"[^a-z0-9]+", " ", text).strip()
def tokenize(text: str) -> list[str]:
return [
word
for word in normalize_text(text).split()
if len(word) >= 2
]
def detect_search_mode(tokens: list[str]) -> str:
"""Jednoduchý odhad, či ide o meno osoby alebo odbornú tému."""
if not tokens:
return "topic"
has_technical_term = any(token in TECHNICAL_TERMS for token in tokens)
if len(tokens) == 2 and not has_technical_term:
return "person"
return "topic"
def contains_all(query_tokens: list[str], field_tokens: list[str]) -> bool:
return all(token in field_tokens for token in query_tokens)
def score_tokens(
query_tokens: list[str],
field_tokens: list[str],
weight: int,
) -> int:
counts = Counter(field_tokens)
return sum(
counts.get(token, 0) * weight
for token in query_tokens
)
def make_source_url(document_path: str) -> str:
clean_path = document_path.replace("pages/", "").replace("/README.md", "")
return f"https://zp.kemt.fei.tuke.sk/{clean_path}"
def load_labels(
conn: sqlite3.Connection,
table: str,
column: str,
) -> dict[str, list[str]]:
rows = conn.execute(f"SELECT chunk_id, {column} FROM {table}").fetchall()
labels: dict[str, list[str]] = defaultdict(list)
for chunk_id, value in rows:
labels[chunk_id].append(value)
return labels
def person_matches(query_tokens: list[str], item: dict[str, Any]) -> bool:
fields = [
item.get("title") or "",
item.get("document_path") or "",
item.get("author") or "",
item.get("text") or "",
]
return any(
contains_all(query_tokens, tokenize(field))
for field in fields
)
def score_item(
query: str,
query_tokens: list[str],
item: dict[str, Any],
mode: str,
) -> int:
title_tokens = tokenize(item.get("title") or "")
path_tokens = tokenize(item.get("document_path") or "")
author_tokens = tokenize(item.get("author") or "")
text_tokens = tokenize(item.get("text") or "")
tag_tokens = tokenize(" ".join(item.get("tags") or []))
category_tokens = tokenize(" ".join(item.get("categories") or []))
if mode == "person":
score = 0
score += score_tokens(query_tokens, title_tokens, 30)
score += score_tokens(query_tokens, path_tokens, 30)
score += score_tokens(query_tokens, author_tokens, 15)
score += score_tokens(query_tokens, text_tokens, 2)
if contains_all(query_tokens, title_tokens):
score += 100
if contains_all(query_tokens, path_tokens):
score += 100
if contains_all(query_tokens, author_tokens):
score += 60
return score
score = 0
score += score_tokens(query_tokens, title_tokens, 12)
score += score_tokens(query_tokens, path_tokens, 12)
score += score_tokens(query_tokens, tag_tokens, 10)
score += score_tokens(query_tokens, category_tokens, 6)
score += score_tokens(query_tokens, author_tokens, 3)
score += score_tokens(query_tokens, text_tokens, 2)
normalized_query = normalize_text(query)
normalized_title = normalize_text(item.get("title") or "")
normalized_path = normalize_text(item.get("document_path") or "")
if normalized_query and normalized_query in normalized_title:
score += 30
if normalized_query and normalized_query in normalized_path:
score += 30
if query_tokens and contains_all(query_tokens, title_tokens):
score += 25
if query_tokens and contains_all(query_tokens, path_tokens):
score += 25
return score
def search_database(
db_file: Path,
query: str,
limit: int = 10,
) -> tuple[str, list[dict[str, Any]]]:
if not db_file.exists():
raise FileNotFoundError(f"Databáza neexistuje: {db_file}")
query_tokens = tokenize(query)
mode = detect_search_mode(query_tokens)
with sqlite3.connect(db_file) as conn:
conn.row_factory = sqlite3.Row
tags_by_chunk = load_labels(conn, "chunk_tags", "tag")
categories_by_chunk = load_labels(conn, "chunk_categories", "category")
rows = conn.execute(
"""
SELECT
chunk_id,
document_path,
title,
author,
chunk_index,
text,
text_length
FROM chunks
"""
).fetchall()
results = []
for row in rows:
item = dict(row)
chunk_id = item["chunk_id"]
item["tags"] = tags_by_chunk.get(chunk_id, [])
item["categories"] = categories_by_chunk.get(chunk_id, [])
if mode == "person" and not person_matches(query_tokens, item):
continue
score = score_item(query, query_tokens, item, mode)
if score <= 0:
continue
item["score"] = score
item["source_url"] = make_source_url(item["document_path"])
results.append(item)
results.sort(key=lambda item: item["score"], reverse=True)
return mode, results[:limit]