376 lines
12 KiB
Python
376 lines
12 KiB
Python
# Vector Store Operations for Qdrant
|
|
import asyncio
|
|
import uuid
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
try:
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
|
|
except ImportError:
|
|
QdrantClient = None
|
|
Distance = VectorParams = PointStruct = Filter = FieldCondition = MatchValue = None
|
|
|
|
|
|
# Singleton client instance
|
|
_client: Optional[Any] = None
|
|
try:
|
|
from .config import settings
|
|
_collection_name = settings.collection_name
|
|
except Exception:
|
|
_collection_name = "local_context7_docs"
|
|
|
|
|
|
def get_client() -> Any:
|
|
"""Get or create the Qdrant client singleton using environment config."""
|
|
global _client
|
|
|
|
if _client is None:
|
|
if QdrantClient is None:
|
|
raise RuntimeError("qdrant-client is not installed")
|
|
|
|
try:
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
except ImportError:
|
|
pass
|
|
|
|
# Use QDRANT_URL from environment if available, otherwise use host:port
|
|
import os
|
|
qdrant_url = os.getenv("QDRANT_URL")
|
|
|
|
if qdrant_url:
|
|
_client = QdrantClient(url=qdrant_url)
|
|
else:
|
|
from .config import settings
|
|
host = settings.vector_store_host
|
|
port = settings.vector_store_port
|
|
_client = QdrantClient(host=host, port=port)
|
|
|
|
return _client
|
|
|
|
|
|
def get_collection_name() -> str:
|
|
"""Get the collection name for vector storage."""
|
|
return _collection_name
|
|
|
|
|
|
def get_embedding_size() -> int:
|
|
"""Get embedding dimension size from embeddings module."""
|
|
try:
|
|
from .embeddings import get_embedding_size
|
|
return get_embedding_size()
|
|
except (ImportError, RuntimeError):
|
|
# Default fallback if embeddings module not loaded yet
|
|
return 384
|
|
|
|
|
|
def create_collection(client: Any, collection_name: str, size: int, distance: Any) -> None:
|
|
"""Create a Qdrant collection across qdrant-client keyword changes."""
|
|
vector_params = VectorParams(size=size, distance=distance)
|
|
variants = (
|
|
{"vectors_config": vector_params, "wait": True},
|
|
{"vectors_config": vector_params},
|
|
{"vectors": vector_params, "wait": True},
|
|
{"vectors": vector_params},
|
|
)
|
|
last_error = None
|
|
for kwargs in variants:
|
|
try:
|
|
client.create_collection(collection_name=collection_name, **kwargs)
|
|
return
|
|
except (AssertionError, TypeError, ValueError) as exc:
|
|
last_error = exc
|
|
message = str(exc)
|
|
if not any(key in message for key in ("Unknown arguments", "vectors", "wait")):
|
|
raise
|
|
raise RuntimeError(f"Could not create Qdrant collection: {last_error}")
|
|
|
|
|
|
async def ensure_collection(vector_size: Optional[int] = None) -> Dict[str, Any]:
|
|
"""
|
|
Ensure the Qdrant collection exists with proper schema.
|
|
|
|
Args:
|
|
vector_size: Override embedding dimension (uses get_embedding_size() if not provided)
|
|
|
|
Returns:
|
|
Dict with operation result
|
|
"""
|
|
try:
|
|
if QdrantClient is None:
|
|
return {"success": False, "error": "qdrant-client is not installed"}
|
|
|
|
client = get_client()
|
|
size = vector_size or get_embedding_size()
|
|
distance = Distance.COSINE
|
|
|
|
# Check if collection exists
|
|
try:
|
|
collections = client.get_collections().collections
|
|
collection_exists = any(c.name == _collection_name for c in collections)
|
|
except Exception:
|
|
collection_exists = False
|
|
|
|
if not collection_exists:
|
|
# Create new collection
|
|
create_collection(client, _collection_name, size, distance)
|
|
|
|
return {
|
|
"success": True,
|
|
"collection": _collection_name,
|
|
"vector_size": size,
|
|
"created": True
|
|
}
|
|
else:
|
|
# Verify current vector size matches expected
|
|
try:
|
|
collection_info = client.get_collection(_collection_name)
|
|
current_size = collection_info.config.params.vectors.size
|
|
|
|
if current_size != size:
|
|
# Collection exists with wrong size - delete and recreate
|
|
client.delete_collection(_collection_name)
|
|
create_collection(client, _collection_name, size, distance)
|
|
|
|
return {
|
|
"success": True,
|
|
"collection": _collection_name,
|
|
"vector_size": size,
|
|
"created": False,
|
|
"resized": True
|
|
}
|
|
except Exception:
|
|
pass # Collection exists, don't worry about size for now
|
|
|
|
return {
|
|
"success": True,
|
|
"collection": _collection_name,
|
|
"vector_size": size,
|
|
"created": False
|
|
}
|
|
|
|
except Exception as e:
|
|
return {"success": False, "error": str(e)}
|
|
|
|
|
|
async def upsert_chunks(chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""
|
|
Upsert chunks into the vector store.
|
|
|
|
Args:
|
|
chunks: List of chunk dicts with format:
|
|
{
|
|
"id": "...",
|
|
"library_id": "...",
|
|
"path": "...",
|
|
"title": "...",
|
|
"chunk_index": 0,
|
|
"content": "...",
|
|
"embedding": [...]
|
|
}
|
|
|
|
Returns:
|
|
Dict with operation result
|
|
"""
|
|
try:
|
|
if QdrantClient is None:
|
|
return {"success": False, "error": "qdrant-client is not installed"}
|
|
|
|
if not chunks:
|
|
return {"success": True, "points_added": 0}
|
|
|
|
client = get_client()
|
|
|
|
# Build PointStruct points from chunk dicts
|
|
points = []
|
|
for chunk in chunks:
|
|
point_key = f"{chunk['library_id']}:{chunk['id']}"
|
|
point_id = str(uuid.uuid5(uuid.NAMESPACE_URL, point_key))
|
|
|
|
points.append(PointStruct(
|
|
id=point_id,
|
|
vector=chunk["embedding"],
|
|
payload={
|
|
"id": chunk["id"],
|
|
"library_id": chunk["library_id"],
|
|
"path": chunk.get("path", ""),
|
|
"title": chunk.get("title", ""),
|
|
"chunk_index": chunk.get("chunk_index", 0),
|
|
"content": chunk.get("content", "")
|
|
}
|
|
))
|
|
|
|
# Upsert points into collection
|
|
client.upsert(_collection_name, points=points)
|
|
|
|
return {
|
|
"success": True,
|
|
"points_added": len(points)
|
|
}
|
|
|
|
except Exception as e:
|
|
return {"success": False, "error": str(e)}
|
|
|
|
|
|
async def search_vectors(
|
|
query_vector: List[float],
|
|
library_id: Optional[str] = None,
|
|
limit: int = 10
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search for semantically similar vectors.
|
|
|
|
Args:
|
|
query_vector: The embedding vector to search against
|
|
library_id: Optional filter by library ID
|
|
limit: Maximum results to return
|
|
|
|
Returns:
|
|
List of result dicts with format:
|
|
{
|
|
"id": "...",
|
|
"score": 0.123,
|
|
"library_id": "...",
|
|
"path": "...",
|
|
"title": "...",
|
|
"chunk_index": 0
|
|
}
|
|
"""
|
|
try:
|
|
if QdrantClient is None:
|
|
return []
|
|
|
|
client = get_client()
|
|
|
|
# Build filter if library_id is specified
|
|
search_filter = None
|
|
if library_id:
|
|
search_filter = Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="library_id",
|
|
match=MatchValue(value=library_id),
|
|
)
|
|
]
|
|
)
|
|
|
|
# Perform vector search
|
|
results = client.search(
|
|
collection_name=_collection_name,
|
|
query_vector=query_vector,
|
|
limit=limit,
|
|
search_filter=search_filter
|
|
)
|
|
|
|
# Format results
|
|
formatted_results = []
|
|
for result in results:
|
|
if result.score > 0 and result.payload:
|
|
formatted_results.append({
|
|
"id": result.payload["id"],
|
|
"score": float(result.score),
|
|
"library_id": result.payload["library_id"],
|
|
"path": result.payload.get("path", ""),
|
|
"title": result.payload.get("title", ""),
|
|
"chunk_index": result.payload.get("chunk_index", 0)
|
|
})
|
|
|
|
return formatted_results
|
|
|
|
except Exception as e:
|
|
return []
|
|
|
|
|
|
async def delete_library_vectors(library_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Delete all vectors for a given library.
|
|
|
|
Args:
|
|
library_id: The library ID to delete vectors for
|
|
|
|
Returns:
|
|
Dict with operation result
|
|
"""
|
|
try:
|
|
if QdrantClient is None:
|
|
return {"success": True, "library_id": library_id, "skipped": "qdrant-client is not installed"}
|
|
|
|
client = get_client()
|
|
|
|
# Use filter to delete only vectors matching the library_id
|
|
filter_condition = Filter(
|
|
must=[
|
|
FieldCondition(
|
|
key="library_id",
|
|
match=MatchValue(value=library_id),
|
|
)
|
|
]
|
|
)
|
|
|
|
# Get all points with the filter (in batches)
|
|
batch_size = 100
|
|
offset = None
|
|
|
|
while True:
|
|
try:
|
|
# Scroll to get points matching filter
|
|
points, _ = client.scroll(
|
|
collection_name=_collection_name,
|
|
scroll_filter=filter_condition,
|
|
limit=batch_size,
|
|
offset=offset,
|
|
with_payload=True,
|
|
with_vectors=False
|
|
)
|
|
|
|
if not points:
|
|
break
|
|
|
|
# Collect IDs to delete
|
|
point_ids = [p.id for p in points]
|
|
|
|
# Delete the points
|
|
client.delete(
|
|
collection_name=_collection_name,
|
|
points_selector=point_ids
|
|
)
|
|
|
|
offset = points[-1].id if points else None
|
|
|
|
except Exception as e:
|
|
# If we hit end of dataset or other issue, break
|
|
break
|
|
|
|
return {
|
|
"success": True,
|
|
"library_id": library_id
|
|
}
|
|
|
|
except Exception as e:
|
|
return {"success": False, "error": str(e)}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test vector store module
|
|
import os
|
|
|
|
print("Testing vector store module...\n")
|
|
|
|
# Test ensure_collection
|
|
print("1. Testing ensure_collection()...")
|
|
result = asyncio.run(ensure_collection())
|
|
print(f" Result: {result}\n")
|
|
|
|
# Test search with empty query (will return empty since no vectors exist yet)
|
|
print("2. Testing search_vectors() with dummy vector...")
|
|
dummy_vector = [0.1] * 384
|
|
results = asyncio.run(search_vectors(dummy_vector, limit=5))
|
|
print(f" Results count: {len(results)}\n")
|
|
|
|
# Test delete_library_vectors (will succeed even if no vectors exist)
|
|
print("3. Testing delete_library_vectors()...")
|
|
result = asyncio.run(delete_library_vectors("test-library"))
|
|
print(f" Result: {result}\n")
|
|
|
|
print("✅ All tests completed!")
|