diff --git a/backend/app/main.py b/backend/app/main.py index 7a66f6f..09f2ec0 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,8 +1,12 @@ """Context7 Docs API.""" import asyncio +import io +import os import shutil +import zipfile import yaml from pathlib import Path +from posixpath import normpath from typing import Optional from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile @@ -41,7 +45,7 @@ class SyncSourcesRequest(BaseModel): override: bool = False -ALLOWED_EXTENSIONS = { +DOCUMENT_EXTENSIONS = { ".md", ".txt", ".py", @@ -54,6 +58,11 @@ ALLOWED_EXTENSIONS = { ".css", ".pdf", } +ALLOWED_EXTENSIONS = DOCUMENT_EXTENSIONS | {".zip"} +MAX_FILE_UPLOAD_BYTES = 5 * 1024 * 1024 +MAX_ZIP_UPLOAD_BYTES = 100 * 1024 * 1024 +MAX_ZIP_EXTRACTED_BYTES = 250 * 1024 * 1024 +MAX_ZIP_FILES = 2000 @app.middleware("http") @@ -110,6 +119,69 @@ def safe_upload_filename(filename: str) -> str: return f"{stem}{ext}" +def safe_archive_member_path(member_name: str) -> Optional[Path]: + normalized = normpath(member_name.replace("\\", "/")).lstrip("/") + if not normalized or normalized == ".": + return None + + path = Path(normalized) + if path.is_absolute() or ".." in path.parts or path.suffix.lower() not in DOCUMENT_EXTENSIONS: + return None + + safe_parts = [] + for part in path.parts: + cleaned = "".join(c for c in part if c.isalnum() or c in "-_ .").strip() + if not cleaned: + return None + safe_parts.append(cleaned) + return Path(*safe_parts) + + +def assert_inside_directory(root: Path, target: Path) -> None: + root_resolved = root.resolve() + target_resolved = target.resolve() + if os.path.commonpath([str(root_resolved), str(target_resolved)]) != str(root_resolved): + raise HTTPException(status_code=400, detail="Archive contains unsafe paths") + + +def extract_zip_upload(contents: bytes, lib_dir: Path) -> dict: + if len(contents) > MAX_ZIP_UPLOAD_BYTES: + raise HTTPException(status_code=400, detail="ZIP too large (max 100MB)") + + extracted = [] + skipped = [] + total_uncompressed = 0 + + try: + archive = zipfile.ZipFile(io.BytesIO(contents)) + except zipfile.BadZipFile: + raise HTTPException(status_code=400, detail="Invalid ZIP archive") + + with archive: + files = [info for info in archive.infolist() if not info.is_dir()] + if len(files) > MAX_ZIP_FILES: + raise HTTPException(status_code=400, detail=f"ZIP contains too many files (max {MAX_ZIP_FILES})") + + for info in files: + relative_path = safe_archive_member_path(info.filename) + if relative_path is None: + skipped.append(info.filename) + continue + + total_uncompressed += info.file_size + if total_uncompressed > MAX_ZIP_EXTRACTED_BYTES: + raise HTTPException(status_code=400, detail="ZIP extracted content too large (max 250MB)") + + target = lib_dir / relative_path + assert_inside_directory(lib_dir, target) + target.parent.mkdir(parents=True, exist_ok=True) + with archive.open(info) as src: + target.write_bytes(src.read()) + extracted.append(str(target.relative_to(lib_dir))) + + return {"extracted": extracted, "skipped": skipped, "size_bytes": len(contents)} + + def docs_root() -> Path: return Path(settings.docs_path) @@ -239,7 +311,18 @@ async def api_upload(library_id: str, file: UploadFile = File(...)): lib_dir.mkdir(parents=True, exist_ok=True) contents = await file.read() - if len(contents) > 5 * 1024 * 1024: + if safe_name.lower().endswith(".zip"): + result = extract_zip_upload(contents, lib_dir) + upsert_library(library_id, library_id, None, library_id) + return { + "success": True, + "library_id": library_id, + "filename": safe_name, + "archive": True, + **result, + } + + if len(contents) > MAX_FILE_UPLOAD_BYTES: raise HTTPException(status_code=400, detail="File too large (max 5MB)") target = lib_dir / safe_name diff --git a/tests/test_upload_zip.py b/tests/test_upload_zip.py new file mode 100644 index 0000000..32d5380 --- /dev/null +++ b/tests/test_upload_zip.py @@ -0,0 +1,44 @@ +import io +import zipfile + +from backend.app.main import extract_zip_upload + + +def make_zip(entries): + payload = io.BytesIO() + with zipfile.ZipFile(payload, "w") as archive: + for name, content in entries.items(): + archive.writestr(name, content) + return payload.getvalue() + + +def test_extract_zip_upload_preserves_folders(tmp_path): + contents = make_zip( + { + "docs/index.md": "# Hello", + "docs/nested/example.py": "print('ok')", + } + ) + + result = extract_zip_upload(contents, tmp_path) + + assert sorted(result["extracted"]) == ["docs/index.md", "docs/nested/example.py"] + assert (tmp_path / "docs" / "index.md").read_text() == "# Hello" + assert (tmp_path / "docs" / "nested" / "example.py").read_text() == "print('ok')" + + +def test_extract_zip_upload_skips_unsafe_and_unsupported_members(tmp_path): + contents = make_zip( + { + "../escape.md": "bad", + "docs/image.png": "unsupported", + "docs/readme.md": "good", + } + ) + + result = extract_zip_upload(contents, tmp_path) + + assert result["extracted"] == ["docs/readme.md"] + assert "../escape.md" in result["skipped"] + assert "docs/image.png" in result["skipped"] + assert not (tmp_path.parent / "escape.md").exists() diff --git a/webui/app/templates/upload.html b/webui/app/templates/upload.html index 85d1cfb..904b353 100644 --- a/webui/app/templates/upload.html +++ b/webui/app/templates/upload.html @@ -17,7 +17,7 @@ - +
@@ -31,7 +31,7 @@ -

Allowed: .md, .txt, .py, .js, .ts, .json, .yaml, .yml, .html, .css, .pdf (max 5MB each)

+

Allowed: .md, .txt, .py, .js, .ts, .json, .yaml, .yml, .html, .css, .pdf, .zip (files max 5MB, ZIPs max 100MB)