DP_PRACA/graph_builder.py
2026-05-16 08:50:22 +02:00

412 lines
8.9 KiB
Python

import json
import os
import networkx as nx
import re
import hashlib
from ingest import load_documents
from llm import llm_call
# Cleans response to avoid errors
def clean_llm_response(text: str) -> str:
"""
Removes markdown, extra text, and extracts JSON safely.
"""
if not text:
raise ValueError("LLM returned empty response")
text = text.strip()
# Remove ```json or ``` wrappers
text = re.sub(r"^```json", "", text)
text = re.sub(r"^```", "", text)
text = re.sub(r"```$", "", text)
text = text.strip()
# Try to extract JSON object if model adds extra text
match = re.search(r"\{.*\}", text, re.DOTALL)
if match:
return match.group(0)
return text
# Gets cached extraction for a chunk, or extracts and caches it
def get_cached_extraction(chunk, cache_dir="cache"):
os.makedirs(cache_dir, exist_ok=True)
# Create unique hash for chunk
chunk_hash = hashlib.md5(
chunk.encode("utf-8")
).hexdigest()
cache_path = os.path.join(
cache_dir,
f"{chunk_hash}.json"
)
# Cache Load (if exists)
if os.path.exists(cache_path):
print(f"[CACHE HIT] {chunk_hash}")
with open(cache_path, "r", encoding="utf-8") as f:
return json.load(f)
# Extract json and store in cache
print(f"[CACHE MISS] {chunk_hash}")
extraction = extract_entities_and_relations(chunk)
# Save cache
with open(cache_path, "w", encoding="utf-8") as f:
json.dump(
extraction,
f,
indent=2,
ensure_ascii=False
)
return extraction
# Calls LLM to get entities and realtions (.json) and saves them to cache
def extract_entities_and_relations(text):
text = text[:8000] # prevent context overflow
prompt = f"""
You are an educational knowledge graph extraction system.
Your task:
Extract educational entities and relations.
Rules:
- Return ONLY EXACT JSON
- Use ONLY meaningful educational concepts
- NO markdown
- NO explanations
- NO extra text
- Use these relations ONLY:
[ "is_a",
"part_of",
"uses",
"teaches",
"requires",
"depends_on",
"implements",
"applies_to",
"subfield_of",
"contains",
"defined_by"]
- Entity names must be normalized:
GOOD:
- "Machine Learning"
- "Artificial Intelligence"
BAD:
- "machine learning"
- "ML"
- "AI systems"
- Prefer fewer HIGH QUALITY relations over many weak relations
- Every entity should participate in at least one relation
OUTPUT FORMAT EXACTLY:
{{
"entities": [
{{"name": "...", "type": "..."}}
],
"relations": [
{{
"source": "...",
"target": "...",
"relation": "..."
}}
]
}}
TEXT:
{text}
"""
response = llm_call([
{"role": "user", "content": prompt}
])
cleaned = clean_llm_response(response)
try:
data = json.loads(cleaned)
except json.JSONDecodeError:
print("\nINVALID JSON")
print(cleaned)
raise
# Checks if llm output is valid
entities = data.get("entities", [])
relations = data.get("relations", [])
# Normalize entity names
normalized_entities = {}
for ent in entities:
name = ent["name"].strip()
# title case normalization
name = " ".join(word.capitalize() for word in name.split())
normalized_entities[name] = {
"name": name,
"type": ent.get("type", "concept")
}
entity_names = set(normalized_entities.keys())
valid_relations = []
for rel in relations:
source = rel["source"].strip()
target = rel["target"].strip()
relation = rel["relation"].strip()
source = " ".join(word.capitalize() for word in source.split())
target = " ".join(word.capitalize() for word in target.split())
# Skip self-loops
if source == target:
continue
# Ensure both entities exist
if source not in entity_names:
continue
if target not in entity_names:
continue
valid_relations.append({
"source": source,
"target": target,
"relation": relation
})
# Remove isolated entities
connected = set()
for rel in valid_relations:
connected.add(rel["source"])
connected.add(rel["target"])
final_entities = [
ent for ent in normalized_entities.values()
if ent["name"] in connected
]
return {
"entities": final_entities,
"relations": valid_relations
}
# Chunks document into chunks (with overlap)
def chunk_text(text, chunk_size=8000, overlap=300):
text = text.replace("\r", " ")
paragraphs = text.split("\n")
chunks = []
current = ""
for para in paragraphs:
para = para.strip()
# Skip useless tiny paragraphs
if len(para) < 40:
continue
if len(current) + len(para) < chunk_size:
current += para + "\n"
else:
chunks.append(current)
# overlap keeps context continuity
current = current[-overlap:] + "\n" + para + "\n"
if current:
chunks.append(current)
return chunks
# Builds graph from documents (must be in documents folder [.pdf, .txt, .docx])
def build_graph_from_documents(folder="documents"):
G = nx.DiGraph()
docs = load_documents(folder)
for doc in docs:
print(f"\n[INFO] Processing document: {doc['filename']}")
# Document is split into chunks
chunks = chunk_text(doc["content"])
print(f"[INFO] Total chunks: {len(chunks)}")
# Process chunk by chunk
for idx, chunk in enumerate(chunks):
print(f"[INFO] Processing chunk {idx + 1}/{len(chunks)}")
try:
extraction = get_cached_extraction(chunk)
except Exception as e:
print(f"[WARN] Extraction failed on chunk {idx + 1}")
print(e)
continue
# Add entities (nodes) to graph
for ent in extraction["entities"]:
G.add_node(
ent["name"],
type=ent["type"],
source_doc=doc["filename"]
)
# Adds entity realtions (links nodes)
for rel in extraction["relations"]:
G.add_edge(
rel["source"],
rel["target"],
relation=rel["relation"],
source_doc=doc["filename"]
)
return G
# Saves graph to file
def save_graph(G, path="graph/kg.graphml"):
os.makedirs(os.path.dirname(path), exist_ok=True)
nx.write_graphml(G, path)
print(f"Graph saved to {path}")
# Loads graph from file
def load_graph(path="graph/kg.graphml"):
return nx.read_graphml(path)
# Adds document to graph
def add_document(filepath, G):
docs = load_documents(folder=filepath)
if not docs:
print(f"[WARN] No documents found in: {filepath}")
return
doc = docs[0]
print(f"\n[INFO] Adding document: {doc['filename']}")
# Document is split into chunks
chunks = chunk_text(doc["content"])
print(f"[INFO] Total chunks: {len(chunks)}")
# Process chunk by chunk
for idx, chunk in enumerate(chunks):
print(f"[INFO] Processing chunk {idx + 1}/{len(chunks)}")
try:
extraction = get_cached_extraction(chunk)
except Exception as e:
print(f"[WARN] Extraction failed on chunk {idx + 1}")
print(e)
continue
# Add entities (nodes) to graph
for ent in extraction["entities"]:
G.add_node(
ent["name"],
type=ent["type"],
source_doc=doc["filename"]
)
# Adds entity realtions (links nodes)
for rel in extraction["relations"]:
G.add_edge(
rel["source"],
rel["target"],
relation=rel["relation"],
source_doc=doc["filename"]
)
print(f"[INFO] Finished adding {doc['filename']}")
# Removes document from graph
def remove_document(filename, G):
edges_to_remove = []
for u, v, data in G.edges(data=True):
if data.get("source_doc") == filename:
edges_to_remove.append((u, v))
G.remove_edges_from(edges_to_remove)
# Remove orphan nodes
isolated_nodes = list(nx.isolates(G))
G.remove_nodes_from(isolated_nodes)
print(f"[INFO] Removed document: {filename}")
'''
#test
if __name__ == "__main__":
docs = load_documents()
extraction = extract_entities_and_relations(
docs[0]["content"]
)
print(extraction)
'''
#test
if __name__ == "__main__":
#G2 = load_graph()
#print(G2.nodes(data=True))
G = build_graph_from_documents()
save_graph(G)
#print(G.nodes(data=True))
#print(G.edges(data=True))