GraphRAG Studio — initial commit: multimodal RAG system with KG visualization
Full-stack application for document-to-knowledge-graph pipeline: - Backend: FastAPI + LangGraph ReAct agent + DeepSeek + MinerU parsing - Frontend: React 19 + Vite + D3.js + shadcn/ui - Pipeline: MinerU parsing → LangExtract entity extraction → KG building Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
0
backend/pipeline/__init__.py
Normal file
0
backend/pipeline/__init__.py
Normal file
66
backend/pipeline/entity_extractor.py
Normal file
66
backend/pipeline/entity_extractor.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Entity Extractor — LangExtract + DeepSeek entity extraction.
|
||||
Independent implementation for the GraphRAG Studio backend.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import langextract as lx
|
||||
from langextract.providers.openai import OpenAILanguageModel
|
||||
|
||||
load_dotenv(Path(__file__).parent.parent / ".env", override=True)
|
||||
|
||||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
|
||||
DEEPSEEK_BASE_URL = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
||||
MODEL_ID = "deepseek-chat"
|
||||
|
||||
PROMPT_DESCRIPTION = (
|
||||
"Extract named entities from the text in order of appearance. "
|
||||
"Entity types: TECHNOLOGY (software, algorithms, models, tools), "
|
||||
"ORGANIZATION (companies, research groups, institutions), "
|
||||
"PERSON (individual people), "
|
||||
"LOCATION (places, geographic entities), "
|
||||
"CONCEPT (technical concepts, methodologies, frameworks)."
|
||||
)
|
||||
|
||||
EXAMPLES = [
|
||||
lx.data.ExampleData(
|
||||
text=(
|
||||
"LangChain is a framework created by Harrison Chase for building "
|
||||
"LLM applications. It integrates with OpenAI models and Pinecone "
|
||||
"vector database for semantic search."
|
||||
),
|
||||
extractions=[
|
||||
lx.data.Extraction(extraction_class="TECHNOLOGY", extraction_text="LangChain"),
|
||||
lx.data.Extraction(extraction_class="PERSON", extraction_text="Harrison Chase"),
|
||||
lx.data.Extraction(extraction_class="CONCEPT", extraction_text="LLM applications"),
|
||||
lx.data.Extraction(extraction_class="TECHNOLOGY", extraction_text="OpenAI models"),
|
||||
lx.data.Extraction(extraction_class="TECHNOLOGY", extraction_text="Pinecone"),
|
||||
lx.data.Extraction(extraction_class="CONCEPT", extraction_text="semantic search"),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def create_model() -> OpenAILanguageModel:
|
||||
if not DEEPSEEK_API_KEY:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set in backend/.env")
|
||||
return OpenAILanguageModel(
|
||||
model_id=MODEL_ID,
|
||||
api_key=DEEPSEEK_API_KEY,
|
||||
base_url=DEEPSEEK_BASE_URL,
|
||||
)
|
||||
|
||||
|
||||
def extract_entities(page_text: str, model: OpenAILanguageModel) -> lx.data.AnnotatedDocument:
|
||||
return lx.extract(
|
||||
text_or_documents=page_text,
|
||||
prompt_description=PROMPT_DESCRIPTION,
|
||||
examples=EXAMPLES,
|
||||
model=model,
|
||||
show_progress=False,
|
||||
)
|
||||
123
backend/pipeline/kg_builder.py
Normal file
123
backend/pipeline/kg_builder.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
KG Builder — node deduplication + CO_OCCURS_IN edge generation.
|
||||
Independent implementation for the GraphRAG Studio backend.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import langextract as lx
|
||||
|
||||
from pipeline.text_assembler import PageText
|
||||
|
||||
ACCEPTED_ALIGNMENTS = {"match_exact", "match_greater", "match_lesser"}
|
||||
|
||||
|
||||
def build_kg(
|
||||
pages: list[PageText],
|
||||
annotated_docs: list[lx.data.AnnotatedDocument],
|
||||
source_doc_id: str,
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Build KG nodes and edges from LangExtract results.
|
||||
|
||||
Returns:
|
||||
(nodes, edges) — deduplicated node list and edge list.
|
||||
"""
|
||||
# Phase 1: collect raw entities
|
||||
raw_entities = []
|
||||
for page, doc in zip(pages, annotated_docs):
|
||||
if not doc.extractions:
|
||||
continue
|
||||
for ext in doc.extractions:
|
||||
status = ext.alignment_status.value if ext.alignment_status else None
|
||||
if status not in ACCEPTED_ALIGNMENTS:
|
||||
continue
|
||||
char_start = ext.char_interval.start_pos if ext.char_interval else None
|
||||
char_end = ext.char_interval.end_pos if ext.char_interval else None
|
||||
raw_entities.append({
|
||||
"name": ext.extraction_text,
|
||||
"type": ext.extraction_class,
|
||||
"char_start": char_start,
|
||||
"char_end": char_end,
|
||||
"confidence": status,
|
||||
"page": page.page_idx,
|
||||
"source_doc": source_doc_id,
|
||||
})
|
||||
|
||||
# Phase 2: deduplicate nodes
|
||||
seen: dict[tuple[str, str], int] = {}
|
||||
nodes: list[dict] = []
|
||||
node_pages: dict[int, set[int]] = defaultdict(set)
|
||||
|
||||
for entity in raw_entities:
|
||||
type_prefix = entity["type"].lower()[:4]
|
||||
name_slug = entity["name"].lower().replace(" ", "")[:12]
|
||||
dedup_key = (entity["name"].lower(), entity["type"])
|
||||
if dedup_key not in seen:
|
||||
node_idx = len(nodes)
|
||||
seen[dedup_key] = node_idx
|
||||
nodes.append({
|
||||
"id": f"{type_prefix}_{name_slug}_{node_idx}",
|
||||
"name": entity["name"],
|
||||
"type": entity["type"],
|
||||
"source_doc": entity["source_doc"],
|
||||
"char_start": entity["char_start"],
|
||||
"char_end": entity["char_end"],
|
||||
"confidence": entity["confidence"],
|
||||
"page": entity["page"],
|
||||
})
|
||||
node_idx = seen[dedup_key]
|
||||
node_pages[node_idx].add(entity["page"])
|
||||
|
||||
# Phase 3: CO_OCCURS_IN edges
|
||||
page_nodes: dict[int, list[int]] = defaultdict(list)
|
||||
for node_idx, page_set in node_pages.items():
|
||||
for page_idx in page_set:
|
||||
page_nodes[page_idx].append(node_idx)
|
||||
|
||||
edges: list[dict] = []
|
||||
edge_seen: set[tuple] = set()
|
||||
|
||||
for page_idx, node_indices in sorted(page_nodes.items()):
|
||||
for i in range(len(node_indices)):
|
||||
for j in range(i + 1, len(node_indices)):
|
||||
a = nodes[node_indices[i]]["id"]
|
||||
b = nodes[node_indices[j]]["id"]
|
||||
src, tgt = (a, b) if a < b else (b, a)
|
||||
key = (src, tgt, source_doc_id, page_idx)
|
||||
if key in edge_seen:
|
||||
continue
|
||||
edge_seen.add(key)
|
||||
edges.append({
|
||||
"source": src,
|
||||
"target": tgt,
|
||||
"relation": "CO_OCCURS_IN",
|
||||
"doc_id": source_doc_id,
|
||||
"page": page_idx,
|
||||
})
|
||||
|
||||
return nodes, edges
|
||||
|
||||
|
||||
def extractions_to_records(
|
||||
pages: list[PageText],
|
||||
annotated_docs: list[lx.data.AnnotatedDocument],
|
||||
doc_id: str,
|
||||
) -> list[dict]:
|
||||
"""Flatten LangExtract results to ExtractionRecord dicts."""
|
||||
records = []
|
||||
for page, doc in zip(pages, annotated_docs):
|
||||
if not doc.extractions:
|
||||
continue
|
||||
for ext in doc.extractions:
|
||||
status = ext.alignment_status.value if ext.alignment_status else None
|
||||
records.append({
|
||||
"text": ext.extraction_text,
|
||||
"type": ext.extraction_class,
|
||||
"char_start": ext.char_interval.start_pos if ext.char_interval else None,
|
||||
"char_end": ext.char_interval.end_pos if ext.char_interval else None,
|
||||
"alignment": status,
|
||||
"page": page.page_idx,
|
||||
"doc_id": doc_id,
|
||||
})
|
||||
return records
|
||||
217
backend/pipeline/qa_agent.py
Normal file
217
backend/pipeline/qa_agent.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
QA Agent — LangGraph ReAct agent over the knowledge graph.
|
||||
Independent implementation for the GraphRAG Studio backend.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import networkx as nx
|
||||
from dotenv import load_dotenv
|
||||
from langchain.tools import tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
load_dotenv(Path(__file__).parent.parent / ".env", override=True)
|
||||
|
||||
DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY", "")
|
||||
DEEPSEEK_BASE_URL = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
||||
|
||||
|
||||
def build_kg_graph(nodes: list[dict], edges: list[dict]) -> nx.Graph:
|
||||
G = nx.Graph()
|
||||
for n in nodes:
|
||||
G.add_node(n["id"], **n)
|
||||
for e in edges:
|
||||
G.add_edge(e["source"], e["target"], **{k: v for k, v in e.items() if k not in ("source", "target")})
|
||||
return G
|
||||
|
||||
|
||||
def make_tools(G: nx.Graph) -> list:
|
||||
@tool
|
||||
def search_entities(query: str) -> str:
|
||||
"""Search knowledge graph entities by name (case-insensitive substring).
|
||||
Args:
|
||||
query: Keyword to search for in entity names.
|
||||
"""
|
||||
q = query.lower()
|
||||
matches = [data for _, data in G.nodes(data=True) if q in data.get("name", "").lower()]
|
||||
if not matches:
|
||||
sample = ", ".join(d.get("name", "") for _, d in list(G.nodes(data=True))[:8])
|
||||
return f"No entities found matching '{query}'. Sample: {sample}"
|
||||
lines = [f"Found {len(matches)} entity(ies) matching '{query}':"]
|
||||
for m in matches[:15]:
|
||||
lines.append(
|
||||
f" [{m['type']}] \"{m['name']}\" "
|
||||
f"(confidence={m.get('confidence','?')}, page={m.get('page',0)}, id={m['id']})"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
@tool
|
||||
def get_neighbors(entity_name: str, hops: int = 1) -> str:
|
||||
"""Get N-hop neighbors of an entity in the knowledge graph.
|
||||
Args:
|
||||
entity_name: Entity name (partial match).
|
||||
hops: Number of hops (1-3, default 1).
|
||||
"""
|
||||
hops = max(1, min(int(hops), 3))
|
||||
candidates = [(nid, d) for nid, d in G.nodes(data=True)
|
||||
if entity_name.lower() in d.get("name", "").lower()]
|
||||
if not candidates:
|
||||
return f"Entity '{entity_name}' not found. Use search_entities first."
|
||||
node_id, node_data = candidates[0]
|
||||
reachable = nx.single_source_shortest_path_length(G, node_id, cutoff=hops)
|
||||
by_hop: dict[int, list] = {}
|
||||
for nid, dist in reachable.items():
|
||||
if dist > 0:
|
||||
by_hop.setdefault(dist, []).append(G.nodes[nid])
|
||||
lines = [f"Neighbors of '{node_data['name']}' [{node_data['type']}] within {hops} hop(s):"]
|
||||
for hop in sorted(by_hop.keys()):
|
||||
hop_nodes = by_hop[hop]
|
||||
lines.append(f"\n Hop {hop} — {len(hop_nodes)} related entities:")
|
||||
for n in hop_nodes[:20]:
|
||||
lines.append(f" [{n.get('type','?')}] {n.get('name','?')}")
|
||||
if len(hop_nodes) > 20:
|
||||
lines.append(f" ... and {len(hop_nodes)-20} more")
|
||||
lines.append(f"\n Total related entities: {sum(len(v) for v in by_hop.values())}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@tool
|
||||
def get_entities_by_type(entity_type: str) -> str:
|
||||
"""List all entities of a specific type.
|
||||
Args:
|
||||
entity_type: TECHNOLOGY, CONCEPT, PERSON, ORGANIZATION, or LOCATION.
|
||||
"""
|
||||
t_upper = entity_type.strip().upper()
|
||||
valid = {"TECHNOLOGY", "CONCEPT", "PERSON", "ORGANIZATION", "LOCATION"}
|
||||
if t_upper not in valid:
|
||||
present = sorted({d.get("type","") for _, d in G.nodes(data=True)})
|
||||
return f"Unknown type '{entity_type}'. Present: {present}"
|
||||
matches = [d for _, d in G.nodes(data=True) if d.get("type","") == t_upper]
|
||||
if not matches:
|
||||
return f"No {t_upper} entities found."
|
||||
lines = [f"Found {len(matches)} {t_upper} entities:"]
|
||||
for m in matches[:30]:
|
||||
lines.append(f" \"{m['name']}\" (page={m.get('page',0)}, id={m['id']})")
|
||||
if len(matches) > 30:
|
||||
lines.append(f" ... and {len(matches)-30} more")
|
||||
return "\n".join(lines)
|
||||
|
||||
@tool
|
||||
def describe_graph() -> str:
|
||||
"""Get an overview of the knowledge graph statistics."""
|
||||
n_nodes = G.number_of_nodes()
|
||||
n_edges = G.number_of_edges()
|
||||
type_counts: dict[str, int] = {}
|
||||
for _, d in G.nodes(data=True):
|
||||
t = d.get("type", "UNKNOWN")
|
||||
type_counts[t] = type_counts.get(t, 0) + 1
|
||||
lines = [
|
||||
f"Knowledge Graph Overview:",
|
||||
f" Nodes: {n_nodes}",
|
||||
f" Edges: {n_edges}",
|
||||
f" Entity types: {type_counts}",
|
||||
]
|
||||
if n_nodes > 0:
|
||||
centrality = nx.degree_centrality(G)
|
||||
top5 = sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
|
||||
lines.append(" Top 5 central nodes:")
|
||||
for nid, c in top5:
|
||||
nd = G.nodes[nid]
|
||||
lines.append(f" [{nd.get('type','?')}] {nd.get('name','?')} (centrality={c:.3f})")
|
||||
return "\n".join(lines)
|
||||
|
||||
return [search_entities, get_neighbors, get_entities_by_type, describe_graph]
|
||||
|
||||
|
||||
def run_qa(
|
||||
question: str,
|
||||
history: list[dict],
|
||||
nodes: list[dict],
|
||||
edges: list[dict],
|
||||
) -> dict:
|
||||
"""Run Agentic-RAG QA. Returns dict with answer, tool_calls, cited_nodes."""
|
||||
if not DEEPSEEK_API_KEY:
|
||||
raise ValueError("DEEPSEEK_API_KEY not set in backend/.env")
|
||||
|
||||
G = build_kg_graph(nodes, edges)
|
||||
tools = make_tools(G)
|
||||
|
||||
llm = ChatOpenAI(
|
||||
model="deepseek-chat",
|
||||
api_key=DEEPSEEK_API_KEY,
|
||||
base_url=DEEPSEEK_BASE_URL,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
system_prompt = (
|
||||
"You are a helpful assistant with access to a knowledge graph (KG) built from the user's documents.\n"
|
||||
"\n"
|
||||
"Guidelines:\n"
|
||||
"- If the question is clearly unrelated to the KG (greetings, math, general knowledge, etc.), "
|
||||
"answer directly WITHOUT using any tools.\n"
|
||||
"- If the question might be answered by the KG (topics related to entities in the documents), "
|
||||
"use the tools to search and explore before answering.\n"
|
||||
"- When you DO use the KG, cite the entity names and types you found.\n"
|
||||
"- If the KG has no relevant information, say so honestly and answer from general knowledge if possible.\n"
|
||||
"\n"
|
||||
"Available tools: search entities by name, get neighbors, list entities by type, get graph overview."
|
||||
)
|
||||
|
||||
agent = create_react_agent(llm, tools, prompt=system_prompt)
|
||||
|
||||
# Build messages: system + history + current question
|
||||
messages: list = []
|
||||
for msg in history[-8:]:
|
||||
role = msg.get("role", "human")
|
||||
content = msg.get("content", "") or msg.get("answer", "")
|
||||
if role == "human":
|
||||
messages.append(HumanMessage(content=msg.get("question", content)))
|
||||
else:
|
||||
messages.append(AIMessage(content=content))
|
||||
messages.append(HumanMessage(content=question))
|
||||
|
||||
result = agent.invoke({"messages": messages})
|
||||
|
||||
# Extract answer from last AIMessage
|
||||
answer = ""
|
||||
for msg in reversed(result.get("messages", [])):
|
||||
if isinstance(msg, AIMessage) and msg.content and not msg.tool_calls:
|
||||
answer = msg.content
|
||||
break
|
||||
|
||||
# Extract tool calls and cited node IDs from message history
|
||||
tool_calls = []
|
||||
cited_node_ids: set[str] = set()
|
||||
step = 0
|
||||
all_messages = result.get("messages", [])
|
||||
for i, msg in enumerate(all_messages):
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
step += 1
|
||||
# Find the corresponding ToolMessage
|
||||
output = ""
|
||||
for j in range(i + 1, len(all_messages)):
|
||||
tm = all_messages[j]
|
||||
if isinstance(tm, ToolMessage) and tm.tool_call_id == tc.get("id"):
|
||||
output = tm.content
|
||||
break
|
||||
tool_input = tc.get("args", {})
|
||||
tool_calls.append({
|
||||
"step": step,
|
||||
"tool_name": tc.get("name", ""),
|
||||
"tool_input": str(tool_input),
|
||||
"tool_output": str(output),
|
||||
})
|
||||
# Extract node IDs mentioned in tool output
|
||||
for node_id in re.findall(r'\bid=([^\s,\)\]]+)', str(output)):
|
||||
cited_node_ids.add(node_id)
|
||||
|
||||
return {
|
||||
"answer": answer,
|
||||
"tool_calls": tool_calls,
|
||||
"cited_nodes": list(cited_node_ids),
|
||||
}
|
||||
107
backend/pipeline/text_assembler.py
Normal file
107
backend/pipeline/text_assembler.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Text Assembler — MinerU content_list.json → per-page plain text.
|
||||
Independent implementation for the GraphRAG Studio backend.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BlockSpan:
|
||||
block_index: int
|
||||
block_type: str
|
||||
page_idx: int
|
||||
char_start: int
|
||||
char_end: int
|
||||
bbox: list
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PageText:
|
||||
page_idx: int
|
||||
text: str
|
||||
block_spans: list[BlockSpan]
|
||||
|
||||
|
||||
def html_table_to_text(table_body: str) -> str:
|
||||
soup = BeautifulSoup(table_body, "html.parser")
|
||||
rows = []
|
||||
for tr in soup.find_all("tr"):
|
||||
cells = [td.get_text(strip=True) for td in tr.find_all(["td", "th"])]
|
||||
rows.append(" | ".join(cells))
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
def load_content_list(path: Path) -> list[dict]:
|
||||
if path.is_dir():
|
||||
matches = list(path.glob("*_content_list.json"))
|
||||
if not matches:
|
||||
matches = list(path.glob("*content_list.json"))
|
||||
if not matches:
|
||||
raise FileNotFoundError(f"No content_list.json found in {path}")
|
||||
path = matches[0]
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def assemble_pages(content_list: list[dict]) -> list[PageText]:
|
||||
pages: dict[int, list[tuple[int, dict]]] = defaultdict(list)
|
||||
for i, block in enumerate(content_list):
|
||||
page_idx = block.get("page_idx", 0)
|
||||
pages[page_idx].append((i, block))
|
||||
|
||||
result = []
|
||||
for page_idx in sorted(pages.keys()):
|
||||
blocks = pages[page_idx]
|
||||
buffer = []
|
||||
spans = []
|
||||
cursor = 0
|
||||
|
||||
for block_index, block in blocks:
|
||||
block_type = block.get("type", "unknown")
|
||||
bbox = block.get("bbox", [0, 0, 0, 0])
|
||||
|
||||
if block_type == "text":
|
||||
block_text = block.get("text", "").rstrip()
|
||||
elif block_type == "table":
|
||||
table_body = block.get("table_body", "")
|
||||
block_text = html_table_to_text(table_body) if table_body else ""
|
||||
else:
|
||||
continue
|
||||
|
||||
if not block_text:
|
||||
continue
|
||||
|
||||
char_start = cursor
|
||||
buffer.append(block_text)
|
||||
cursor += len(block_text)
|
||||
char_end = cursor
|
||||
|
||||
spans.append(BlockSpan(
|
||||
block_index=block_index,
|
||||
block_type=block_type,
|
||||
page_idx=page_idx,
|
||||
char_start=char_start,
|
||||
char_end=char_end,
|
||||
bbox=bbox,
|
||||
))
|
||||
buffer.append("\n")
|
||||
cursor += 1
|
||||
|
||||
text = "".join(buffer).rstrip("\n")
|
||||
result.append(PageText(page_idx=page_idx, text=text, block_spans=spans))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def count_blocks_by_type(content_list: list[dict]) -> dict[str, int]:
|
||||
counts: dict[str, int] = defaultdict(int)
|
||||
for block in content_list:
|
||||
counts[block.get("type", "unknown")] += 1
|
||||
return dict(counts)
|
||||
Reference in New Issue
Block a user