290 lines
7.3 KiB
Python
290 lines
7.3 KiB
Python
from pathlib import Path
|
|
import sqlite3
|
|
import re
|
|
import unicodedata
|
|
from collections import Counter
|
|
|
|
from fastapi import FastAPI
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
DB_FILE = Path("data/zp_index.sqlite")
|
|
|
|
|
|
TECHNICAL_TERMS = {
|
|
"rag",
|
|
"agent",
|
|
"graph",
|
|
"knowledge",
|
|
"chatbot",
|
|
"nlp",
|
|
"llm",
|
|
"lm",
|
|
"openwebui",
|
|
"docker",
|
|
"webhook",
|
|
"database",
|
|
"db",
|
|
"neo4j",
|
|
"python",
|
|
"search",
|
|
"retrieval",
|
|
"generation",
|
|
"embedding",
|
|
"vector",
|
|
"vectors",
|
|
"langchain",
|
|
"graphrag",
|
|
"qa",
|
|
"question",
|
|
"answer",
|
|
"cloud",
|
|
"api",
|
|
}
|
|
|
|
|
|
app = FastAPI(
|
|
title="ZP Agent API",
|
|
description="API pre vyhľadávanie v repozitári záverečných prác zpwiki.",
|
|
version="0.1.0",
|
|
)
|
|
|
|
|
|
class SearchRequest(BaseModel):
|
|
query: str = Field(..., min_length=1)
|
|
limit: int = Field(default=10, ge=1, le=50)
|
|
|
|
|
|
class SearchResult(BaseModel):
|
|
score: int
|
|
chunk_id: str
|
|
document_path: str
|
|
source_url: str
|
|
title: str | None
|
|
author: str | None
|
|
chunk_index: int
|
|
categories: list[str]
|
|
tags: list[str]
|
|
text: str
|
|
text_length: int
|
|
|
|
|
|
def normalize_text(text: str) -> str:
|
|
text = text.lower()
|
|
text = text.replace("_", " ")
|
|
text = text.replace("/", " ")
|
|
text = text.replace("-", " ")
|
|
|
|
text = unicodedata.normalize("NFKD", text)
|
|
text = "".join(ch for ch in text if not unicodedata.combining(ch))
|
|
|
|
text = re.sub(r"[^a-z0-9]+", " ", text)
|
|
return text.strip()
|
|
|
|
|
|
def tokenize(text: str) -> list[str]:
|
|
text = normalize_text(text)
|
|
return [word for word in text.split() if len(word) >= 2]
|
|
|
|
|
|
def detect_search_mode(query_tokens: list[str]) -> str:
|
|
if not query_tokens:
|
|
return "topic"
|
|
|
|
has_technical_term = any(token in TECHNICAL_TERMS for token in query_tokens)
|
|
|
|
if len(query_tokens) == 2 and not has_technical_term:
|
|
return "person"
|
|
|
|
return "topic"
|
|
|
|
|
|
def score_tokens(query_tokens: list[str], field_tokens: list[str], weight: int) -> int:
|
|
counts = Counter(field_tokens)
|
|
score = 0
|
|
|
|
for token in query_tokens:
|
|
score += counts.get(token, 0) * weight
|
|
|
|
return score
|
|
|
|
|
|
def contains_all_tokens(query_tokens: list[str], field_tokens: list[str]) -> bool:
|
|
return all(token in field_tokens for token in query_tokens)
|
|
|
|
|
|
def get_tags(conn: sqlite3.Connection, chunk_id: str) -> list[str]:
|
|
rows = conn.execute(
|
|
"SELECT tag FROM chunk_tags WHERE chunk_id = ?",
|
|
(chunk_id,),
|
|
).fetchall()
|
|
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
def get_categories(conn: sqlite3.Connection, chunk_id: str) -> list[str]:
|
|
rows = conn.execute(
|
|
"SELECT category FROM chunk_categories WHERE chunk_id = ?",
|
|
(chunk_id,),
|
|
).fetchall()
|
|
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
def person_match(query_tokens: list[str], item: dict) -> bool:
|
|
title_tokens = tokenize(item.get("title") or "")
|
|
path_tokens = tokenize(item.get("document_path") or "")
|
|
author_tokens = tokenize(item.get("author") or "")
|
|
text_tokens = tokenize(item.get("text") or "")
|
|
|
|
if contains_all_tokens(query_tokens, title_tokens):
|
|
return True
|
|
|
|
if contains_all_tokens(query_tokens, path_tokens):
|
|
return True
|
|
|
|
if contains_all_tokens(query_tokens, author_tokens):
|
|
return True
|
|
|
|
if contains_all_tokens(query_tokens, text_tokens):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def score_item(query: str, query_tokens: list[str], item: dict, mode: str) -> int:
|
|
title = item.get("title") or ""
|
|
path = item.get("document_path") or ""
|
|
author = item.get("author") or ""
|
|
text = item.get("text") or ""
|
|
tags = item.get("tags") or []
|
|
categories = item.get("categories") or []
|
|
|
|
title_tokens = tokenize(title)
|
|
path_tokens = tokenize(path)
|
|
author_tokens = tokenize(author)
|
|
text_tokens = tokenize(text)
|
|
tag_tokens = tokenize(" ".join(tags))
|
|
category_tokens = tokenize(" ".join(categories))
|
|
|
|
score = 0
|
|
|
|
if mode == "person":
|
|
score += score_tokens(query_tokens, title_tokens, 30)
|
|
score += score_tokens(query_tokens, path_tokens, 30)
|
|
score += score_tokens(query_tokens, author_tokens, 15)
|
|
score += score_tokens(query_tokens, text_tokens, 2)
|
|
|
|
if contains_all_tokens(query_tokens, title_tokens):
|
|
score += 100
|
|
|
|
if contains_all_tokens(query_tokens, path_tokens):
|
|
score += 100
|
|
|
|
if contains_all_tokens(query_tokens, author_tokens):
|
|
score += 60
|
|
|
|
return score
|
|
|
|
score += score_tokens(query_tokens, title_tokens, 12)
|
|
score += score_tokens(query_tokens, path_tokens, 12)
|
|
score += score_tokens(query_tokens, tag_tokens, 10)
|
|
score += score_tokens(query_tokens, category_tokens, 6)
|
|
score += score_tokens(query_tokens, author_tokens, 3)
|
|
score += score_tokens(query_tokens, text_tokens, 2)
|
|
|
|
normalized_query = normalize_text(query)
|
|
normalized_title = normalize_text(title)
|
|
normalized_path = normalize_text(path)
|
|
|
|
if normalized_query and normalized_query in normalized_title:
|
|
score += 30
|
|
|
|
if normalized_query and normalized_query in normalized_path:
|
|
score += 30
|
|
|
|
matched_title_tokens = sum(1 for token in query_tokens if token in title_tokens)
|
|
matched_path_tokens = sum(1 for token in query_tokens if token in path_tokens)
|
|
|
|
if query_tokens and matched_title_tokens == len(query_tokens):
|
|
score += 25
|
|
|
|
if query_tokens and matched_path_tokens == len(query_tokens):
|
|
score += 25
|
|
|
|
return score
|
|
|
|
|
|
def make_source_url(document_path: str) -> str:
|
|
clean_path = document_path.replace("pages/", "").replace("/README.md", "")
|
|
return f"https://zp.kemt.fei.tuke.sk/{clean_path}"
|
|
|
|
|
|
def search_database(query: str, limit: int) -> tuple[str, list[dict]]:
|
|
if not DB_FILE.exists():
|
|
raise FileNotFoundError(f"Databáza neexistuje: {DB_FILE}")
|
|
|
|
query_tokens = tokenize(query)
|
|
mode = detect_search_mode(query_tokens)
|
|
|
|
conn = sqlite3.connect(DB_FILE)
|
|
|
|
rows = conn.execute("""
|
|
SELECT chunk_id, document_path, title, author, chunk_index, text, text_length
|
|
FROM chunks
|
|
""").fetchall()
|
|
|
|
results = []
|
|
|
|
for row in rows:
|
|
chunk_id, document_path, title, author, chunk_index, text, text_length = row
|
|
|
|
item = {
|
|
"chunk_id": chunk_id,
|
|
"document_path": document_path,
|
|
"title": title,
|
|
"author": author,
|
|
"chunk_index": chunk_index,
|
|
"text": text,
|
|
"text_length": text_length,
|
|
"tags": get_tags(conn, chunk_id),
|
|
"categories": get_categories(conn, chunk_id),
|
|
}
|
|
|
|
if mode == "person" and not person_match(query_tokens, item):
|
|
continue
|
|
|
|
score = score_item(query, query_tokens, item, mode)
|
|
|
|
if score > 0:
|
|
item["score"] = score
|
|
item["source_url"] = make_source_url(document_path)
|
|
results.append(item)
|
|
|
|
conn.close()
|
|
|
|
results.sort(key=lambda item: item["score"], reverse=True)
|
|
|
|
return mode, results[:limit]
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {
|
|
"status": "ok",
|
|
"database_exists": DB_FILE.exists(),
|
|
"database_path": str(DB_FILE),
|
|
}
|
|
|
|
|
|
@app.post("/search")
|
|
def search(request: SearchRequest):
|
|
mode, results = search_database(request.query, request.limit)
|
|
|
|
return {
|
|
"query": request.query,
|
|
"mode": mode,
|
|
"count": len(results),
|
|
"results": results,
|
|
}
|