Files
DocsMCP/backend/app/main.py
T
2026-06-06 00:30:24 +01:00

383 lines
12 KiB
Python

"""Context7 Docs API."""
import asyncio
import io
import os
import shutil
import zipfile
import yaml
from pathlib import Path
from posixpath import normpath
from typing import Optional
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from .config import settings
from .db import (
clear_library_documents,
delete_library,
init_db,
list_libraries,
search_libraries,
upsert_library,
)
from .git_source import ingest_git_source
from .ingest import ingest_all, ingest_library
from .search import get_library_docs, resolve_library_id, search_docs
from .vector_store import delete_library_vectors, ensure_collection, get_client, get_collection_name
app = FastAPI(
title="Context7 Docs API",
description="Document ingestion and semantic search API for local-context7",
version="1.0.0",
)
class SearchRequest(BaseModel):
query: str = Field(..., min_length=1)
library_id: Optional[str] = None
limit: int = Field(10, ge=1, le=50)
class SyncSourcesRequest(BaseModel):
override: bool = False
DOCUMENT_EXTENSIONS = {
".md",
".txt",
".py",
".js",
".ts",
".json",
".yaml",
".yml",
".html",
".css",
".pdf",
}
ALLOWED_EXTENSIONS = DOCUMENT_EXTENSIONS | {".zip"}
MAX_FILE_UPLOAD_BYTES = 5 * 1024 * 1024
MAX_ZIP_UPLOAD_BYTES = 100 * 1024 * 1024
MAX_ZIP_EXTRACTED_BYTES = 250 * 1024 * 1024
MAX_ZIP_FILES = 2000
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
"""Require X-API-Key for mutating endpoints when API_KEY_DOCS_API is set."""
if not settings.is_auth_enabled:
return await call_next(request)
public_prefixes = ("/health", "/libraries", "/docs/")
if request.method == "GET" and request.url.path.startswith(public_prefixes):
return await call_next(request)
if request.headers.get("X-API-Key") != settings.api_key_docs_api:
return JSONResponse(status_code=401, content={"detail": "Invalid or missing API key"})
return await call_next(request)
@app.on_event("startup")
async def startup() -> None:
init_result = init_db()
if not init_result.get("success"):
raise RuntimeError(f"Failed to initialize SQLite database: {init_result.get('error')}")
last_error = None
for _ in range(20):
collection_result = await ensure_collection()
if collection_result.get("success"):
return
last_error = collection_result.get("error")
await asyncio.sleep(1)
raise RuntimeError(f"Failed to initialize Qdrant collection: {last_error}")
def safe_library_id(library_id: str) -> str:
"""Normalize user-provided library IDs to a single path segment."""
base = Path(library_id).name.strip()
if not base or base in {".", ".."} or ".." in library_id or "/" in library_id or "\\" in library_id:
raise HTTPException(status_code=400, detail="Invalid library ID")
return base
def safe_upload_filename(filename: str) -> str:
ext = Path(filename).suffix.lower()
if ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"Unsafe extension: {ext}. Allowed extensions: {', '.join(sorted(ALLOWED_EXTENSIONS))}",
)
stem = "".join(c for c in Path(filename).stem if c.isalnum() or c in "-_ ").strip()
if not stem:
raise HTTPException(status_code=400, detail="Filename contains only unsafe characters")
return f"{stem}{ext}"
def safe_archive_member_path(member_name: str) -> Optional[Path]:
normalized = normpath(member_name.replace("\\", "/")).lstrip("/")
if not normalized or normalized == ".":
return None
path = Path(normalized)
if path.is_absolute() or ".." in path.parts or path.suffix.lower() not in DOCUMENT_EXTENSIONS:
return None
safe_parts = []
for part in path.parts:
cleaned = "".join(c for c in part if c.isalnum() or c in "-_ .").strip()
if not cleaned:
return None
safe_parts.append(cleaned)
return Path(*safe_parts)
def assert_inside_directory(root: Path, target: Path) -> None:
root_resolved = root.resolve()
target_resolved = target.resolve()
if os.path.commonpath([str(root_resolved), str(target_resolved)]) != str(root_resolved):
raise HTTPException(status_code=400, detail="Archive contains unsafe paths")
def extract_zip_upload(contents: bytes, lib_dir: Path) -> dict:
if len(contents) > MAX_ZIP_UPLOAD_BYTES:
raise HTTPException(status_code=400, detail="ZIP too large (max 100MB)")
extracted = []
skipped = []
total_uncompressed = 0
try:
archive = zipfile.ZipFile(io.BytesIO(contents))
except zipfile.BadZipFile:
raise HTTPException(status_code=400, detail="Invalid ZIP archive")
with archive:
files = [info for info in archive.infolist() if not info.is_dir()]
if len(files) > MAX_ZIP_FILES:
raise HTTPException(status_code=400, detail=f"ZIP contains too many files (max {MAX_ZIP_FILES})")
for info in files:
relative_path = safe_archive_member_path(info.filename)
if relative_path is None:
skipped.append(info.filename)
continue
total_uncompressed += info.file_size
if total_uncompressed > MAX_ZIP_EXTRACTED_BYTES:
raise HTTPException(status_code=400, detail="ZIP extracted content too large (max 250MB)")
target = lib_dir / relative_path
assert_inside_directory(lib_dir, target)
target.parent.mkdir(parents=True, exist_ok=True)
with archive.open(info) as src:
target.write_bytes(src.read())
extracted.append(str(target.relative_to(lib_dir)))
return {"extracted": extracted, "skipped": skipped, "size_bytes": len(contents)}
def docs_root() -> Path:
return Path(settings.docs_path)
def sources_config_path() -> Path:
return Path(__file__).resolve().parents[2] / "docs_sources.yaml"
@app.get("/health")
async def health_check():
return {"status": "ok", "service": "docs-api"}
@app.get("/collections")
async def collections():
try:
client = get_client()
info = client.get_collection(get_collection_name())
vectors = getattr(info, "vectors_count", None) or getattr(info, "points_count", 0) or 0
return {"collections": {get_collection_name(): {"vectors": vectors}}}
except Exception as e:
return {"collections": {}, "warning": str(e)}
@app.get("/libraries")
async def list_libraries_api():
libs = list_libraries()
if isinstance(libs, dict) and not libs.get("success", True):
raise HTTPException(status_code=500, detail=libs.get("error", "Failed to list libraries"))
return {"libraries": libs, "count": len(libs)}
@app.get("/libraries/search")
async def search_libraries_api(q: str = Query(..., min_length=1)):
matches = resolve_library_id(q)
return {"matches": matches, "count": len(matches)}
@app.post("/search")
async def search_docs_api(payload: SearchRequest):
results = search_docs(payload.query, library_id=payload.library_id, limit=payload.limit)
return {
"query": payload.query,
"library_id": payload.library_id,
"results": results,
"count": len(results),
}
@app.get("/docs/{library_id}")
@app.get("/libraries/{library_id}/docs")
async def get_library_docs_api(
library_id: str,
topic: Optional[str] = Query(None),
tokens: int = Query(8000, ge=1),
):
docs = get_library_docs(library_id=library_id, topic=topic, token_limit=tokens)
return {"library_id": library_id, "content": docs}
@app.post("/ingest/all")
async def ingest_all_api():
return await ingest_all()
@app.post("/ingest/{library_id}")
async def ingest_library_api(library_id: str):
library_id = safe_library_id(library_id)
source_path = library_id
return await ingest_library(library_id=library_id, name=library_id, source_path=source_path)
@app.post("/api/v1/libraries/{library_id}")
async def api_create_library(
library_id: str,
name: Optional[str] = Form(None),
description: Optional[str] = Form(None),
):
library_id = safe_library_id(library_id)
lib_dir = docs_root() / library_id
lib_dir.mkdir(parents=True, exist_ok=True)
result = upsert_library(library_id, name or library_id, description, library_id)
if not result.get("success"):
raise HTTPException(status_code=500, detail=result.get("error", "Failed to create library"))
return {
"success": True,
"created": not result.get("exists", False),
"library_id": library_id,
"name": name or library_id,
"description": description,
"path": str(lib_dir),
}
@app.delete("/api/v1/libraries/{library_id}")
async def api_delete_library(library_id: str):
library_id = safe_library_id(library_id)
lib_dir = docs_root() / library_id
deleted_files = 0
if lib_dir.exists():
for path in lib_dir.rglob("*"):
if path.is_file():
deleted_files += 1
shutil.rmtree(lib_dir)
docs_result = clear_library_documents(library_id)
vectors_result = await delete_library_vectors(library_id)
library_result = delete_library(library_id)
failures = [
r.get("error")
for r in (docs_result, vectors_result, library_result)
if isinstance(r, dict) and not r.get("success", True)
]
if failures:
raise HTTPException(status_code=500, detail="; ".join(failures))
return {"success": True, "library_id": library_id, "deleted_files": deleted_files}
@app.post("/api/v1/upload/{library_id}")
async def api_upload(library_id: str, file: UploadFile = File(...)):
library_id = safe_library_id(library_id)
safe_name = safe_upload_filename(file.filename or "upload.txt")
lib_dir = docs_root() / library_id
lib_dir.mkdir(parents=True, exist_ok=True)
contents = await file.read()
if safe_name.lower().endswith(".zip"):
result = extract_zip_upload(contents, lib_dir)
upsert_library(library_id, library_id, None, library_id)
return {
"success": True,
"library_id": library_id,
"filename": safe_name,
"archive": True,
**result,
}
if len(contents) > MAX_FILE_UPLOAD_BYTES:
raise HTTPException(status_code=400, detail="File too large (max 5MB)")
target = lib_dir / safe_name
target.write_bytes(contents)
upsert_library(library_id, library_id, None, library_id)
return {
"success": True,
"library_id": library_id,
"filename": safe_name,
"path": str(target.relative_to(docs_root())),
"size_bytes": len(contents),
}
@app.get("/api/v1/sources")
@app.get("/sources/config")
async def api_list_sources():
path = sources_config_path()
if not path.exists():
return {"success": True, "sources": [], "count": 0}
with path.open() as f:
data = yaml.safe_load(f) or {}
sources = data.get("sources", data if isinstance(data, list) else [])
if not isinstance(sources, list):
sources = []
return {"success": True, "sources": sources, "count": len(sources)}
@app.post("/sources/sync")
async def sync_sources_api(payload: Optional[SyncSourcesRequest] = None):
source_data = await api_list_sources()
sources = source_data["sources"]
override = payload.override if payload else False
results = []
for source in sources:
result = await ingest_git_source(
library_id=source["library_id"],
name=source.get("name") or source["library_id"],
description=source.get("description"),
repo_url=source["repo_url"],
branch=source.get("branch", "main"),
include_paths=source.get("include_paths"),
exclude_paths=source.get("exclude_paths"),
)
results.append(result)
successful = len([r for r in results if r.get("success")])
return {
"success": successful == len(results),
"total_sources": len(results),
"successful": successful,
"failed": len(results) - successful,
"results": results,
}