272 lines
7.3 KiB
Python
272 lines
7.3 KiB
Python
from pathlib import Path
|
|
import sqlite3
|
|
import re
|
|
import sys
|
|
import unicodedata
|
|
from collections import Counter
|
|
from rich import print
|
|
|
|
|
|
DB_FILE = Path("data/zp_index.sqlite")
|
|
|
|
|
|
TECHNICAL_TERMS = {
|
|
"rag",
|
|
"agent",
|
|
"graph",
|
|
"knowledge",
|
|
"chatbot",
|
|
"nlp",
|
|
"llm",
|
|
"lm",
|
|
"openwebui",
|
|
"docker",
|
|
"webhook",
|
|
"database",
|
|
"db",
|
|
"neo4j",
|
|
"python",
|
|
"search",
|
|
"retrieval",
|
|
"generation",
|
|
"embedding",
|
|
"vector",
|
|
"vectors",
|
|
"langchain",
|
|
"graphrag",
|
|
"qa",
|
|
"question",
|
|
"answer",
|
|
"cloud",
|
|
"api",
|
|
}
|
|
|
|
|
|
def normalize_text(text: str) -> str:
|
|
text = text.lower()
|
|
text = text.replace("_", " ")
|
|
text = text.replace("/", " ")
|
|
text = text.replace("-", " ")
|
|
|
|
text = unicodedata.normalize("NFKD", text)
|
|
text = "".join(ch for ch in text if not unicodedata.combining(ch))
|
|
|
|
text = re.sub(r"[^a-z0-9]+", " ", text)
|
|
return text.strip()
|
|
|
|
|
|
def tokenize(text: str) -> list[str]:
|
|
text = normalize_text(text)
|
|
return [word for word in text.split() if len(word) >= 2]
|
|
|
|
|
|
def detect_search_mode(query_tokens: list[str]) -> str:
|
|
"""
|
|
person režim:
|
|
napríklad jan ptak, jan holp, daniel hladek
|
|
|
|
topic režim:
|
|
napríklad rag agent, knowledge graph, nlp chatbot
|
|
"""
|
|
|
|
if not query_tokens:
|
|
return "topic"
|
|
|
|
has_technical_term = any(token in TECHNICAL_TERMS for token in query_tokens)
|
|
|
|
if len(query_tokens) == 2 and not has_technical_term:
|
|
return "person"
|
|
|
|
return "topic"
|
|
|
|
|
|
def score_tokens(query_tokens: list[str], field_tokens: list[str], weight: int) -> int:
|
|
counts = Counter(field_tokens)
|
|
score = 0
|
|
|
|
for token in query_tokens:
|
|
score += counts.get(token, 0) * weight
|
|
|
|
return score
|
|
|
|
|
|
def get_tags(conn: sqlite3.Connection, chunk_id: str) -> list[str]:
|
|
rows = conn.execute(
|
|
"SELECT tag FROM chunk_tags WHERE chunk_id = ?",
|
|
(chunk_id,)
|
|
).fetchall()
|
|
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
def get_categories(conn: sqlite3.Connection, chunk_id: str) -> list[str]:
|
|
rows = conn.execute(
|
|
"SELECT category FROM chunk_categories WHERE chunk_id = ?",
|
|
(chunk_id,)
|
|
).fetchall()
|
|
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
def contains_all_tokens(query_tokens: list[str], field_tokens: list[str]) -> bool:
|
|
return all(token in field_tokens for token in query_tokens)
|
|
|
|
|
|
def person_match(query_tokens: list[str], item: dict) -> bool:
|
|
title_tokens = tokenize(item.get("title") or "")
|
|
path_tokens = tokenize(item.get("document_path") or "")
|
|
author_tokens = tokenize(item.get("author") or "")
|
|
text_tokens = tokenize(item.get("text") or "")
|
|
|
|
if contains_all_tokens(query_tokens, title_tokens):
|
|
return True
|
|
|
|
if contains_all_tokens(query_tokens, path_tokens):
|
|
return True
|
|
|
|
if contains_all_tokens(query_tokens, author_tokens):
|
|
return True
|
|
|
|
"""
|
|
Text berieme slabšie, ale necháme ho ako fallback.
|
|
Napríklad ak meno nie je v title, ale je v obsahu.
|
|
"""
|
|
if contains_all_tokens(query_tokens, text_tokens):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def score_item(query: str, query_tokens: list[str], item: dict, mode: str) -> int:
|
|
title = item.get("title") or ""
|
|
path = item.get("document_path") or ""
|
|
author = item.get("author") or ""
|
|
text = item.get("text") or ""
|
|
tags = item.get("tags") or []
|
|
categories = item.get("categories") or []
|
|
|
|
title_tokens = tokenize(title)
|
|
path_tokens = tokenize(path)
|
|
author_tokens = tokenize(author)
|
|
text_tokens = tokenize(text)
|
|
tag_tokens = tokenize(" ".join(tags))
|
|
category_tokens = tokenize(" ".join(categories))
|
|
|
|
score = 0
|
|
|
|
if mode == "person":
|
|
score += score_tokens(query_tokens, title_tokens, 30)
|
|
score += score_tokens(query_tokens, path_tokens, 30)
|
|
score += score_tokens(query_tokens, author_tokens, 15)
|
|
score += score_tokens(query_tokens, text_tokens, 2)
|
|
|
|
if contains_all_tokens(query_tokens, title_tokens):
|
|
score += 100
|
|
|
|
if contains_all_tokens(query_tokens, path_tokens):
|
|
score += 100
|
|
|
|
if contains_all_tokens(query_tokens, author_tokens):
|
|
score += 60
|
|
|
|
return score
|
|
|
|
score += score_tokens(query_tokens, title_tokens, 12)
|
|
score += score_tokens(query_tokens, path_tokens, 12)
|
|
score += score_tokens(query_tokens, tag_tokens, 10)
|
|
score += score_tokens(query_tokens, category_tokens, 6)
|
|
score += score_tokens(query_tokens, author_tokens, 3)
|
|
score += score_tokens(query_tokens, text_tokens, 2)
|
|
|
|
normalized_query = normalize_text(query)
|
|
normalized_title = normalize_text(title)
|
|
normalized_path = normalize_text(path)
|
|
|
|
if normalized_query and normalized_query in normalized_title:
|
|
score += 30
|
|
|
|
if normalized_query and normalized_query in normalized_path:
|
|
score += 30
|
|
|
|
matched_title_tokens = sum(1 for token in query_tokens if token in title_tokens)
|
|
matched_path_tokens = sum(1 for token in query_tokens if token in path_tokens)
|
|
|
|
if query_tokens and matched_title_tokens == len(query_tokens):
|
|
score += 25
|
|
|
|
if query_tokens and matched_path_tokens == len(query_tokens):
|
|
score += 25
|
|
|
|
return score
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) < 2:
|
|
print("[red]Použitie:[/red] python scripts/search_db.py \"rag agent\"")
|
|
raise SystemExit(1)
|
|
|
|
if not DB_FILE.exists():
|
|
raise SystemExit(f"Databáza neexistuje: {DB_FILE}")
|
|
|
|
query = " ".join(sys.argv[1:])
|
|
query_tokens = tokenize(query)
|
|
mode = detect_search_mode(query_tokens)
|
|
|
|
conn = sqlite3.connect(DB_FILE)
|
|
|
|
rows = conn.execute("""
|
|
SELECT chunk_id, document_path, title, author, chunk_index, text, text_length
|
|
FROM chunks
|
|
""").fetchall()
|
|
|
|
results = []
|
|
|
|
for row in rows:
|
|
chunk_id, document_path, title, author, chunk_index, text, text_length = row
|
|
|
|
item = {
|
|
"chunk_id": chunk_id,
|
|
"document_path": document_path,
|
|
"title": title,
|
|
"author": author,
|
|
"chunk_index": chunk_index,
|
|
"text": text,
|
|
"text_length": text_length,
|
|
"tags": get_tags(conn, chunk_id),
|
|
"categories": get_categories(conn, chunk_id),
|
|
}
|
|
|
|
if mode == "person" and not person_match(query_tokens, item):
|
|
continue
|
|
|
|
score = score_item(query, query_tokens, item, mode)
|
|
|
|
if score > 0:
|
|
item["score"] = score
|
|
results.append(item)
|
|
|
|
results.sort(key=lambda item: item["score"], reverse=True)
|
|
|
|
print(f"[bold]Dopyt:[/bold] {query}")
|
|
print(f"[bold]Režim:[/bold] {mode}")
|
|
print(f"[bold]Počet výsledkov:[/bold] {len(results)}")
|
|
print("\n[bold]Top výsledky:[/bold]\n")
|
|
|
|
for rank, item in enumerate(results[:10], start=1):
|
|
print(f"[cyan]{rank}. Skóre: {item['score']}[/cyan]")
|
|
print(f"[bold]Názov:[/bold] {item['title']}")
|
|
print(f"[bold]Cesta:[/bold] {item['document_path']}")
|
|
print(f"[bold]Chunk:[/bold] {item['chunk_index']}")
|
|
print(f"[bold]Kategórie:[/bold] {item['categories']}")
|
|
print(f"[bold]Tagy:[/bold] {item['tags']}")
|
|
print(f"[bold]Autor:[/bold] {item['author']}")
|
|
print("[bold]Text:[/bold]")
|
|
print((item["text"] or "")[:700])
|
|
print("-" * 80)
|
|
|
|
conn.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|