300 lines
9.2 KiB
Python
300 lines
9.2 KiB
Python
"""Context7 Docs API."""
|
|
import asyncio
|
|
import shutil
|
|
import yaml
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
from .config import settings
|
|
from .db import (
|
|
clear_library_documents,
|
|
delete_library,
|
|
init_db,
|
|
list_libraries,
|
|
search_libraries,
|
|
upsert_library,
|
|
)
|
|
from .git_source import ingest_git_source
|
|
from .ingest import ingest_all, ingest_library
|
|
from .search import get_library_docs, resolve_library_id, search_docs
|
|
from .vector_store import delete_library_vectors, ensure_collection, get_client, get_collection_name
|
|
|
|
|
|
app = FastAPI(
|
|
title="Context7 Docs API",
|
|
description="Document ingestion and semantic search API for local-context7",
|
|
version="1.0.0",
|
|
)
|
|
|
|
|
|
class SearchRequest(BaseModel):
|
|
query: str = Field(..., min_length=1)
|
|
library_id: Optional[str] = None
|
|
limit: int = Field(10, ge=1, le=50)
|
|
|
|
|
|
class SyncSourcesRequest(BaseModel):
|
|
override: bool = False
|
|
|
|
|
|
ALLOWED_EXTENSIONS = {
|
|
".md",
|
|
".txt",
|
|
".py",
|
|
".js",
|
|
".ts",
|
|
".json",
|
|
".yaml",
|
|
".yml",
|
|
".html",
|
|
".css",
|
|
".pdf",
|
|
}
|
|
|
|
|
|
@app.middleware("http")
|
|
async def auth_middleware(request: Request, call_next):
|
|
"""Require X-API-Key for mutating endpoints when API_KEY_DOCS_API is set."""
|
|
if not settings.is_auth_enabled:
|
|
return await call_next(request)
|
|
|
|
public_prefixes = ("/health", "/libraries", "/docs/")
|
|
if request.method == "GET" and request.url.path.startswith(public_prefixes):
|
|
return await call_next(request)
|
|
|
|
if request.headers.get("X-API-Key") != settings.api_key_docs_api:
|
|
return JSONResponse(status_code=401, content={"detail": "Invalid or missing API key"})
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup() -> None:
|
|
init_result = init_db()
|
|
if not init_result.get("success"):
|
|
raise RuntimeError(f"Failed to initialize SQLite database: {init_result.get('error')}")
|
|
|
|
last_error = None
|
|
for _ in range(20):
|
|
collection_result = await ensure_collection()
|
|
if collection_result.get("success"):
|
|
return
|
|
last_error = collection_result.get("error")
|
|
await asyncio.sleep(1)
|
|
raise RuntimeError(f"Failed to initialize Qdrant collection: {last_error}")
|
|
|
|
|
|
def safe_library_id(library_id: str) -> str:
|
|
"""Normalize user-provided library IDs to a single path segment."""
|
|
base = Path(library_id).name.strip()
|
|
if not base or base in {".", ".."} or ".." in library_id or "/" in library_id or "\\" in library_id:
|
|
raise HTTPException(status_code=400, detail="Invalid library ID")
|
|
return base
|
|
|
|
|
|
def safe_upload_filename(filename: str) -> str:
|
|
ext = Path(filename).suffix.lower()
|
|
if ext not in ALLOWED_EXTENSIONS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Unsafe extension: {ext}. Allowed extensions: {', '.join(sorted(ALLOWED_EXTENSIONS))}",
|
|
)
|
|
|
|
stem = "".join(c for c in Path(filename).stem if c.isalnum() or c in "-_ ").strip()
|
|
if not stem:
|
|
raise HTTPException(status_code=400, detail="Filename contains only unsafe characters")
|
|
return f"{stem}{ext}"
|
|
|
|
|
|
def docs_root() -> Path:
|
|
return Path(settings.docs_path)
|
|
|
|
|
|
def sources_config_path() -> Path:
|
|
return Path(__file__).resolve().parents[2] / "docs_sources.yaml"
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "ok", "service": "docs-api"}
|
|
|
|
|
|
@app.get("/collections")
|
|
async def collections():
|
|
try:
|
|
client = get_client()
|
|
info = client.get_collection(get_collection_name())
|
|
vectors = getattr(info, "vectors_count", None) or getattr(info, "points_count", 0) or 0
|
|
return {"collections": {get_collection_name(): {"vectors": vectors}}}
|
|
except Exception as e:
|
|
return {"collections": {}, "warning": str(e)}
|
|
|
|
|
|
@app.get("/libraries")
|
|
async def list_libraries_api():
|
|
libs = list_libraries()
|
|
if isinstance(libs, dict) and not libs.get("success", True):
|
|
raise HTTPException(status_code=500, detail=libs.get("error", "Failed to list libraries"))
|
|
return {"libraries": libs, "count": len(libs)}
|
|
|
|
|
|
@app.get("/libraries/search")
|
|
async def search_libraries_api(q: str = Query(..., min_length=1)):
|
|
matches = resolve_library_id(q)
|
|
return {"matches": matches, "count": len(matches)}
|
|
|
|
|
|
@app.post("/search")
|
|
async def search_docs_api(payload: SearchRequest):
|
|
results = search_docs(payload.query, library_id=payload.library_id, limit=payload.limit)
|
|
return {
|
|
"query": payload.query,
|
|
"library_id": payload.library_id,
|
|
"results": results,
|
|
"count": len(results),
|
|
}
|
|
|
|
|
|
@app.get("/docs/{library_id}")
|
|
@app.get("/libraries/{library_id}/docs")
|
|
async def get_library_docs_api(
|
|
library_id: str,
|
|
topic: Optional[str] = Query(None),
|
|
tokens: int = Query(8000, ge=1),
|
|
):
|
|
docs = get_library_docs(library_id=library_id, topic=topic, token_limit=tokens)
|
|
return {"library_id": library_id, "content": docs}
|
|
|
|
|
|
@app.post("/ingest/all")
|
|
async def ingest_all_api():
|
|
return await ingest_all()
|
|
|
|
|
|
@app.post("/ingest/{library_id}")
|
|
async def ingest_library_api(library_id: str):
|
|
library_id = safe_library_id(library_id)
|
|
source_path = library_id
|
|
return await ingest_library(library_id=library_id, name=library_id, source_path=source_path)
|
|
|
|
|
|
@app.post("/api/v1/libraries/{library_id}")
|
|
async def api_create_library(
|
|
library_id: str,
|
|
name: Optional[str] = Form(None),
|
|
description: Optional[str] = Form(None),
|
|
):
|
|
library_id = safe_library_id(library_id)
|
|
lib_dir = docs_root() / library_id
|
|
lib_dir.mkdir(parents=True, exist_ok=True)
|
|
result = upsert_library(library_id, name or library_id, description, library_id)
|
|
if not result.get("success"):
|
|
raise HTTPException(status_code=500, detail=result.get("error", "Failed to create library"))
|
|
return {
|
|
"success": True,
|
|
"created": not result.get("exists", False),
|
|
"library_id": library_id,
|
|
"name": name or library_id,
|
|
"description": description,
|
|
"path": str(lib_dir),
|
|
}
|
|
|
|
|
|
@app.delete("/api/v1/libraries/{library_id}")
|
|
async def api_delete_library(library_id: str):
|
|
library_id = safe_library_id(library_id)
|
|
lib_dir = docs_root() / library_id
|
|
deleted_files = 0
|
|
|
|
if lib_dir.exists():
|
|
for path in lib_dir.rglob("*"):
|
|
if path.is_file():
|
|
deleted_files += 1
|
|
shutil.rmtree(lib_dir)
|
|
|
|
docs_result = clear_library_documents(library_id)
|
|
vectors_result = await delete_library_vectors(library_id)
|
|
library_result = delete_library(library_id)
|
|
|
|
failures = [
|
|
r.get("error")
|
|
for r in (docs_result, vectors_result, library_result)
|
|
if isinstance(r, dict) and not r.get("success", True)
|
|
]
|
|
if failures:
|
|
raise HTTPException(status_code=500, detail="; ".join(failures))
|
|
|
|
return {"success": True, "library_id": library_id, "deleted_files": deleted_files}
|
|
|
|
|
|
@app.post("/api/v1/upload/{library_id}")
|
|
async def api_upload(library_id: str, file: UploadFile = File(...)):
|
|
library_id = safe_library_id(library_id)
|
|
safe_name = safe_upload_filename(file.filename or "upload.txt")
|
|
lib_dir = docs_root() / library_id
|
|
lib_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
contents = await file.read()
|
|
if len(contents) > 5 * 1024 * 1024:
|
|
raise HTTPException(status_code=400, detail="File too large (max 5MB)")
|
|
|
|
target = lib_dir / safe_name
|
|
target.write_bytes(contents)
|
|
|
|
upsert_library(library_id, library_id, None, library_id)
|
|
return {
|
|
"success": True,
|
|
"library_id": library_id,
|
|
"filename": safe_name,
|
|
"path": str(target.relative_to(docs_root())),
|
|
"size_bytes": len(contents),
|
|
}
|
|
|
|
|
|
@app.get("/api/v1/sources")
|
|
@app.get("/sources/config")
|
|
async def api_list_sources():
|
|
path = sources_config_path()
|
|
if not path.exists():
|
|
return {"success": True, "sources": [], "count": 0}
|
|
|
|
with path.open() as f:
|
|
data = yaml.safe_load(f) or {}
|
|
sources = data.get("sources", data if isinstance(data, list) else [])
|
|
if not isinstance(sources, list):
|
|
sources = []
|
|
return {"success": True, "sources": sources, "count": len(sources)}
|
|
|
|
|
|
@app.post("/sources/sync")
|
|
async def sync_sources_api(payload: Optional[SyncSourcesRequest] = None):
|
|
source_data = await api_list_sources()
|
|
sources = source_data["sources"]
|
|
override = payload.override if payload else False
|
|
results = []
|
|
|
|
for source in sources:
|
|
result = await ingest_git_source(
|
|
library_id=source["library_id"],
|
|
name=source.get("name") or source["library_id"],
|
|
description=source.get("description"),
|
|
repo_url=source["repo_url"],
|
|
branch=source.get("branch", "main"),
|
|
include_paths=source.get("include_paths"),
|
|
exclude_paths=source.get("exclude_paths"),
|
|
)
|
|
results.append(result)
|
|
|
|
successful = len([r for r in results if r.get("success")])
|
|
return {
|
|
"success": successful == len(results),
|
|
"total_sources": len(results),
|
|
"successful": successful,
|
|
"failed": len(results) - successful,
|
|
"results": results,
|
|
}
|