412 lines
8.9 KiB
Python
412 lines
8.9 KiB
Python
import json
|
|
import os
|
|
import networkx as nx
|
|
import re
|
|
import hashlib
|
|
|
|
from ingest import load_documents
|
|
from llm import llm_call
|
|
|
|
# Cleans response to avoid errors
|
|
def clean_llm_response(text: str) -> str:
|
|
"""
|
|
Removes markdown, extra text, and extracts JSON safely.
|
|
"""
|
|
|
|
if not text:
|
|
raise ValueError("LLM returned empty response")
|
|
|
|
text = text.strip()
|
|
|
|
# Remove ```json or ``` wrappers
|
|
text = re.sub(r"^```json", "", text)
|
|
text = re.sub(r"^```", "", text)
|
|
text = re.sub(r"```$", "", text)
|
|
|
|
text = text.strip()
|
|
|
|
# Try to extract JSON object if model adds extra text
|
|
match = re.search(r"\{.*\}", text, re.DOTALL)
|
|
if match:
|
|
return match.group(0)
|
|
|
|
return text
|
|
|
|
# Gets cached extraction for a chunk, or extracts and caches it
|
|
def get_cached_extraction(chunk, cache_dir="cache"):
|
|
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
|
|
# Create unique hash for chunk
|
|
chunk_hash = hashlib.md5(
|
|
chunk.encode("utf-8")
|
|
).hexdigest()
|
|
|
|
cache_path = os.path.join(
|
|
cache_dir,
|
|
f"{chunk_hash}.json"
|
|
)
|
|
|
|
|
|
# Cache Load (if exists)
|
|
if os.path.exists(cache_path):
|
|
|
|
print(f"[CACHE HIT] {chunk_hash}")
|
|
|
|
with open(cache_path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
# Extract json and store in cache
|
|
print(f"[CACHE MISS] {chunk_hash}")
|
|
|
|
extraction = extract_entities_and_relations(chunk)
|
|
|
|
|
|
# Save cache
|
|
with open(cache_path, "w", encoding="utf-8") as f:
|
|
json.dump(
|
|
extraction,
|
|
f,
|
|
indent=2,
|
|
ensure_ascii=False
|
|
)
|
|
|
|
return extraction
|
|
|
|
# Calls LLM to get entities and realtions (.json) and saves them to cache
|
|
def extract_entities_and_relations(text):
|
|
|
|
text = text[:8000] # prevent context overflow
|
|
|
|
prompt = f"""
|
|
You are an educational knowledge graph extraction system.
|
|
|
|
Your task:
|
|
|
|
Extract educational entities and relations.
|
|
|
|
Rules:
|
|
- Return ONLY EXACT JSON
|
|
- Use ONLY meaningful educational concepts
|
|
- NO markdown
|
|
- NO explanations
|
|
- NO extra text
|
|
- Use these relations ONLY:
|
|
[ "is_a",
|
|
"part_of",
|
|
"uses",
|
|
"teaches",
|
|
"requires",
|
|
"depends_on",
|
|
"implements",
|
|
"applies_to",
|
|
"subfield_of",
|
|
"contains",
|
|
"defined_by"]
|
|
- Entity names must be normalized:
|
|
GOOD:
|
|
- "Machine Learning"
|
|
- "Artificial Intelligence"
|
|
BAD:
|
|
- "machine learning"
|
|
- "ML"
|
|
- "AI systems"
|
|
- Prefer fewer HIGH QUALITY relations over many weak relations
|
|
- Every entity should participate in at least one relation
|
|
|
|
OUTPUT FORMAT EXACTLY:
|
|
{{
|
|
"entities": [
|
|
{{"name": "...", "type": "..."}}
|
|
],
|
|
"relations": [
|
|
{{
|
|
"source": "...",
|
|
"target": "...",
|
|
"relation": "..."
|
|
}}
|
|
]
|
|
}}
|
|
|
|
TEXT:
|
|
{text}
|
|
"""
|
|
|
|
response = llm_call([
|
|
{"role": "user", "content": prompt}
|
|
])
|
|
|
|
cleaned = clean_llm_response(response)
|
|
|
|
try:
|
|
data = json.loads(cleaned)
|
|
|
|
except json.JSONDecodeError:
|
|
print("\nINVALID JSON")
|
|
print(cleaned)
|
|
raise
|
|
|
|
# Checks if llm output is valid
|
|
entities = data.get("entities", [])
|
|
relations = data.get("relations", [])
|
|
|
|
# Normalize entity names
|
|
normalized_entities = {}
|
|
|
|
for ent in entities:
|
|
|
|
name = ent["name"].strip()
|
|
|
|
# title case normalization
|
|
name = " ".join(word.capitalize() for word in name.split())
|
|
|
|
normalized_entities[name] = {
|
|
"name": name,
|
|
"type": ent.get("type", "concept")
|
|
}
|
|
|
|
entity_names = set(normalized_entities.keys())
|
|
|
|
valid_relations = []
|
|
|
|
for rel in relations:
|
|
|
|
source = rel["source"].strip()
|
|
target = rel["target"].strip()
|
|
relation = rel["relation"].strip()
|
|
|
|
source = " ".join(word.capitalize() for word in source.split())
|
|
target = " ".join(word.capitalize() for word in target.split())
|
|
|
|
# Skip self-loops
|
|
if source == target:
|
|
continue
|
|
|
|
# Ensure both entities exist
|
|
if source not in entity_names:
|
|
continue
|
|
|
|
if target not in entity_names:
|
|
continue
|
|
|
|
valid_relations.append({
|
|
"source": source,
|
|
"target": target,
|
|
"relation": relation
|
|
})
|
|
|
|
# Remove isolated entities
|
|
connected = set()
|
|
|
|
for rel in valid_relations:
|
|
connected.add(rel["source"])
|
|
connected.add(rel["target"])
|
|
|
|
final_entities = [
|
|
ent for ent in normalized_entities.values()
|
|
if ent["name"] in connected
|
|
]
|
|
|
|
return {
|
|
"entities": final_entities,
|
|
"relations": valid_relations
|
|
}
|
|
|
|
# Chunks document into chunks (with overlap)
|
|
def chunk_text(text, chunk_size=8000, overlap=300):
|
|
|
|
text = text.replace("\r", " ")
|
|
|
|
paragraphs = text.split("\n")
|
|
|
|
chunks = []
|
|
current = ""
|
|
|
|
for para in paragraphs:
|
|
|
|
para = para.strip()
|
|
|
|
# Skip useless tiny paragraphs
|
|
if len(para) < 40:
|
|
continue
|
|
|
|
if len(current) + len(para) < chunk_size:
|
|
|
|
current += para + "\n"
|
|
|
|
else:
|
|
|
|
chunks.append(current)
|
|
|
|
# overlap keeps context continuity
|
|
current = current[-overlap:] + "\n" + para + "\n"
|
|
|
|
if current:
|
|
chunks.append(current)
|
|
|
|
return chunks
|
|
|
|
# Builds graph from documents (must be in documents folder [.pdf, .txt, .docx])
|
|
def build_graph_from_documents(folder="documents"):
|
|
|
|
G = nx.DiGraph()
|
|
|
|
docs = load_documents(folder)
|
|
|
|
for doc in docs:
|
|
|
|
print(f"\n[INFO] Processing document: {doc['filename']}")
|
|
|
|
# Document is split into chunks
|
|
chunks = chunk_text(doc["content"])
|
|
|
|
print(f"[INFO] Total chunks: {len(chunks)}")
|
|
|
|
# Process chunk by chunk
|
|
for idx, chunk in enumerate(chunks):
|
|
|
|
print(f"[INFO] Processing chunk {idx + 1}/{len(chunks)}")
|
|
|
|
try:
|
|
extraction = get_cached_extraction(chunk)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"[WARN] Extraction failed on chunk {idx + 1}")
|
|
print(e)
|
|
|
|
continue
|
|
|
|
# Add entities (nodes) to graph
|
|
for ent in extraction["entities"]:
|
|
|
|
G.add_node(
|
|
ent["name"],
|
|
type=ent["type"],
|
|
source_doc=doc["filename"]
|
|
)
|
|
|
|
# Adds entity realtions (links nodes)
|
|
for rel in extraction["relations"]:
|
|
|
|
G.add_edge(
|
|
rel["source"],
|
|
rel["target"],
|
|
relation=rel["relation"],
|
|
source_doc=doc["filename"]
|
|
)
|
|
|
|
return G
|
|
|
|
# Saves graph to file
|
|
def save_graph(G, path="graph/kg.graphml"):
|
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
nx.write_graphml(G, path)
|
|
print(f"Graph saved to {path}")
|
|
|
|
# Loads graph from file
|
|
def load_graph(path="graph/kg.graphml"):
|
|
return nx.read_graphml(path)
|
|
|
|
# Adds document to graph
|
|
def add_document(filepath, G):
|
|
|
|
docs = load_documents(folder=filepath)
|
|
|
|
if not docs:
|
|
print(f"[WARN] No documents found in: {filepath}")
|
|
return
|
|
|
|
doc = docs[0]
|
|
|
|
print(f"\n[INFO] Adding document: {doc['filename']}")
|
|
|
|
# Document is split into chunks
|
|
chunks = chunk_text(doc["content"])
|
|
|
|
print(f"[INFO] Total chunks: {len(chunks)}")
|
|
|
|
# Process chunk by chunk
|
|
for idx, chunk in enumerate(chunks):
|
|
|
|
print(f"[INFO] Processing chunk {idx + 1}/{len(chunks)}")
|
|
|
|
try:
|
|
|
|
extraction = get_cached_extraction(chunk)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"[WARN] Extraction failed on chunk {idx + 1}")
|
|
print(e)
|
|
|
|
continue
|
|
|
|
# Add entities (nodes) to graph
|
|
for ent in extraction["entities"]:
|
|
|
|
G.add_node(
|
|
ent["name"],
|
|
type=ent["type"],
|
|
source_doc=doc["filename"]
|
|
)
|
|
|
|
# Adds entity realtions (links nodes)
|
|
for rel in extraction["relations"]:
|
|
|
|
G.add_edge(
|
|
rel["source"],
|
|
rel["target"],
|
|
relation=rel["relation"],
|
|
source_doc=doc["filename"]
|
|
)
|
|
|
|
print(f"[INFO] Finished adding {doc['filename']}")
|
|
|
|
# Removes document from graph
|
|
def remove_document(filename, G):
|
|
|
|
edges_to_remove = []
|
|
|
|
for u, v, data in G.edges(data=True):
|
|
|
|
if data.get("source_doc") == filename:
|
|
edges_to_remove.append((u, v))
|
|
|
|
G.remove_edges_from(edges_to_remove)
|
|
|
|
# Remove orphan nodes
|
|
isolated_nodes = list(nx.isolates(G))
|
|
|
|
G.remove_nodes_from(isolated_nodes)
|
|
|
|
print(f"[INFO] Removed document: {filename}")
|
|
|
|
'''
|
|
#test
|
|
if __name__ == "__main__":
|
|
|
|
docs = load_documents()
|
|
|
|
extraction = extract_entities_and_relations(
|
|
docs[0]["content"]
|
|
)
|
|
|
|
print(extraction)
|
|
'''
|
|
|
|
#test
|
|
if __name__ == "__main__":
|
|
|
|
#G2 = load_graph()
|
|
|
|
#print(G2.nodes(data=True))
|
|
G = build_graph_from_documents()
|
|
|
|
save_graph(G)
|
|
#print(G.nodes(data=True))
|
|
|
|
#print(G.edges(data=True))
|
|
|