240 lines
5.9 KiB
Python
240 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
import sqlite3
|
|
import unicodedata
|
|
from collections import Counter, defaultdict
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
TECHNICAL_TERMS = {
|
|
"rag",
|
|
"agent",
|
|
"graph",
|
|
"knowledge",
|
|
"chatbot",
|
|
"nlp",
|
|
"llm",
|
|
"lm",
|
|
"openwebui",
|
|
"docker",
|
|
"webhook",
|
|
"database",
|
|
"db",
|
|
"neo4j",
|
|
"python",
|
|
"search",
|
|
"retrieval",
|
|
"generation",
|
|
"embedding",
|
|
"vector",
|
|
"vectors",
|
|
"langchain",
|
|
"graphrag",
|
|
"qa",
|
|
"question",
|
|
"answer",
|
|
"cloud",
|
|
"api",
|
|
}
|
|
|
|
|
|
def normalize_text(text: str) -> str:
|
|
text = text.lower()
|
|
text = text.replace("_", " ")
|
|
text = text.replace("/", " ")
|
|
text = text.replace("-", " ")
|
|
|
|
text = unicodedata.normalize("NFKD", text)
|
|
text = "".join(ch for ch in text if not unicodedata.combining(ch))
|
|
|
|
return re.sub(r"[^a-z0-9]+", " ", text).strip()
|
|
|
|
|
|
def tokenize(text: str) -> list[str]:
|
|
return [
|
|
word
|
|
for word in normalize_text(text).split()
|
|
if len(word) >= 2
|
|
]
|
|
|
|
|
|
def detect_search_mode(tokens: list[str]) -> str:
|
|
"""Jednoduchý odhad, či ide o meno osoby alebo odbornú tému."""
|
|
if not tokens:
|
|
return "topic"
|
|
|
|
has_technical_term = any(token in TECHNICAL_TERMS for token in tokens)
|
|
|
|
if len(tokens) == 2 and not has_technical_term:
|
|
return "person"
|
|
|
|
return "topic"
|
|
|
|
|
|
def contains_all(query_tokens: list[str], field_tokens: list[str]) -> bool:
|
|
return all(token in field_tokens for token in query_tokens)
|
|
|
|
|
|
def score_tokens(
|
|
query_tokens: list[str],
|
|
field_tokens: list[str],
|
|
weight: int,
|
|
) -> int:
|
|
counts = Counter(field_tokens)
|
|
|
|
return sum(
|
|
counts.get(token, 0) * weight
|
|
for token in query_tokens
|
|
)
|
|
|
|
|
|
def make_source_url(document_path: str) -> str:
|
|
clean_path = document_path.replace("pages/", "").replace("/README.md", "")
|
|
return f"https://zp.kemt.fei.tuke.sk/{clean_path}"
|
|
|
|
|
|
def load_labels(
|
|
conn: sqlite3.Connection,
|
|
table: str,
|
|
column: str,
|
|
) -> dict[str, list[str]]:
|
|
rows = conn.execute(f"SELECT chunk_id, {column} FROM {table}").fetchall()
|
|
labels: dict[str, list[str]] = defaultdict(list)
|
|
|
|
for chunk_id, value in rows:
|
|
labels[chunk_id].append(value)
|
|
|
|
return labels
|
|
|
|
|
|
def person_matches(query_tokens: list[str], item: dict[str, Any]) -> bool:
|
|
fields = [
|
|
item.get("title") or "",
|
|
item.get("document_path") or "",
|
|
item.get("author") or "",
|
|
item.get("text") or "",
|
|
]
|
|
|
|
return any(
|
|
contains_all(query_tokens, tokenize(field))
|
|
for field in fields
|
|
)
|
|
|
|
|
|
def score_item(
|
|
query: str,
|
|
query_tokens: list[str],
|
|
item: dict[str, Any],
|
|
mode: str,
|
|
) -> int:
|
|
title_tokens = tokenize(item.get("title") or "")
|
|
path_tokens = tokenize(item.get("document_path") or "")
|
|
author_tokens = tokenize(item.get("author") or "")
|
|
text_tokens = tokenize(item.get("text") or "")
|
|
tag_tokens = tokenize(" ".join(item.get("tags") or []))
|
|
category_tokens = tokenize(" ".join(item.get("categories") or []))
|
|
|
|
if mode == "person":
|
|
score = 0
|
|
score += score_tokens(query_tokens, title_tokens, 30)
|
|
score += score_tokens(query_tokens, path_tokens, 30)
|
|
score += score_tokens(query_tokens, author_tokens, 15)
|
|
score += score_tokens(query_tokens, text_tokens, 2)
|
|
|
|
if contains_all(query_tokens, title_tokens):
|
|
score += 100
|
|
|
|
if contains_all(query_tokens, path_tokens):
|
|
score += 100
|
|
|
|
if contains_all(query_tokens, author_tokens):
|
|
score += 60
|
|
|
|
return score
|
|
|
|
score = 0
|
|
score += score_tokens(query_tokens, title_tokens, 12)
|
|
score += score_tokens(query_tokens, path_tokens, 12)
|
|
score += score_tokens(query_tokens, tag_tokens, 10)
|
|
score += score_tokens(query_tokens, category_tokens, 6)
|
|
score += score_tokens(query_tokens, author_tokens, 3)
|
|
score += score_tokens(query_tokens, text_tokens, 2)
|
|
|
|
normalized_query = normalize_text(query)
|
|
normalized_title = normalize_text(item.get("title") or "")
|
|
normalized_path = normalize_text(item.get("document_path") or "")
|
|
|
|
if normalized_query and normalized_query in normalized_title:
|
|
score += 30
|
|
|
|
if normalized_query and normalized_query in normalized_path:
|
|
score += 30
|
|
|
|
if query_tokens and contains_all(query_tokens, title_tokens):
|
|
score += 25
|
|
|
|
if query_tokens and contains_all(query_tokens, path_tokens):
|
|
score += 25
|
|
|
|
return score
|
|
|
|
|
|
def search_database(
|
|
db_file: Path,
|
|
query: str,
|
|
limit: int = 10,
|
|
) -> tuple[str, list[dict[str, Any]]]:
|
|
if not db_file.exists():
|
|
raise FileNotFoundError(f"Databáza neexistuje: {db_file}")
|
|
|
|
query_tokens = tokenize(query)
|
|
mode = detect_search_mode(query_tokens)
|
|
|
|
with sqlite3.connect(db_file) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
tags_by_chunk = load_labels(conn, "chunk_tags", "tag")
|
|
categories_by_chunk = load_labels(conn, "chunk_categories", "category")
|
|
|
|
rows = conn.execute(
|
|
"""
|
|
SELECT
|
|
chunk_id,
|
|
document_path,
|
|
title,
|
|
author,
|
|
chunk_index,
|
|
text,
|
|
text_length
|
|
FROM chunks
|
|
"""
|
|
).fetchall()
|
|
|
|
results = []
|
|
|
|
for row in rows:
|
|
item = dict(row)
|
|
chunk_id = item["chunk_id"]
|
|
|
|
item["tags"] = tags_by_chunk.get(chunk_id, [])
|
|
item["categories"] = categories_by_chunk.get(chunk_id, [])
|
|
|
|
if mode == "person" and not person_matches(query_tokens, item):
|
|
continue
|
|
|
|
score = score_item(query, query_tokens, item, mode)
|
|
|
|
if score <= 0:
|
|
continue
|
|
|
|
item["score"] = score
|
|
item["source_url"] = make_source_url(item["document_path"])
|
|
|
|
results.append(item)
|
|
|
|
results.sort(key=lambda item: item["score"], reverse=True)
|
|
|
|
return mode, results[:limit]
|