GraphRAG Studio — initial commit: multimodal RAG system with KG visualization
Full-stack application for document-to-knowledge-graph pipeline: - Backend: FastAPI + LangGraph ReAct agent + DeepSeek + MinerU parsing - Frontend: React 19 + Vite + D3.js + shadcn/ui - Pipeline: MinerU parsing → LangExtract entity extraction → KG building Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
0
backend/services/__init__.py
Normal file
0
backend/services/__init__.py
Normal file
109
backend/services/document_service.py
Normal file
109
backend/services/document_service.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""Document Service — file upload, metadata CRUD."""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from storage import file_store as fs
|
||||
|
||||
ALLOWED_EXTENSIONS = {"pdf", "docx", "doc", "pptx", "ppt", "png", "jpg", "jpeg", "html"}
|
||||
MAX_FILE_SIZE_MB = 200
|
||||
|
||||
|
||||
def validate_upload(filename: str, size_bytes: int) -> tuple[bool, int, str]:
|
||||
"""Returns (ok, error_code, error_msg)."""
|
||||
if not filename or "/" in filename or "\\" in filename:
|
||||
return False, 1001, "Invalid filename"
|
||||
ext = Path(filename).suffix.lower().lstrip(".")
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
return False, 1002, f"Unsupported file format: .{ext}. Supported: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
if size_mb > MAX_FILE_SIZE_MB:
|
||||
return False, 1003, f"File size {size_mb:.1f}MB exceeds {MAX_FILE_SIZE_MB}MB limit"
|
||||
return True, 0, ""
|
||||
|
||||
|
||||
def save_upload(filename: str, content: bytes, language: str = "ch",
|
||||
enable_formula: bool = True, enable_table: bool = True) -> dict:
|
||||
doc_id = uuid.uuid4().hex[:8]
|
||||
ext = Path(filename).suffix.lower().lstrip(".")
|
||||
upload_filename = f"{doc_id}_{filename}"
|
||||
upload_path = fs.UPLOADS_DIR / upload_filename
|
||||
upload_path.write_bytes(content)
|
||||
|
||||
doc = {
|
||||
"doc_id": doc_id,
|
||||
"filename": filename,
|
||||
"format": ext,
|
||||
"size_bytes": len(content),
|
||||
"pages": None,
|
||||
"uploaded_at": datetime.now(timezone.utc).isoformat(),
|
||||
"status": "uploaded",
|
||||
"language": language,
|
||||
"enable_formula": enable_formula,
|
||||
"enable_table": enable_table,
|
||||
"upload_filename": upload_filename, # internal: actual stored filename
|
||||
}
|
||||
fs.save_doc(doc)
|
||||
return doc
|
||||
|
||||
|
||||
def get_document(doc_id: str) -> dict | None:
|
||||
return fs.get_doc(doc_id)
|
||||
|
||||
|
||||
def list_documents(page: int = 1, page_size: int = 20,
|
||||
status: str | None = None, fmt: str | None = None) -> dict:
|
||||
index = fs.load_docs_index()
|
||||
items = list(index.values())
|
||||
items.sort(key=lambda d: d.get("uploaded_at", ""), reverse=True)
|
||||
if status:
|
||||
items = [d for d in items if d.get("status") == status]
|
||||
if fmt:
|
||||
items = [d for d in items if d.get("format") == fmt.lower()]
|
||||
total = len(items)
|
||||
start = (page - 1) * page_size
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"items": items[start: start + page_size],
|
||||
}
|
||||
|
||||
|
||||
def delete_document(doc_id: str) -> tuple[bool, int, int]:
|
||||
"""Delete doc and its KG contributions. Returns (ok, removed_nodes, removed_edges)."""
|
||||
doc = fs.get_doc(doc_id)
|
||||
if not doc:
|
||||
return False, 0, 0
|
||||
|
||||
# Remove from KG
|
||||
removed_nodes, removed_edges = fs.remove_doc_from_kg(doc_id)
|
||||
|
||||
# Remove upload file
|
||||
upload_filename = doc.get("upload_filename", "")
|
||||
upload_path = fs.UPLOADS_DIR / upload_filename
|
||||
if upload_path.exists():
|
||||
upload_path.unlink(missing_ok=True)
|
||||
|
||||
# Remove associated jobs
|
||||
for meta in fs.list_all_jobs():
|
||||
if meta.get("doc_id") == doc_id:
|
||||
fs.delete_job(meta["job_id"])
|
||||
|
||||
# Remove from index
|
||||
index = fs.load_docs_index()
|
||||
index.pop(doc_id, None)
|
||||
fs.save_docs_index(index)
|
||||
|
||||
return True, removed_nodes, removed_edges
|
||||
|
||||
|
||||
def update_doc_status(doc_id: str, status: str, pages: int | None = None) -> None:
|
||||
index = fs.load_docs_index()
|
||||
if doc_id in index:
|
||||
index[doc_id]["status"] = status
|
||||
if pages is not None:
|
||||
index[doc_id]["pages"] = pages
|
||||
fs.save_docs_index(index)
|
||||
255
backend/services/indexing_service.py
Normal file
255
backend/services/indexing_service.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Indexing Service — Pipeline orchestration (parsing → extracting → indexing)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from storage import file_store as fs
|
||||
from services.document_service import update_doc_status
|
||||
|
||||
load_dotenv(Path(__file__).parent.parent / ".env", override=True)
|
||||
|
||||
MINERU_PYTHON = Path(os.getenv("MINERU_PYTHON", "F:/GraphRAGAgent/mineru_mvp/.venv/Scripts/python.exe"))
|
||||
MINERU_PIPELINE = Path(os.getenv("MINERU_PIPELINE", "F:/GraphRAGAgent/mineru_mvp/pipeline.py"))
|
||||
|
||||
# In-memory registry of active jobs {job_id: threading.Thread}
|
||||
_active_threads: dict[str, threading.Thread] = {}
|
||||
_cancel_flags: dict[str, bool] = {}
|
||||
|
||||
|
||||
def start_indexing(doc_id: str) -> dict:
|
||||
doc = fs.get_doc(doc_id)
|
||||
if not doc:
|
||||
return None # type: ignore
|
||||
|
||||
job_id = f"job_{uuid.uuid4().hex[:8]}"
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
meta = {
|
||||
"job_id": job_id,
|
||||
"doc_id": doc_id,
|
||||
"status": "submitted",
|
||||
"stage": "Job submitted",
|
||||
"progress": {"parsed_pages": 0, "total_pages": 0, "extracted_entities": 0},
|
||||
"created_at": now,
|
||||
"elapsed_seconds": 0.0,
|
||||
"error": None,
|
||||
"pdf_name": doc["filename"],
|
||||
"pdf_path": str(fs.UPLOADS_DIR / doc.get("upload_filename", "")),
|
||||
}
|
||||
fs.save_job_meta(job_id, meta)
|
||||
|
||||
_cancel_flags[job_id] = False
|
||||
thread = threading.Thread(target=_run_pipeline, args=(job_id,), daemon=True)
|
||||
_active_threads[job_id] = thread
|
||||
thread.start()
|
||||
|
||||
return meta
|
||||
|
||||
|
||||
def _update_meta(job_id: str, **kwargs) -> None:
|
||||
meta = fs.load_job_meta(job_id) or {}
|
||||
meta.update(kwargs)
|
||||
meta["elapsed_seconds"] = round(
|
||||
(datetime.now(timezone.utc) - datetime.fromisoformat(meta["created_at"])).total_seconds(), 1
|
||||
)
|
||||
fs.save_job_meta(job_id, meta)
|
||||
|
||||
|
||||
def _run_pipeline(job_id: str) -> None:
|
||||
meta = fs.load_job_meta(job_id)
|
||||
if not meta:
|
||||
return
|
||||
|
||||
doc_id = meta["doc_id"]
|
||||
pdf_path = Path(meta["pdf_path"])
|
||||
job_dir = fs.job_dir(job_id)
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# ── Stage 1: parsing ──────────────────────────────────────────────
|
||||
if _cancel_flags.get(job_id):
|
||||
_update_meta(job_id, status="cancelled", stage="Cancelled")
|
||||
return
|
||||
|
||||
_update_meta(job_id, status="parsing", stage="MinerU document parsing...")
|
||||
mineru_out_dir = job_dir / "mineru_output"
|
||||
mineru_out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result = subprocess.run(
|
||||
[str(MINERU_PYTHON), str(MINERU_PIPELINE), str(pdf_path)],
|
||||
cwd=str(MINERU_PIPELINE.parent),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"MinerU failed: {result.stderr[:500]}")
|
||||
|
||||
# Find content_list.json in MinerU output
|
||||
# MinerU writes output to mineru_mvp/output/{stem}/
|
||||
stem = pdf_path.stem
|
||||
mineru_default_out = MINERU_PIPELINE.parent / "output" / stem
|
||||
content_list_path = None
|
||||
|
||||
if mineru_default_out.exists():
|
||||
matches = list(mineru_default_out.glob("*_content_list.json"))
|
||||
if matches:
|
||||
content_list_path = matches[0]
|
||||
# Copy to our job dir
|
||||
import shutil
|
||||
shutil.copytree(str(mineru_default_out), str(mineru_out_dir), dirs_exist_ok=True)
|
||||
|
||||
if not content_list_path:
|
||||
# Fallback: search job mineru_output dir
|
||||
matches = list(mineru_out_dir.glob("*_content_list.json"))
|
||||
if matches:
|
||||
content_list_path = matches[0]
|
||||
|
||||
if not content_list_path or not content_list_path.exists():
|
||||
raise RuntimeError(f"MinerU output content_list.json not found. stdout: {result.stdout[:300]}")
|
||||
|
||||
# ── Stage 2: extracting ───────────────────────────────────────────
|
||||
if _cancel_flags.get(job_id):
|
||||
_update_meta(job_id, status="cancelled", stage="Cancelled")
|
||||
return
|
||||
|
||||
from pipeline.text_assembler import load_content_list, assemble_pages, count_blocks_by_type
|
||||
from pipeline.entity_extractor import create_model, extract_entities
|
||||
from pipeline.kg_builder import build_kg, extractions_to_records
|
||||
|
||||
content_list = load_content_list(content_list_path)
|
||||
pages = assemble_pages(content_list)
|
||||
total_pages = len(pages)
|
||||
block_types = count_blocks_by_type(content_list)
|
||||
|
||||
_update_meta(
|
||||
job_id,
|
||||
status="extracting",
|
||||
stage=f"Extracting entities (LangExtract + DeepSeek)...",
|
||||
progress={"parsed_pages": total_pages, "total_pages": total_pages, "extracted_entities": 0},
|
||||
)
|
||||
update_doc_status(doc_id, "indexing", pages=total_pages)
|
||||
|
||||
model = create_model()
|
||||
annotated_docs = []
|
||||
total_entities = 0
|
||||
|
||||
for i, page in enumerate(pages):
|
||||
if _cancel_flags.get(job_id):
|
||||
_update_meta(job_id, status="cancelled", stage="Cancelled")
|
||||
return
|
||||
|
||||
_update_meta(
|
||||
job_id,
|
||||
stage=f"Extracting entities page {i+1}/{total_pages} (LangExtract + DeepSeek)...",
|
||||
progress={"parsed_pages": total_pages, "total_pages": total_pages,
|
||||
"extracted_entities": total_entities},
|
||||
)
|
||||
ann_doc = extract_entities(page.text, model)
|
||||
annotated_docs.append(ann_doc)
|
||||
total_entities += len(ann_doc.extractions) if ann_doc.extractions else 0
|
||||
|
||||
# Save raw extractions
|
||||
records = extractions_to_records(pages, annotated_docs, doc_id)
|
||||
fs.write_json(job_dir / "extractions.json", records)
|
||||
|
||||
# ── Stage 3: indexing ─────────────────────────────────────────────
|
||||
_update_meta(job_id, status="indexing", stage="Building knowledge graph...")
|
||||
|
||||
nodes, edges = build_kg(pages, annotated_docs, doc_id)
|
||||
fs.write_json(job_dir / "kg_nodes.json", nodes)
|
||||
fs.write_json(job_dir / "kg_edges.json", edges)
|
||||
|
||||
# Merge into global KG
|
||||
fs.merge_kg(nodes, edges, doc_id)
|
||||
|
||||
# Count alignment types
|
||||
alignment_counts: dict[str, int] = {}
|
||||
type_counts: dict[str, int] = {}
|
||||
for r in records:
|
||||
al = r.get("alignment") or "null"
|
||||
alignment_counts[al] = alignment_counts.get(al, 0) + 1
|
||||
t = r.get("type", "UNKNOWN")
|
||||
type_counts[t] = type_counts.get(t, 0) + 1
|
||||
|
||||
elapsed = round(time.time() - start_time, 1)
|
||||
stats = {
|
||||
"blocks": len(content_list),
|
||||
"block_types": block_types,
|
||||
"pages": total_pages,
|
||||
"raw_extractions": len(records),
|
||||
"nodes": len(nodes),
|
||||
"edges": len(edges),
|
||||
"type_counts": type_counts,
|
||||
"alignment_counts": alignment_counts,
|
||||
"elapsed_seconds": elapsed,
|
||||
}
|
||||
fs.write_json(job_dir / "stats.json", stats)
|
||||
|
||||
_update_meta(
|
||||
job_id,
|
||||
status="done",
|
||||
stage="Complete",
|
||||
progress={"parsed_pages": total_pages, "total_pages": total_pages,
|
||||
"extracted_entities": len(records)},
|
||||
)
|
||||
update_doc_status(doc_id, "indexed", pages=total_pages)
|
||||
|
||||
except Exception as exc:
|
||||
_update_meta(job_id, status="failed", stage=f"Error: {exc}", error=str(exc))
|
||||
update_doc_status(doc_id, "failed")
|
||||
finally:
|
||||
_active_threads.pop(job_id, None)
|
||||
_cancel_flags.pop(job_id, None)
|
||||
|
||||
|
||||
def get_job_status(job_id: str) -> dict | None:
|
||||
return fs.load_job_meta(job_id)
|
||||
|
||||
|
||||
def get_job_result(job_id: str) -> dict | None:
|
||||
meta = fs.load_job_meta(job_id)
|
||||
if not meta:
|
||||
return None
|
||||
if meta["status"] != "done":
|
||||
return meta
|
||||
|
||||
job_dir = fs.job_dir(job_id)
|
||||
stats = fs.read_json(job_dir / "stats.json") or {}
|
||||
extractions = fs.read_json(job_dir / "extractions.json") or []
|
||||
nodes = fs.read_json(job_dir / "kg_nodes.json") or []
|
||||
edges = fs.read_json(job_dir / "kg_edges.json") or []
|
||||
|
||||
return {
|
||||
"job_id": meta["job_id"],
|
||||
"doc_id": meta["doc_id"],
|
||||
"status": "done",
|
||||
"stats": stats,
|
||||
"extractions": extractions,
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
}
|
||||
|
||||
|
||||
def cancel_job(job_id: str) -> tuple[bool, str]:
|
||||
meta = fs.load_job_meta(job_id)
|
||||
if not meta:
|
||||
return False, "not_found"
|
||||
prev_status = meta["status"]
|
||||
_cancel_flags[job_id] = True
|
||||
_update_meta(job_id, status="cancelled", stage="Cancelled by user")
|
||||
return True, prev_status
|
||||
|
||||
|
||||
def count_active_jobs() -> int:
|
||||
return sum(1 for t in _active_threads.values() if t.is_alive())
|
||||
167
backend/services/kg_service.py
Normal file
167
backend/services/kg_service.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""KG Service — NetworkX graph operations over the global KG."""
|
||||
from __future__ import annotations
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from storage import file_store as fs
|
||||
|
||||
|
||||
def _load_graph() -> nx.Graph:
|
||||
nodes = fs.load_kg_nodes()
|
||||
edges = fs.load_kg_edges()
|
||||
G = nx.Graph()
|
||||
for n in nodes:
|
||||
G.add_node(n["id"], **n)
|
||||
for e in edges:
|
||||
G.add_edge(e["source"], e["target"],
|
||||
relation=e.get("relation", "CO_OCCURS_IN"),
|
||||
doc_id=e.get("doc_id", ""),
|
||||
page=e.get("page", 0))
|
||||
return G
|
||||
|
||||
|
||||
def get_nodes(page: int = 1, page_size: int = 50,
|
||||
node_type: str | None = None,
|
||||
doc_id: str | None = None,
|
||||
confidence: str | None = None) -> dict:
|
||||
nodes = fs.load_kg_nodes()
|
||||
G = _load_graph()
|
||||
# Attach degree
|
||||
degrees = dict(G.degree())
|
||||
for n in nodes:
|
||||
n["degree"] = degrees.get(n["id"], 0)
|
||||
|
||||
if node_type:
|
||||
nodes = [n for n in nodes if n.get("type", "").upper() == node_type.upper()]
|
||||
if doc_id:
|
||||
nodes = [n for n in nodes if n.get("source_doc") == doc_id]
|
||||
if confidence:
|
||||
nodes = [n for n in nodes if n.get("confidence") == confidence]
|
||||
|
||||
total = len(nodes)
|
||||
start = (page - 1) * page_size
|
||||
return {"total": total, "page": page, "page_size": page_size,
|
||||
"items": nodes[start: start + page_size]}
|
||||
|
||||
|
||||
def get_edges(page: int = 1, page_size: int = 100,
|
||||
doc_id: str | None = None,
|
||||
relation: str | None = None) -> dict:
|
||||
edges = fs.load_kg_edges()
|
||||
if doc_id:
|
||||
edges = [e for e in edges if e.get("doc_id") == doc_id]
|
||||
if relation:
|
||||
edges = [e for e in edges if e.get("relation") == relation]
|
||||
total = len(edges)
|
||||
start = (page - 1) * page_size
|
||||
return {"total": total, "page": page, "page_size": page_size,
|
||||
"items": edges[start: start + page_size]}
|
||||
|
||||
|
||||
def get_node_detail(node_id: str) -> dict | None:
|
||||
nodes = fs.load_kg_nodes()
|
||||
node = next((n for n in nodes if n["id"] == node_id), None)
|
||||
if not node:
|
||||
return None
|
||||
G = _load_graph()
|
||||
if node_id not in G:
|
||||
node["degree"] = 0
|
||||
node["degree_centrality"] = 0.0
|
||||
node["neighbor_count"] = 0
|
||||
return node
|
||||
deg = G.degree(node_id)
|
||||
centrality = nx.degree_centrality(G)
|
||||
node["degree"] = deg
|
||||
node["degree_centrality"] = round(centrality.get(node_id, 0.0), 4)
|
||||
node["neighbor_count"] = deg
|
||||
return node
|
||||
|
||||
|
||||
def get_neighbors(node_id: str, hops: int = 1) -> dict | None:
|
||||
nodes = fs.load_kg_nodes()
|
||||
node = next((n for n in nodes if n["id"] == node_id), None)
|
||||
if not node:
|
||||
return None
|
||||
G = _load_graph()
|
||||
if node_id not in G:
|
||||
return {
|
||||
"center": {"id": node_id, "name": node["name"], "type": node["type"], "page": node.get("page", 0)},
|
||||
"hops": hops, "neighbors_by_hop": {}, "total_neighbors": 0,
|
||||
}
|
||||
hops = max(1, min(hops, 3))
|
||||
reachable = nx.single_source_shortest_path_length(G, node_id, cutoff=hops)
|
||||
by_hop: dict[str, list] = {}
|
||||
for nid, dist in reachable.items():
|
||||
if dist == 0:
|
||||
continue
|
||||
nd = G.nodes[nid]
|
||||
by_hop.setdefault(str(dist), []).append({
|
||||
"id": nid, "name": nd.get("name", ""), "type": nd.get("type", ""), "page": nd.get("page", 0)
|
||||
})
|
||||
total = sum(len(v) for v in by_hop.values())
|
||||
return {
|
||||
"center": {"id": node_id, "name": node["name"], "type": node["type"], "page": node.get("page", 0)},
|
||||
"hops": hops,
|
||||
"neighbors_by_hop": by_hop,
|
||||
"total_neighbors": total,
|
||||
}
|
||||
|
||||
|
||||
def get_stats() -> dict:
|
||||
nodes = fs.load_kg_nodes()
|
||||
edges = fs.load_kg_edges()
|
||||
G = _load_graph()
|
||||
|
||||
type_dist: dict[str, int] = {}
|
||||
for n in nodes:
|
||||
t = n.get("type", "UNKNOWN")
|
||||
type_dist[t] = type_dist.get(t, 0) + 1
|
||||
|
||||
relation_types: dict[str, int] = {}
|
||||
for e in edges:
|
||||
r = e.get("relation", "CO_OCCURS_IN")
|
||||
relation_types[r] = relation_types.get(r, 0) + 1
|
||||
|
||||
density = round(nx.density(G), 4) if G.number_of_nodes() > 1 else 0.0
|
||||
|
||||
top5: list[dict] = []
|
||||
if G.number_of_nodes() > 0:
|
||||
centrality = nx.degree_centrality(G)
|
||||
for nid, c in sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||
nd = G.nodes[nid]
|
||||
top5.append({"node_id": nid, "name": nd.get("name", ""), "type": nd.get("type", ""),
|
||||
"centrality": round(c, 4)})
|
||||
|
||||
source_docs = list({n.get("source_doc", "") for n in nodes if n.get("source_doc")})
|
||||
|
||||
return {
|
||||
"total_nodes": len(nodes),
|
||||
"total_edges": len(edges),
|
||||
"density": density,
|
||||
"type_distribution": type_dist,
|
||||
"relation_types": relation_types,
|
||||
"top5_central_nodes": top5,
|
||||
"source_documents": source_docs,
|
||||
}
|
||||
|
||||
|
||||
def export_kg(doc_id: str | None = None) -> dict:
|
||||
from datetime import datetime, timezone
|
||||
nodes = fs.load_kg_nodes()
|
||||
edges = fs.load_kg_edges()
|
||||
G = _load_graph()
|
||||
degrees = dict(G.degree())
|
||||
for n in nodes:
|
||||
n["degree"] = degrees.get(n["id"], 0)
|
||||
if doc_id:
|
||||
nodes = [n for n in nodes if n.get("source_doc") == doc_id]
|
||||
edges = [e for e in edges if e.get("doc_id") == doc_id]
|
||||
return {
|
||||
"format": "json",
|
||||
"doc_id": doc_id,
|
||||
"total_nodes": len(nodes),
|
||||
"total_edges": len(edges),
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
}
|
||||
85
backend/services/qa_service.py
Normal file
85
backend/services/qa_service.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""QA Service — Agentic-RAG wrapper."""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from storage import file_store as fs
|
||||
|
||||
|
||||
def run_query(question: str, history: list[dict]) -> dict:
|
||||
from pipeline.qa_agent import run_qa
|
||||
|
||||
nodes = fs.load_kg_nodes()
|
||||
edges = fs.load_kg_edges()
|
||||
|
||||
if not nodes:
|
||||
raise ValueError("KG_EMPTY")
|
||||
|
||||
start = time.time()
|
||||
result = run_qa(question, history, nodes, edges)
|
||||
elapsed = round(time.time() - start, 2)
|
||||
|
||||
query_id = f"q_{uuid.uuid4().hex[:10]}"
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
record = {
|
||||
"id": query_id,
|
||||
"question": question,
|
||||
"answer": result["answer"],
|
||||
"tool_calls": result["tool_calls"],
|
||||
"cited_nodes": result["cited_nodes"],
|
||||
"duration_seconds": elapsed,
|
||||
"timestamp": now,
|
||||
}
|
||||
fs.append_query_history(record)
|
||||
return record
|
||||
|
||||
|
||||
def get_history(page: int = 1, page_size: int = 20) -> dict:
|
||||
all_records = fs.load_query_history()
|
||||
total = len(all_records)
|
||||
start = (page - 1) * page_size
|
||||
return {
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"items": all_records[start: start + page_size],
|
||||
}
|
||||
|
||||
|
||||
def start_batch(questions: list[str]) -> dict:
|
||||
import threading
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:10]}"
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
meta = {
|
||||
"batch_id": batch_id,
|
||||
"total": len(questions),
|
||||
"completed": 0,
|
||||
"failed": 0,
|
||||
"status": "submitted",
|
||||
"created_at": now,
|
||||
"results": [],
|
||||
}
|
||||
fs.save_batch_meta(batch_id, meta)
|
||||
|
||||
def _run():
|
||||
for q in questions:
|
||||
try:
|
||||
res = run_query(q, [])
|
||||
meta["results"].append(res)
|
||||
meta["completed"] += 1
|
||||
except Exception as e:
|
||||
meta["failed"] += 1
|
||||
meta["results"].append({"question": q, "error": str(e)})
|
||||
meta["status"] = "done"
|
||||
fs.save_batch_meta(batch_id, meta)
|
||||
|
||||
threading.Thread(target=_run, daemon=True).start()
|
||||
return {"batch_id": batch_id, "total": len(questions), "status": "submitted", "created_at": now}
|
||||
|
||||
|
||||
def get_batch_result(batch_id: str) -> dict | None:
|
||||
return fs.load_batch_meta(batch_id)
|
||||
106
backend/services/search_service.py
Normal file
106
backend/services/search_service.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Search Service — entity, path, and graph search."""
|
||||
from __future__ import annotations
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from storage import file_store as fs
|
||||
|
||||
|
||||
def _load_graph() -> nx.Graph:
|
||||
nodes = fs.load_kg_nodes()
|
||||
edges = fs.load_kg_edges()
|
||||
G = nx.Graph()
|
||||
for n in nodes:
|
||||
G.add_node(n["id"], **n)
|
||||
for e in edges:
|
||||
G.add_edge(e["source"], e["target"],
|
||||
relation=e.get("relation", "CO_OCCURS_IN"),
|
||||
doc_id=e.get("doc_id", ""), page=e.get("page", 0))
|
||||
return G
|
||||
|
||||
|
||||
def search_entities(q: str, entity_type: str | None = None, limit: int = 15) -> dict:
|
||||
nodes = fs.load_kg_nodes()
|
||||
G = _load_graph()
|
||||
degrees = dict(G.degree())
|
||||
q_lower = q.lower()
|
||||
matches = [n for n in nodes if q_lower in n.get("name", "").lower()]
|
||||
if entity_type:
|
||||
matches = [n for n in matches if n.get("type", "").upper() == entity_type.upper()]
|
||||
for n in matches:
|
||||
n["degree"] = degrees.get(n["id"], 0)
|
||||
matches = matches[:limit]
|
||||
return {"query": q, "total": len(matches), "items": matches}
|
||||
|
||||
|
||||
def search_path(from_id: str, to_id: str, max_hops: int = 3) -> dict | None:
|
||||
nodes = fs.load_kg_nodes()
|
||||
node_map = {n["id"]: n for n in nodes}
|
||||
if from_id not in node_map or to_id not in node_map:
|
||||
return None # node not found
|
||||
|
||||
G = _load_graph()
|
||||
max_hops = max(1, min(max_hops, 5))
|
||||
|
||||
try:
|
||||
raw_paths = list(nx.all_simple_paths(G, from_id, to_id, cutoff=max_hops))
|
||||
except nx.NetworkXError:
|
||||
raw_paths = []
|
||||
|
||||
paths = []
|
||||
for path_nodes in raw_paths:
|
||||
path_edges = []
|
||||
for i in range(len(path_nodes) - 1):
|
||||
s, t = path_nodes[i], path_nodes[i + 1]
|
||||
edge_data = G.edges[s, t]
|
||||
path_edges.append({"source": s, "target": t,
|
||||
"relation": edge_data.get("relation", "CO_OCCURS_IN")})
|
||||
paths.append({
|
||||
"length": len(path_nodes) - 1,
|
||||
"nodes": [{"id": nid, "name": node_map.get(nid, {}).get("name", nid),
|
||||
"type": node_map.get(nid, {}).get("type", "")} for nid in path_nodes],
|
||||
"edges": path_edges,
|
||||
})
|
||||
|
||||
from_node = node_map[from_id]
|
||||
to_node = node_map[to_id]
|
||||
return {
|
||||
"from": {"id": from_id, "name": from_node.get("name", ""), "type": from_node.get("type", "")},
|
||||
"to": {"id": to_id, "name": to_node.get("name", ""), "type": to_node.get("type", "")},
|
||||
"max_hops": max_hops,
|
||||
"paths": paths,
|
||||
"total_paths": len(paths),
|
||||
}
|
||||
|
||||
|
||||
def search_graph(q: str, include_neighbors: bool = False) -> dict:
|
||||
nodes = fs.load_kg_nodes()
|
||||
edges = fs.load_kg_edges()
|
||||
G = _load_graph()
|
||||
degrees = dict(G.degree())
|
||||
q_lower = q.lower()
|
||||
|
||||
matched = [n for n in nodes if q_lower in n.get("name", "").lower()]
|
||||
matched_ids = {n["id"] for n in matched}
|
||||
for n in matched:
|
||||
n["degree"] = degrees.get(n["id"], 0)
|
||||
|
||||
if include_neighbors:
|
||||
neighbor_ids = set()
|
||||
for nid in matched_ids:
|
||||
if nid in G:
|
||||
neighbor_ids.update(G.neighbors(nid))
|
||||
all_relevant = matched_ids | neighbor_ids
|
||||
else:
|
||||
all_relevant = matched_ids
|
||||
|
||||
subgraph_edges = [
|
||||
e for e in edges
|
||||
if e.get("source") in all_relevant and e.get("target") in all_relevant
|
||||
]
|
||||
|
||||
return {
|
||||
"query": q,
|
||||
"matched_nodes": matched,
|
||||
"subgraph_edges": subgraph_edges,
|
||||
}
|
||||
Reference in New Issue
Block a user