451 lines
12 KiB
Python
451 lines
12 KiB
Python
from pathlib import Path
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import os
|
|
import re
|
|
import sqlite3
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import unicodedata
|
|
from collections import Counter
|
|
|
|
from fastapi import FastAPI, Header, HTTPException, Request
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
DB_FILE = Path("data/zp_index.sqlite")
|
|
ZPWIKI_ROOT = Path(os.getenv("ZPWIKI_ROOT", "../zpwiki"))
|
|
WEBHOOK_SECRET = os.getenv("WEBHOOK_SECRET", "dev-secret")
|
|
|
|
|
|
TECHNICAL_TERMS = {
|
|
"rag",
|
|
"agent",
|
|
"graph",
|
|
"knowledge",
|
|
"chatbot",
|
|
"nlp",
|
|
"llm",
|
|
"lm",
|
|
"openwebui",
|
|
"docker",
|
|
"webhook",
|
|
"database",
|
|
"db",
|
|
"neo4j",
|
|
"python",
|
|
"search",
|
|
"retrieval",
|
|
"generation",
|
|
"embedding",
|
|
"vector",
|
|
"vectors",
|
|
"langchain",
|
|
"graphrag",
|
|
"qa",
|
|
"question",
|
|
"answer",
|
|
"cloud",
|
|
"api",
|
|
}
|
|
|
|
|
|
app = FastAPI(
|
|
title="ZP Agent API",
|
|
description="API pre vyhľadávanie v repozitári záverečných prác zpwiki.",
|
|
version="0.3.0",
|
|
)
|
|
|
|
|
|
class SearchRequest(BaseModel):
|
|
query: str = Field(..., min_length=1)
|
|
limit: int = Field(default=10, ge=1, le=50)
|
|
|
|
|
|
class SyncRequest(BaseModel):
|
|
pull_git: bool = Field(
|
|
default=False,
|
|
description="Ak je true, pred reindexovaním sa vykoná git pull v repozitári zpwiki.",
|
|
)
|
|
|
|
|
|
def normalize_text(text: str) -> str:
|
|
text = text.lower()
|
|
text = text.replace("_", " ")
|
|
text = text.replace("/", " ")
|
|
text = text.replace("-", " ")
|
|
|
|
text = unicodedata.normalize("NFKD", text)
|
|
text = "".join(ch for ch in text if not unicodedata.combining(ch))
|
|
|
|
text = re.sub(r"[^a-z0-9]+", " ", text)
|
|
return text.strip()
|
|
|
|
|
|
def tokenize(text: str) -> list[str]:
|
|
text = normalize_text(text)
|
|
return [word for word in text.split() if len(word) >= 2]
|
|
|
|
|
|
def detect_search_mode(query_tokens: list[str]) -> str:
|
|
if not query_tokens:
|
|
return "topic"
|
|
|
|
has_technical_term = any(token in TECHNICAL_TERMS for token in query_tokens)
|
|
|
|
if len(query_tokens) == 2 and not has_technical_term:
|
|
return "person"
|
|
|
|
return "topic"
|
|
|
|
|
|
def score_tokens(query_tokens: list[str], field_tokens: list[str], weight: int) -> int:
|
|
counts = Counter(field_tokens)
|
|
score = 0
|
|
|
|
for token in query_tokens:
|
|
score += counts.get(token, 0) * weight
|
|
|
|
return score
|
|
|
|
|
|
def contains_all_tokens(query_tokens: list[str], field_tokens: list[str]) -> bool:
|
|
return all(token in field_tokens for token in query_tokens)
|
|
|
|
|
|
def get_tags(conn: sqlite3.Connection, chunk_id: str) -> list[str]:
|
|
rows = conn.execute(
|
|
"SELECT tag FROM chunk_tags WHERE chunk_id = ?",
|
|
(chunk_id,),
|
|
).fetchall()
|
|
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
def get_categories(conn: sqlite3.Connection, chunk_id: str) -> list[str]:
|
|
rows = conn.execute(
|
|
"SELECT category FROM chunk_categories WHERE chunk_id = ?",
|
|
(chunk_id,),
|
|
).fetchall()
|
|
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
def person_match(query_tokens: list[str], item: dict) -> bool:
|
|
title_tokens = tokenize(item.get("title") or "")
|
|
path_tokens = tokenize(item.get("document_path") or "")
|
|
author_tokens = tokenize(item.get("author") or "")
|
|
text_tokens = tokenize(item.get("text") or "")
|
|
|
|
if contains_all_tokens(query_tokens, title_tokens):
|
|
return True
|
|
|
|
if contains_all_tokens(query_tokens, path_tokens):
|
|
return True
|
|
|
|
if contains_all_tokens(query_tokens, author_tokens):
|
|
return True
|
|
|
|
if contains_all_tokens(query_tokens, text_tokens):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def score_item(query: str, query_tokens: list[str], item: dict, mode: str) -> int:
|
|
title = item.get("title") or ""
|
|
path = item.get("document_path") or ""
|
|
author = item.get("author") or ""
|
|
text = item.get("text") or ""
|
|
tags = item.get("tags") or []
|
|
categories = item.get("categories") or []
|
|
|
|
title_tokens = tokenize(title)
|
|
path_tokens = tokenize(path)
|
|
author_tokens = tokenize(author)
|
|
text_tokens = tokenize(text)
|
|
tag_tokens = tokenize(" ".join(tags))
|
|
category_tokens = tokenize(" ".join(categories))
|
|
|
|
score = 0
|
|
|
|
if mode == "person":
|
|
score += score_tokens(query_tokens, title_tokens, 30)
|
|
score += score_tokens(query_tokens, path_tokens, 30)
|
|
score += score_tokens(query_tokens, author_tokens, 15)
|
|
score += score_tokens(query_tokens, text_tokens, 2)
|
|
|
|
if contains_all_tokens(query_tokens, title_tokens):
|
|
score += 100
|
|
|
|
if contains_all_tokens(query_tokens, path_tokens):
|
|
score += 100
|
|
|
|
if contains_all_tokens(query_tokens, author_tokens):
|
|
score += 60
|
|
|
|
return score
|
|
|
|
score += score_tokens(query_tokens, title_tokens, 12)
|
|
score += score_tokens(query_tokens, path_tokens, 12)
|
|
score += score_tokens(query_tokens, tag_tokens, 10)
|
|
score += score_tokens(query_tokens, category_tokens, 6)
|
|
score += score_tokens(query_tokens, author_tokens, 3)
|
|
score += score_tokens(query_tokens, text_tokens, 2)
|
|
|
|
normalized_query = normalize_text(query)
|
|
normalized_title = normalize_text(title)
|
|
normalized_path = normalize_text(path)
|
|
|
|
if normalized_query and normalized_query in normalized_title:
|
|
score += 30
|
|
|
|
if normalized_query and normalized_query in normalized_path:
|
|
score += 30
|
|
|
|
matched_title_tokens = sum(1 for token in query_tokens if token in title_tokens)
|
|
matched_path_tokens = sum(1 for token in query_tokens if token in path_tokens)
|
|
|
|
if query_tokens and matched_title_tokens == len(query_tokens):
|
|
score += 25
|
|
|
|
if query_tokens and matched_path_tokens == len(query_tokens):
|
|
score += 25
|
|
|
|
return score
|
|
|
|
|
|
def make_source_url(document_path: str) -> str:
|
|
clean_path = document_path.replace("pages/", "").replace("/README.md", "")
|
|
return f"https://zp.kemt.fei.tuke.sk/{clean_path}"
|
|
|
|
|
|
def search_database(query: str, limit: int) -> tuple[str, list[dict]]:
|
|
if not DB_FILE.exists():
|
|
raise FileNotFoundError(f"Databáza neexistuje: {DB_FILE}")
|
|
|
|
query_tokens = tokenize(query)
|
|
mode = detect_search_mode(query_tokens)
|
|
|
|
conn = sqlite3.connect(DB_FILE)
|
|
|
|
rows = conn.execute("""
|
|
SELECT chunk_id, document_path, title, author, chunk_index, text, text_length
|
|
FROM chunks
|
|
""").fetchall()
|
|
|
|
results = []
|
|
|
|
for row in rows:
|
|
chunk_id, document_path, title, author, chunk_index, text, text_length = row
|
|
|
|
item = {
|
|
"chunk_id": chunk_id,
|
|
"document_path": document_path,
|
|
"title": title,
|
|
"author": author,
|
|
"chunk_index": chunk_index,
|
|
"text": text,
|
|
"text_length": text_length,
|
|
"tags": get_tags(conn, chunk_id),
|
|
"categories": get_categories(conn, chunk_id),
|
|
}
|
|
|
|
if mode == "person" and not person_match(query_tokens, item):
|
|
continue
|
|
|
|
score = score_item(query, query_tokens, item, mode)
|
|
|
|
if score > 0:
|
|
item["score"] = score
|
|
item["source_url"] = make_source_url(document_path)
|
|
results.append(item)
|
|
|
|
conn.close()
|
|
|
|
results.sort(key=lambda item: item["score"], reverse=True)
|
|
|
|
return mode, results[:limit]
|
|
|
|
|
|
def run_command(command: list[str], cwd: Path | None = None) -> str:
|
|
result = subprocess.run(
|
|
command,
|
|
cwd=cwd,
|
|
text=True,
|
|
capture_output=True,
|
|
)
|
|
|
|
output = ""
|
|
|
|
if result.stdout:
|
|
output += result.stdout
|
|
|
|
if result.stderr:
|
|
output += result.stderr
|
|
|
|
if result.returncode != 0:
|
|
raise RuntimeError(output.strip())
|
|
|
|
return output.strip()
|
|
|
|
|
|
def get_index_counts() -> dict:
|
|
if not DB_FILE.exists():
|
|
return {
|
|
"documents": 0,
|
|
"chunks": 0,
|
|
"tags": 0,
|
|
"categories": 0,
|
|
}
|
|
|
|
conn = sqlite3.connect(DB_FILE)
|
|
cursor = conn.cursor()
|
|
|
|
counts = {
|
|
"documents": cursor.execute("SELECT COUNT(*) FROM documents").fetchone()[0],
|
|
"chunks": cursor.execute("SELECT COUNT(*) FROM chunks").fetchone()[0],
|
|
"tags": cursor.execute("SELECT COUNT(*) FROM chunk_tags").fetchone()[0],
|
|
"categories": cursor.execute("SELECT COUNT(*) FROM chunk_categories").fetchone()[0],
|
|
}
|
|
|
|
conn.close()
|
|
return counts
|
|
|
|
|
|
def rebuild_index(pull_git: bool = False) -> dict:
|
|
start = time.time()
|
|
logs = []
|
|
|
|
if pull_git:
|
|
if not ZPWIKI_ROOT.exists():
|
|
raise RuntimeError(f"ZPWIKI_ROOT neexistuje: {ZPWIKI_ROOT}")
|
|
|
|
if not (ZPWIKI_ROOT / ".git").exists():
|
|
raise RuntimeError(f"Nie je to git repozitár: {ZPWIKI_ROOT}")
|
|
|
|
logs.append(run_command(["git", "pull"], cwd=ZPWIKI_ROOT))
|
|
|
|
logs.append(run_command([sys.executable, "scripts/scan_zpwiki.py"]))
|
|
logs.append(run_command([sys.executable, "scripts/build_chunks.py"]))
|
|
logs.append(run_command([sys.executable, "scripts/build_sqlite_index.py"]))
|
|
|
|
counts = get_index_counts()
|
|
duration = round(time.time() - start, 2)
|
|
|
|
return {
|
|
"duration_seconds": duration,
|
|
"counts": counts,
|
|
"logs": logs,
|
|
}
|
|
|
|
|
|
def verify_gitea_signature(raw_body: bytes, signature: str | None) -> bool:
|
|
if not signature:
|
|
return False
|
|
|
|
expected = hmac.new(
|
|
WEBHOOK_SECRET.encode("utf-8"),
|
|
raw_body,
|
|
hashlib.sha256,
|
|
).hexdigest()
|
|
|
|
signature = signature.strip()
|
|
|
|
if signature.startswith("sha256="):
|
|
signature = signature.replace("sha256=", "", 1)
|
|
|
|
return hmac.compare_digest(expected, signature)
|
|
|
|
|
|
def verify_simple_token(token: str | None) -> bool:
|
|
if not token:
|
|
return False
|
|
|
|
return hmac.compare_digest(token, WEBHOOK_SECRET)
|
|
|
|
|
|
@app.get("/health")
|
|
def health():
|
|
return {
|
|
"status": "ok",
|
|
"database_exists": DB_FILE.exists(),
|
|
"database_path": str(DB_FILE),
|
|
"zpwiki_root": str(ZPWIKI_ROOT),
|
|
"zpwiki_exists": ZPWIKI_ROOT.exists(),
|
|
"webhook_secret_configured": bool(WEBHOOK_SECRET),
|
|
}
|
|
|
|
|
|
@app.post("/search")
|
|
def search(request: SearchRequest):
|
|
try:
|
|
mode, results = search_database(request.query, request.limit)
|
|
except FileNotFoundError as error:
|
|
raise HTTPException(status_code=500, detail=str(error)) from error
|
|
|
|
return {
|
|
"query": request.query,
|
|
"mode": mode,
|
|
"count": len(results),
|
|
"results": results,
|
|
}
|
|
|
|
|
|
@app.post("/sync")
|
|
def sync(request: SyncRequest):
|
|
try:
|
|
result = rebuild_index(pull_git=request.pull_git)
|
|
except RuntimeError as error:
|
|
raise HTTPException(status_code=500, detail=str(error)) from error
|
|
|
|
return {
|
|
"status": "ok",
|
|
"pull_git": request.pull_git,
|
|
"duration_seconds": result["duration_seconds"],
|
|
"counts": result["counts"],
|
|
}
|
|
|
|
|
|
@app.post("/webhook/gitea")
|
|
async def gitea_webhook(
|
|
request: Request,
|
|
x_gitea_event: str | None = Header(default=None, alias="X-Gitea-Event"),
|
|
x_gitea_signature: str | None = Header(default=None, alias="X-Gitea-Signature"),
|
|
x_gitea_token: str | None = Header(default=None, alias="X-Gitea-Token"),
|
|
):
|
|
raw_body = await request.body()
|
|
|
|
signature_ok = verify_gitea_signature(raw_body, x_gitea_signature)
|
|
token_ok = verify_simple_token(x_gitea_token)
|
|
|
|
if not signature_ok and not token_ok:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid webhook signature or token",
|
|
)
|
|
|
|
try:
|
|
payload = json.loads(raw_body.decode("utf-8")) if raw_body else {}
|
|
except json.JSONDecodeError:
|
|
payload = {}
|
|
|
|
repository = payload.get("repository", {})
|
|
repository_name = repository.get("full_name") or repository.get("name")
|
|
|
|
try:
|
|
result = rebuild_index(pull_git=False)
|
|
except RuntimeError as error:
|
|
raise HTTPException(status_code=500, detail=str(error)) from error
|
|
|
|
return {
|
|
"status": "ok",
|
|
"event": x_gitea_event or "unknown",
|
|
"repository": repository_name,
|
|
"verified_by": "signature" if signature_ok else "token",
|
|
"duration_seconds": result["duration_seconds"],
|
|
"counts": result["counts"],
|
|
}
|