Nahrát soubory do „/“
This commit is contained in:
commit
22f8a3144a
411
graph_builder.py
Normal file
411
graph_builder.py
Normal file
@ -0,0 +1,411 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import networkx as nx
|
||||||
|
import re
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from ingest import load_documents
|
||||||
|
from llm import llm_call
|
||||||
|
|
||||||
|
# Cleans response to avoid errors
|
||||||
|
def clean_llm_response(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Removes markdown, extra text, and extracts JSON safely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
raise ValueError("LLM returned empty response")
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# Remove ```json or ``` wrappers
|
||||||
|
text = re.sub(r"^```json", "", text)
|
||||||
|
text = re.sub(r"^```", "", text)
|
||||||
|
text = re.sub(r"```$", "", text)
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# Try to extract JSON object if model adds extra text
|
||||||
|
match = re.search(r"\{.*\}", text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
return match.group(0)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
# Gets cached extraction for a chunk, or extracts and caches it
|
||||||
|
def get_cached_extraction(chunk, cache_dir="cache"):
|
||||||
|
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create unique hash for chunk
|
||||||
|
chunk_hash = hashlib.md5(
|
||||||
|
chunk.encode("utf-8")
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
cache_path = os.path.join(
|
||||||
|
cache_dir,
|
||||||
|
f"{chunk_hash}.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Cache Load (if exists)
|
||||||
|
if os.path.exists(cache_path):
|
||||||
|
|
||||||
|
print(f"[CACHE HIT] {chunk_hash}")
|
||||||
|
|
||||||
|
with open(cache_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
# Extract json and store in cache
|
||||||
|
print(f"[CACHE MISS] {chunk_hash}")
|
||||||
|
|
||||||
|
extraction = extract_entities_and_relations(chunk)
|
||||||
|
|
||||||
|
|
||||||
|
# Save cache
|
||||||
|
with open(cache_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(
|
||||||
|
extraction,
|
||||||
|
f,
|
||||||
|
indent=2,
|
||||||
|
ensure_ascii=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return extraction
|
||||||
|
|
||||||
|
# Calls LLM to get entities and realtions (.json) and saves them to cache
|
||||||
|
def extract_entities_and_relations(text):
|
||||||
|
|
||||||
|
text = text[:8000] # prevent context overflow
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
You are an educational knowledge graph extraction system.
|
||||||
|
|
||||||
|
Your task:
|
||||||
|
|
||||||
|
Extract educational entities and relations.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Return ONLY EXACT JSON
|
||||||
|
- Use ONLY meaningful educational concepts
|
||||||
|
- NO markdown
|
||||||
|
- NO explanations
|
||||||
|
- NO extra text
|
||||||
|
- Use these relations ONLY:
|
||||||
|
[ "is_a",
|
||||||
|
"part_of",
|
||||||
|
"uses",
|
||||||
|
"teaches",
|
||||||
|
"requires",
|
||||||
|
"depends_on",
|
||||||
|
"implements",
|
||||||
|
"applies_to",
|
||||||
|
"subfield_of",
|
||||||
|
"contains",
|
||||||
|
"defined_by"]
|
||||||
|
- Entity names must be normalized:
|
||||||
|
GOOD:
|
||||||
|
- "Machine Learning"
|
||||||
|
- "Artificial Intelligence"
|
||||||
|
BAD:
|
||||||
|
- "machine learning"
|
||||||
|
- "ML"
|
||||||
|
- "AI systems"
|
||||||
|
- Prefer fewer HIGH QUALITY relations over many weak relations
|
||||||
|
- Every entity should participate in at least one relation
|
||||||
|
|
||||||
|
OUTPUT FORMAT EXACTLY:
|
||||||
|
{{
|
||||||
|
"entities": [
|
||||||
|
{{"name": "...", "type": "..."}}
|
||||||
|
],
|
||||||
|
"relations": [
|
||||||
|
{{
|
||||||
|
"source": "...",
|
||||||
|
"target": "...",
|
||||||
|
"relation": "..."
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
TEXT:
|
||||||
|
{text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = llm_call([
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
])
|
||||||
|
|
||||||
|
cleaned = clean_llm_response(response)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(cleaned)
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("\nINVALID JSON")
|
||||||
|
print(cleaned)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Checks if llm output is valid
|
||||||
|
entities = data.get("entities", [])
|
||||||
|
relations = data.get("relations", [])
|
||||||
|
|
||||||
|
# Normalize entity names
|
||||||
|
normalized_entities = {}
|
||||||
|
|
||||||
|
for ent in entities:
|
||||||
|
|
||||||
|
name = ent["name"].strip()
|
||||||
|
|
||||||
|
# title case normalization
|
||||||
|
name = " ".join(word.capitalize() for word in name.split())
|
||||||
|
|
||||||
|
normalized_entities[name] = {
|
||||||
|
"name": name,
|
||||||
|
"type": ent.get("type", "concept")
|
||||||
|
}
|
||||||
|
|
||||||
|
entity_names = set(normalized_entities.keys())
|
||||||
|
|
||||||
|
valid_relations = []
|
||||||
|
|
||||||
|
for rel in relations:
|
||||||
|
|
||||||
|
source = rel["source"].strip()
|
||||||
|
target = rel["target"].strip()
|
||||||
|
relation = rel["relation"].strip()
|
||||||
|
|
||||||
|
source = " ".join(word.capitalize() for word in source.split())
|
||||||
|
target = " ".join(word.capitalize() for word in target.split())
|
||||||
|
|
||||||
|
# Skip self-loops
|
||||||
|
if source == target:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ensure both entities exist
|
||||||
|
if source not in entity_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if target not in entity_names:
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid_relations.append({
|
||||||
|
"source": source,
|
||||||
|
"target": target,
|
||||||
|
"relation": relation
|
||||||
|
})
|
||||||
|
|
||||||
|
# Remove isolated entities
|
||||||
|
connected = set()
|
||||||
|
|
||||||
|
for rel in valid_relations:
|
||||||
|
connected.add(rel["source"])
|
||||||
|
connected.add(rel["target"])
|
||||||
|
|
||||||
|
final_entities = [
|
||||||
|
ent for ent in normalized_entities.values()
|
||||||
|
if ent["name"] in connected
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"entities": final_entities,
|
||||||
|
"relations": valid_relations
|
||||||
|
}
|
||||||
|
|
||||||
|
# Chunks document into chunks (with overlap)
|
||||||
|
def chunk_text(text, chunk_size=8000, overlap=300):
|
||||||
|
|
||||||
|
text = text.replace("\r", " ")
|
||||||
|
|
||||||
|
paragraphs = text.split("\n")
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current = ""
|
||||||
|
|
||||||
|
for para in paragraphs:
|
||||||
|
|
||||||
|
para = para.strip()
|
||||||
|
|
||||||
|
# Skip useless tiny paragraphs
|
||||||
|
if len(para) < 40:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(current) + len(para) < chunk_size:
|
||||||
|
|
||||||
|
current += para + "\n"
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
chunks.append(current)
|
||||||
|
|
||||||
|
# overlap keeps context continuity
|
||||||
|
current = current[-overlap:] + "\n" + para + "\n"
|
||||||
|
|
||||||
|
if current:
|
||||||
|
chunks.append(current)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Builds graph from documents (must be in documents folder [.pdf, .txt, .docx])
|
||||||
|
def build_graph_from_documents(folder="documents"):
|
||||||
|
|
||||||
|
G = nx.DiGraph()
|
||||||
|
|
||||||
|
docs = load_documents(folder)
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
|
||||||
|
print(f"\n[INFO] Processing document: {doc['filename']}")
|
||||||
|
|
||||||
|
# Document is split into chunks
|
||||||
|
chunks = chunk_text(doc["content"])
|
||||||
|
|
||||||
|
print(f"[INFO] Total chunks: {len(chunks)}")
|
||||||
|
|
||||||
|
# Process chunk by chunk
|
||||||
|
for idx, chunk in enumerate(chunks):
|
||||||
|
|
||||||
|
print(f"[INFO] Processing chunk {idx + 1}/{len(chunks)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
extraction = get_cached_extraction(chunk)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
print(f"[WARN] Extraction failed on chunk {idx + 1}")
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add entities (nodes) to graph
|
||||||
|
for ent in extraction["entities"]:
|
||||||
|
|
||||||
|
G.add_node(
|
||||||
|
ent["name"],
|
||||||
|
type=ent["type"],
|
||||||
|
source_doc=doc["filename"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adds entity realtions (links nodes)
|
||||||
|
for rel in extraction["relations"]:
|
||||||
|
|
||||||
|
G.add_edge(
|
||||||
|
rel["source"],
|
||||||
|
rel["target"],
|
||||||
|
relation=rel["relation"],
|
||||||
|
source_doc=doc["filename"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return G
|
||||||
|
|
||||||
|
# Saves graph to file
|
||||||
|
def save_graph(G, path="graph/kg.graphml"):
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
nx.write_graphml(G, path)
|
||||||
|
print(f"Graph saved to {path}")
|
||||||
|
|
||||||
|
# Loads graph from file
|
||||||
|
def load_graph(path="graph/kg.graphml"):
|
||||||
|
return nx.read_graphml(path)
|
||||||
|
|
||||||
|
# Adds document to graph
|
||||||
|
def add_document(filepath, G):
|
||||||
|
|
||||||
|
docs = load_documents(folder=filepath)
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
print(f"[WARN] No documents found in: {filepath}")
|
||||||
|
return
|
||||||
|
|
||||||
|
doc = docs[0]
|
||||||
|
|
||||||
|
print(f"\n[INFO] Adding document: {doc['filename']}")
|
||||||
|
|
||||||
|
# Document is split into chunks
|
||||||
|
chunks = chunk_text(doc["content"])
|
||||||
|
|
||||||
|
print(f"[INFO] Total chunks: {len(chunks)}")
|
||||||
|
|
||||||
|
# Process chunk by chunk
|
||||||
|
for idx, chunk in enumerate(chunks):
|
||||||
|
|
||||||
|
print(f"[INFO] Processing chunk {idx + 1}/{len(chunks)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
extraction = get_cached_extraction(chunk)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
|
||||||
|
print(f"[WARN] Extraction failed on chunk {idx + 1}")
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add entities (nodes) to graph
|
||||||
|
for ent in extraction["entities"]:
|
||||||
|
|
||||||
|
G.add_node(
|
||||||
|
ent["name"],
|
||||||
|
type=ent["type"],
|
||||||
|
source_doc=doc["filename"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adds entity realtions (links nodes)
|
||||||
|
for rel in extraction["relations"]:
|
||||||
|
|
||||||
|
G.add_edge(
|
||||||
|
rel["source"],
|
||||||
|
rel["target"],
|
||||||
|
relation=rel["relation"],
|
||||||
|
source_doc=doc["filename"]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[INFO] Finished adding {doc['filename']}")
|
||||||
|
|
||||||
|
# Removes document from graph
|
||||||
|
def remove_document(filename, G):
|
||||||
|
|
||||||
|
edges_to_remove = []
|
||||||
|
|
||||||
|
for u, v, data in G.edges(data=True):
|
||||||
|
|
||||||
|
if data.get("source_doc") == filename:
|
||||||
|
edges_to_remove.append((u, v))
|
||||||
|
|
||||||
|
G.remove_edges_from(edges_to_remove)
|
||||||
|
|
||||||
|
# Remove orphan nodes
|
||||||
|
isolated_nodes = list(nx.isolates(G))
|
||||||
|
|
||||||
|
G.remove_nodes_from(isolated_nodes)
|
||||||
|
|
||||||
|
print(f"[INFO] Removed document: {filename}")
|
||||||
|
|
||||||
|
'''
|
||||||
|
#test
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
docs = load_documents()
|
||||||
|
|
||||||
|
extraction = extract_entities_and_relations(
|
||||||
|
docs[0]["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(extraction)
|
||||||
|
'''
|
||||||
|
|
||||||
|
#test
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
#G2 = load_graph()
|
||||||
|
|
||||||
|
#print(G2.nodes(data=True))
|
||||||
|
G = build_graph_from_documents()
|
||||||
|
|
||||||
|
save_graph(G)
|
||||||
|
#print(G.nodes(data=True))
|
||||||
|
|
||||||
|
#print(G.edges(data=True))
|
||||||
|
|
||||||
43
ingest.py
Normal file
43
ingest.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from pypdf import PdfReader
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
def load_documents(folder="documents"):
|
||||||
|
docs = []
|
||||||
|
|
||||||
|
for file in Path(folder).glob("*"):
|
||||||
|
if file.suffix == ".txt":
|
||||||
|
text = file.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
docs.append({
|
||||||
|
"filename": file.name,
|
||||||
|
"content": text
|
||||||
|
})
|
||||||
|
|
||||||
|
elif file.suffix == ".pdf":
|
||||||
|
reader = PdfReader(file)
|
||||||
|
text = ""
|
||||||
|
for page in reader.pages:
|
||||||
|
text += page.extract_text()
|
||||||
|
|
||||||
|
docs.append({
|
||||||
|
"filename": file.name,
|
||||||
|
"content": text
|
||||||
|
})
|
||||||
|
|
||||||
|
elif file.suffix == ".docx":
|
||||||
|
doc = Document(file)
|
||||||
|
text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||||
|
|
||||||
|
docs.append({
|
||||||
|
"filename": file.name,
|
||||||
|
"content": text
|
||||||
|
})
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
#test
|
||||||
|
if __name__ == "__main__":
|
||||||
|
docs = load_documents()
|
||||||
|
|
||||||
|
print(docs)
|
||||||
35
llm.py
Normal file
35
llm.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# TUKE LLM Connection
|
||||||
|
UNIVERSITY_BASE_URL = "https://ui.tukekemt.xyz/api/v1/chat/completions"
|
||||||
|
UNIVERSITY_API_KEY = "sk-06098ff9afb946c2b9d197cb400cd752"
|
||||||
|
UNIVERSITY_MODEL = "model2"
|
||||||
|
|
||||||
|
# LLM call
|
||||||
|
def llm_call(messages: List[Dict[str, str]]) -> str:
|
||||||
|
"""Send a list of {role, content} dicts to the university LLM and return the reply text."""
|
||||||
|
resp = requests.post(
|
||||||
|
UNIVERSITY_BASE_URL,
|
||||||
|
json={"model": UNIVERSITY_MODEL, "messages": messages, "stream": False},
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {UNIVERSITY_API_KEY}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
#test
|
||||||
|
if __name__ == "__main__":
|
||||||
|
response = llm_call([
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Say hello"
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
print(response)
|
||||||
88
main.py
Normal file
88
main.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from graph_builder import (
|
||||||
|
build_graph_from_documents,
|
||||||
|
save_graph,
|
||||||
|
load_graph
|
||||||
|
)
|
||||||
|
from react_agent import react_agent
|
||||||
|
from visualization import visualize_graph
|
||||||
|
|
||||||
|
GRAPH_PATH = "graph/kg.graphml"
|
||||||
|
|
||||||
|
# Create graph or load existing one
|
||||||
|
def get_or_build_graph(force_rebuild=False):
|
||||||
|
|
||||||
|
if not force_rebuild:
|
||||||
|
try:
|
||||||
|
print("\n[INFO] Loading existing graph...")
|
||||||
|
G = load_graph(GRAPH_PATH)
|
||||||
|
|
||||||
|
print(f"[INFO] Loaded graph: {len(G.nodes())} nodes, {len(G.edges())} edges")
|
||||||
|
return G
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WARN] Could not load graph: {e}")
|
||||||
|
print("[INFO] Rebuilding graph...")
|
||||||
|
|
||||||
|
print("\n[INFO] Building graph from documents...")
|
||||||
|
G = build_graph_from_documents()
|
||||||
|
|
||||||
|
print(f"[INFO] Built graph: {len(G.nodes())} nodes, {len(G.edges())} edges")
|
||||||
|
|
||||||
|
print("\n[INFO] Saving graph...")
|
||||||
|
save_graph(G, GRAPH_PATH)
|
||||||
|
return G
|
||||||
|
|
||||||
|
# Testing questions (query)
|
||||||
|
def ask_question(G, question):
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(f"QUESTION: {question}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
history = []
|
||||||
|
|
||||||
|
result = react_agent(
|
||||||
|
user_message=question,
|
||||||
|
history=history,
|
||||||
|
G=G
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n--- FINAL ANSWER ---")
|
||||||
|
print(result.answer)
|
||||||
|
|
||||||
|
print("\n--- EVIDENCE (GRAPH) ---")
|
||||||
|
print(result.evidence)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Visualisation
|
||||||
|
def show_graph(G):
|
||||||
|
|
||||||
|
print("\n[INFO] Generating visualization...")
|
||||||
|
|
||||||
|
visualize_graph(G)
|
||||||
|
|
||||||
|
print("[INFO] Graph saved as kg.html")
|
||||||
|
|
||||||
|
# Run main
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# Create graph or load existing one
|
||||||
|
G = get_or_build_graph(force_rebuild=False)
|
||||||
|
|
||||||
|
# Run visualisation.py (optional and for testing)
|
||||||
|
#show_graph(G)
|
||||||
|
|
||||||
|
# Testing questions (query)
|
||||||
|
test_questions = [
|
||||||
|
#"What is taught in Informatics?",
|
||||||
|
#"What is used in machine learning?",
|
||||||
|
#"What Literary Texts are taught in Slovak Language And Slovak Literature?",
|
||||||
|
"Čo sa učí na predmete Slovenský jazyk a literatúra?"
|
||||||
|
]
|
||||||
|
|
||||||
|
for q in test_questions:
|
||||||
|
ask_question(G, q)
|
||||||
|
|
||||||
|
print("\n[INFO] PIPELINE COMPLETED")
|
||||||
30
models.py
Normal file
30
models.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GraphRAGResponse:
|
||||||
|
|
||||||
|
question: str
|
||||||
|
|
||||||
|
answer: str
|
||||||
|
|
||||||
|
evidence: Optional[str] = None
|
||||||
|
|
||||||
|
entities_used: Optional[List[str]] = None
|
||||||
|
|
||||||
|
relations_used: Optional[List[str]] = None
|
||||||
|
|
||||||
|
source_documents: Optional[List[str]] = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
|
||||||
|
return (
|
||||||
|
"\nGraphRAGResponse(\n"
|
||||||
|
f" question = {self.question!r}\n"
|
||||||
|
f" answer = {self.answer!r}\n"
|
||||||
|
f" evidence = {self.evidence!r}\n"
|
||||||
|
f" entities_used = {self.entities_used!r}\n"
|
||||||
|
f" relations_used = {self.relations_used!r}\n"
|
||||||
|
f" source_documents = {self.source_documents!r}\n"
|
||||||
|
")"
|
||||||
|
)
|
||||||
82
react_agent.py
Normal file
82
react_agent.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from tools import TOOL_MAP
|
||||||
|
from llm import llm_call
|
||||||
|
import re
|
||||||
|
from typing import List, Dict
|
||||||
|
from models import GraphRAGResponse
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """
|
||||||
|
You are an educational GraphRAG assistant for schools.
|
||||||
|
|
||||||
|
You MUST use tools when answering knowledge questions.
|
||||||
|
|
||||||
|
Available tool:
|
||||||
|
- query_knowledge_graph
|
||||||
|
|
||||||
|
Format:
|
||||||
|
Action: query_knowledge_graph
|
||||||
|
Action Input: <question>
|
||||||
|
|
||||||
|
OR:
|
||||||
|
|
||||||
|
Final Answer: <answer>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse_action(text: str):
|
||||||
|
action = re.search(r"Action:\s*(\w+)", text)
|
||||||
|
input_ = re.search(r"Action Input:\s*(.+)", text)
|
||||||
|
|
||||||
|
if action and input_:
|
||||||
|
return action.group(1), input_.group(1)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def parse_final(text: str):
|
||||||
|
m = re.search(r"Final Answer:\s*(.+)", text, re.DOTALL)
|
||||||
|
return m.group(1).strip() if m else None
|
||||||
|
|
||||||
|
def react_agent(user_message: str, history: List[Dict], G=None, max_steps: int = 5):
|
||||||
|
|
||||||
|
messages = (
|
||||||
|
[{"role": "system", "content": SYSTEM_PROMPT}]
|
||||||
|
+ history
|
||||||
|
+ [{"role": "user", "content": user_message}]
|
||||||
|
)
|
||||||
|
|
||||||
|
last = ""
|
||||||
|
|
||||||
|
for _ in range(max_steps):
|
||||||
|
|
||||||
|
reply = llm_call(messages)
|
||||||
|
last = reply
|
||||||
|
|
||||||
|
final = parse_final(reply)
|
||||||
|
if final:
|
||||||
|
return GraphRAGResponse(
|
||||||
|
question=user_message,
|
||||||
|
answer=final,
|
||||||
|
evidence=""
|
||||||
|
)
|
||||||
|
|
||||||
|
action = parse_action(reply)
|
||||||
|
|
||||||
|
if action:
|
||||||
|
|
||||||
|
tool_name, tool_input = action
|
||||||
|
|
||||||
|
if tool_name == "query_knowledge_graph":
|
||||||
|
result = TOOL_MAP[tool_name](tool_input, G)
|
||||||
|
|
||||||
|
messages.append({"role": "assistant", "content": reply})
|
||||||
|
messages.append({"role": "user", "content": f"Observation: {result}"})
|
||||||
|
|
||||||
|
return GraphRAGResponse(
|
||||||
|
question=user_message,
|
||||||
|
answer=result["answer"],
|
||||||
|
evidence=result["evidence"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return GraphRAGResponse(
|
||||||
|
question=user_message,
|
||||||
|
answer=last,
|
||||||
|
evidence=""
|
||||||
|
)
|
||||||
37
retrieval.py
Normal file
37
retrieval.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
def retrieve_subgraph(G, query):
|
||||||
|
query = query.lower()
|
||||||
|
|
||||||
|
context = []
|
||||||
|
|
||||||
|
for node in G.nodes(data=True):
|
||||||
|
|
||||||
|
node_name = node[0].lower()
|
||||||
|
|
||||||
|
if (
|
||||||
|
query in node_name
|
||||||
|
or node_name in query
|
||||||
|
):
|
||||||
|
|
||||||
|
for neighbor in G.neighbors(node[0]):
|
||||||
|
|
||||||
|
edge = G[node[0]][neighbor]
|
||||||
|
|
||||||
|
context.append(
|
||||||
|
f"{node[0]} --{edge.get('relation','related_to')}--> {neighbor}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(context)
|
||||||
|
|
||||||
|
#test
|
||||||
|
from graph_builder import load_graph
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
G = load_graph()
|
||||||
|
|
||||||
|
result = retrieve_subgraph(
|
||||||
|
G,
|
||||||
|
"Python"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result)
|
||||||
50
tools.py
Normal file
50
tools.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from retrieval import retrieve_subgraph
|
||||||
|
from llm import llm_call
|
||||||
|
|
||||||
|
def query_knowledge_graph(question: str, G):
|
||||||
|
|
||||||
|
graph_context = retrieve_subgraph(G, question)
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
You are an educational assistant.
|
||||||
|
|
||||||
|
Use ONLY the graph knowledge below to answer.
|
||||||
|
|
||||||
|
Graph Knowledge:
|
||||||
|
{graph_context}
|
||||||
|
|
||||||
|
Question:
|
||||||
|
{question}
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- If graph is empty, say "No information in knowledge graph"
|
||||||
|
- Be concise and educational
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = llm_call([
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"answer": response,
|
||||||
|
"evidence": graph_context
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_MAP = {
|
||||||
|
"query_knowledge_graph": query_knowledge_graph,
|
||||||
|
}
|
||||||
|
|
||||||
|
#test
|
||||||
|
from graph_builder import load_graph
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
G = load_graph()
|
||||||
|
|
||||||
|
result = query_knowledge_graph(
|
||||||
|
"What is used in machine learning?",
|
||||||
|
G
|
||||||
|
)
|
||||||
|
|
||||||
|
print(result)
|
||||||
31
visualization.py
Normal file
31
visualization.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from pyvis.network import Network
|
||||||
|
from graph_builder import load_graph
|
||||||
|
|
||||||
|
def visualize_graph(G):
|
||||||
|
net = Network(
|
||||||
|
directed=True,
|
||||||
|
notebook=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for node in G.nodes():
|
||||||
|
net.add_node(node)
|
||||||
|
|
||||||
|
for u, v, data in G.edges(data=True):
|
||||||
|
net.add_edge(
|
||||||
|
u,
|
||||||
|
v,
|
||||||
|
label=data.get("relation", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = "kg.html"
|
||||||
|
|
||||||
|
net.write_html(output_path)
|
||||||
|
|
||||||
|
print(f"Graph saved to {output_path}")
|
||||||
|
|
||||||
|
#test
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
G = load_graph()
|
||||||
|
|
||||||
|
visualize_graph(G)
|
||||||
Loading…
Reference in New Issue
Block a user