"""Context7 Docs API.""" import asyncio import io import os import shutil import zipfile import yaml from pathlib import Path from posixpath import normpath from typing import Optional from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from .config import settings from .db import ( clear_library_documents, delete_library, init_db, list_libraries, search_libraries, upsert_library, ) from .git_source import ingest_git_source from .ingest import ingest_all, ingest_library from .search import get_library_docs, resolve_library_id, search_docs from .vector_store import delete_library_vectors, ensure_collection, get_client, get_collection_name app = FastAPI( title="Context7 Docs API", description="Document ingestion and semantic search API for local-context7", version="1.0.0", ) class SearchRequest(BaseModel): query: str = Field(..., min_length=1) library_id: Optional[str] = None limit: int = Field(10, ge=1, le=50) class SyncSourcesRequest(BaseModel): override: bool = False class GitSourceRequest(BaseModel): library_id: str = Field(..., min_length=1) repo_url: str = Field(..., min_length=1) name: Optional[str] = None description: Optional[str] = None branch: str = "main" include_paths: Optional[list[str]] = None exclude_paths: Optional[list[str]] = None DOCUMENT_EXTENSIONS = { ".md", ".txt", ".py", ".js", ".ts", ".json", ".yaml", ".yml", ".html", ".css", ".pdf", } ALLOWED_EXTENSIONS = DOCUMENT_EXTENSIONS | {".zip"} MAX_FILE_UPLOAD_BYTES = 5 * 1024 * 1024 MAX_ZIP_UPLOAD_BYTES = 100 * 1024 * 1024 MAX_ZIP_EXTRACTED_BYTES = 250 * 1024 * 1024 MAX_ZIP_FILES = 2000 @app.middleware("http") async def auth_middleware(request: Request, call_next): """Require X-API-Key for mutating endpoints when API_KEY_DOCS_API is set.""" if not settings.is_auth_enabled: return await call_next(request) public_prefixes = ("/health", "/libraries", "/docs/") if request.method == "GET" and request.url.path.startswith(public_prefixes): return await call_next(request) if request.headers.get("X-API-Key") != settings.api_key_docs_api: return JSONResponse(status_code=401, content={"detail": "Invalid or missing API key"}) return await call_next(request) @app.on_event("startup") async def startup() -> None: init_result = init_db() if not init_result.get("success"): raise RuntimeError(f"Failed to initialize SQLite database: {init_result.get('error')}") last_error = None for _ in range(20): collection_result = await ensure_collection() if collection_result.get("success"): return last_error = collection_result.get("error") await asyncio.sleep(1) raise RuntimeError(f"Failed to initialize Qdrant collection: {last_error}") def safe_library_id(library_id: str) -> str: """Normalize user-provided library IDs to a single path segment.""" base = Path(library_id).name.strip() if not base or base in {".", ".."} or ".." in library_id or "/" in library_id or "\\" in library_id: raise HTTPException(status_code=400, detail="Invalid library ID") return base def safe_upload_filename(filename: str) -> str: ext = Path(filename).suffix.lower() if ext not in ALLOWED_EXTENSIONS: raise HTTPException( status_code=400, detail=f"Unsafe extension: {ext}. Allowed extensions: {', '.join(sorted(ALLOWED_EXTENSIONS))}", ) stem = "".join(c for c in Path(filename).stem if c.isalnum() or c in "-_ ").strip() if not stem: raise HTTPException(status_code=400, detail="Filename contains only unsafe characters") return f"{stem}{ext}" def safe_archive_member_path(member_name: str) -> Optional[Path]: normalized = normpath(member_name.replace("\\", "/")).lstrip("/") if not normalized or normalized == ".": return None path = Path(normalized) if path.is_absolute() or ".." in path.parts or path.suffix.lower() not in DOCUMENT_EXTENSIONS: return None safe_parts = [] for part in path.parts: cleaned = "".join(c for c in part if c.isalnum() or c in "-_ .").strip() if not cleaned: return None safe_parts.append(cleaned) return Path(*safe_parts) def assert_inside_directory(root: Path, target: Path) -> None: root_resolved = root.resolve() target_resolved = target.resolve() if os.path.commonpath([str(root_resolved), str(target_resolved)]) != str(root_resolved): raise HTTPException(status_code=400, detail="Archive contains unsafe paths") def extract_zip_upload(contents: bytes, lib_dir: Path) -> dict: if len(contents) > MAX_ZIP_UPLOAD_BYTES: raise HTTPException(status_code=400, detail="ZIP too large (max 100MB)") extracted = [] skipped = [] total_uncompressed = 0 try: archive = zipfile.ZipFile(io.BytesIO(contents)) except zipfile.BadZipFile: raise HTTPException(status_code=400, detail="Invalid ZIP archive") with archive: files = [info for info in archive.infolist() if not info.is_dir()] if len(files) > MAX_ZIP_FILES: raise HTTPException(status_code=400, detail=f"ZIP contains too many files (max {MAX_ZIP_FILES})") for info in files: relative_path = safe_archive_member_path(info.filename) if relative_path is None: skipped.append(info.filename) continue total_uncompressed += info.file_size if total_uncompressed > MAX_ZIP_EXTRACTED_BYTES: raise HTTPException(status_code=400, detail="ZIP extracted content too large (max 250MB)") target = lib_dir / relative_path assert_inside_directory(lib_dir, target) target.parent.mkdir(parents=True, exist_ok=True) with archive.open(info) as src: target.write_bytes(src.read()) extracted.append(str(target.relative_to(lib_dir))) return {"extracted": extracted, "skipped": skipped, "size_bytes": len(contents)} def docs_root() -> Path: return Path(settings.docs_path) def sources_config_path() -> Path: return Path(__file__).resolve().parents[2] / "docs_sources.yaml" def clean_source_paths(paths: Optional[list[str]]) -> list[str]: cleaned = [] for raw_path in paths or []: path = raw_path.strip().strip("/") if not path or path == "." or ".." in Path(path).parts or Path(path).is_absolute(): continue cleaned.append(path) return cleaned def load_sources_config() -> dict: path = sources_config_path() if not path.exists(): return {"sources": []} with path.open() as f: data = yaml.safe_load(f) or {} if isinstance(data, list): return {"sources": data} if not isinstance(data, dict): return {"sources": []} sources = data.get("sources", []) data["sources"] = sources if isinstance(sources, list) else [] return data def save_sources_config(data: dict) -> None: path = sources_config_path() path.parent.mkdir(parents=True, exist_ok=True) with path.open("w") as f: yaml.safe_dump(data, f, sort_keys=False) @app.get("/health") async def health_check(): return {"status": "ok", "service": "docs-api"} @app.get("/collections") async def collections(): try: client = get_client() info = client.get_collection(get_collection_name()) vectors = getattr(info, "vectors_count", None) or getattr(info, "points_count", 0) or 0 return {"collections": {get_collection_name(): {"vectors": vectors}}} except Exception as e: return {"collections": {}, "warning": str(e)} @app.get("/libraries") async def list_libraries_api(): libs = list_libraries() if isinstance(libs, dict) and not libs.get("success", True): raise HTTPException(status_code=500, detail=libs.get("error", "Failed to list libraries")) return {"libraries": libs, "count": len(libs)} @app.get("/libraries/search") async def search_libraries_api(q: str = Query(..., min_length=1)): matches = resolve_library_id(q) return {"matches": matches, "count": len(matches)} @app.post("/search") async def search_docs_api(payload: SearchRequest): results = search_docs(payload.query, library_id=payload.library_id, limit=payload.limit) return { "query": payload.query, "library_id": payload.library_id, "results": results, "count": len(results), } @app.get("/docs/{library_id}") @app.get("/libraries/{library_id}/docs") async def get_library_docs_api( library_id: str, topic: Optional[str] = Query(None), tokens: int = Query(8000, ge=1), ): docs = get_library_docs(library_id=library_id, topic=topic, token_limit=tokens) return {"library_id": library_id, "content": docs} @app.post("/ingest/all") async def ingest_all_api(): return await ingest_all() @app.post("/ingest/{library_id}") async def ingest_library_api(library_id: str): library_id = safe_library_id(library_id) source_path = library_id return await ingest_library(library_id=library_id, name=library_id, source_path=source_path) @app.post("/api/v1/libraries/{library_id}") async def api_create_library( library_id: str, name: Optional[str] = Form(None), description: Optional[str] = Form(None), ): library_id = safe_library_id(library_id) lib_dir = docs_root() / library_id lib_dir.mkdir(parents=True, exist_ok=True) result = upsert_library(library_id, name or library_id, description, library_id) if not result.get("success"): raise HTTPException(status_code=500, detail=result.get("error", "Failed to create library")) return { "success": True, "created": not result.get("exists", False), "library_id": library_id, "name": name or library_id, "description": description, "path": str(lib_dir), } @app.delete("/api/v1/libraries/{library_id}") async def api_delete_library(library_id: str): library_id = safe_library_id(library_id) lib_dir = docs_root() / library_id deleted_files = 0 if lib_dir.exists(): for path in lib_dir.rglob("*"): if path.is_file(): deleted_files += 1 shutil.rmtree(lib_dir) docs_result = clear_library_documents(library_id) vectors_result = await delete_library_vectors(library_id) library_result = delete_library(library_id) failures = [ r.get("error") for r in (docs_result, vectors_result, library_result) if isinstance(r, dict) and not r.get("success", True) ] if failures: raise HTTPException(status_code=500, detail="; ".join(failures)) return {"success": True, "library_id": library_id, "deleted_files": deleted_files} @app.post("/api/v1/upload/{library_id}") async def api_upload(library_id: str, file: UploadFile = File(...)): library_id = safe_library_id(library_id) safe_name = safe_upload_filename(file.filename or "upload.txt") lib_dir = docs_root() / library_id lib_dir.mkdir(parents=True, exist_ok=True) contents = await file.read() if safe_name.lower().endswith(".zip"): result = extract_zip_upload(contents, lib_dir) upsert_library(library_id, library_id, None, library_id) return { "success": True, "library_id": library_id, "filename": safe_name, "archive": True, **result, } if len(contents) > MAX_FILE_UPLOAD_BYTES: raise HTTPException(status_code=400, detail="File too large (max 5MB)") target = lib_dir / safe_name target.write_bytes(contents) upsert_library(library_id, library_id, None, library_id) return { "success": True, "library_id": library_id, "filename": safe_name, "path": str(target.relative_to(docs_root())), "size_bytes": len(contents), } @app.get("/api/v1/sources") @app.get("/sources/config") async def api_list_sources(): data = load_sources_config() sources = data["sources"] return {"success": True, "sources": sources, "count": len(sources)} @app.post("/api/v1/sources") async def api_add_source(source: GitSourceRequest): library_id = safe_library_id(source.library_id) branch = source.branch.strip() or "main" include_paths = clean_source_paths(source.include_paths) exclude_paths = clean_source_paths(source.exclude_paths) or ["node_modules", ".git"] source_entry = { "library_id": library_id, "name": (source.name or library_id).strip(), "description": (source.description or "").strip(), "repo_url": source.repo_url.strip(), "branch": branch, "include_paths": include_paths or ["docs"], "exclude_paths": exclude_paths, } data = load_sources_config() sources = data["sources"] existing_index = next( (index for index, item in enumerate(sources) if item.get("library_id") == library_id), None, ) if existing_index is None: sources.append(source_entry) created = True else: sources[existing_index] = source_entry created = False save_sources_config(data) return {"success": True, "created": created, "source": source_entry} @app.post("/sources/sync") async def sync_sources_api(payload: Optional[SyncSourcesRequest] = None): source_data = await api_list_sources() sources = source_data["sources"] override = payload.override if payload else False results = [] for source in sources: result = await ingest_git_source( library_id=source["library_id"], name=source.get("name") or source["library_id"], description=source.get("description"), repo_url=source["repo_url"], branch=source.get("branch", "main"), include_paths=source.get("include_paths"), exclude_paths=source.get("exclude_paths"), ) results.append(result) successful = len([r for r in results if r.get("success")]) return { "success": successful == len(results), "total_sources": len(results), "successful": successful, "failed": len(results) - successful, "results": results, }