181 lines
5.4 KiB
Python
181 lines
5.4 KiB
Python
# Local Embedding Generation using FastEmbed
|
|
import asyncio
|
|
from typing import List
|
|
from functools import lru_cache
|
|
|
|
|
|
# Module-level singleton for cached model instance
|
|
_embedding_model = None
|
|
_embedding_size = 384 # BAAI/bge-small-en-v1.5 output dimension
|
|
|
|
|
|
def _load_model():
|
|
"""Lazy-load the FastEmbed model on first use."""
|
|
global _embedding_model, _embedding_size
|
|
|
|
try:
|
|
from fastembed import TextEmbedding
|
|
|
|
if _embedding_model is None:
|
|
print("Loading embedding model (this may take a few minutes on first run)...")
|
|
|
|
# Use BAAI/bge-small-en-v1.5 - lightweight (~90MB), works offline
|
|
_embedding_model = TextEmbedding(model_name="BAAI/bge-small-en-v1.5", cache_dir=".embed_cache")
|
|
print("Embedding model loaded successfully.")
|
|
|
|
return _embedding_model
|
|
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"FastEmbed is not installed. Please install with:\n"
|
|
" pip install fastembed\n\n"
|
|
f"Import error details: {e}"
|
|
) from e
|
|
|
|
except RuntimeError as e:
|
|
# Model download/installation failed
|
|
if "No space left" in str(e) or "disk quota exceeded" in str(e):
|
|
raise RuntimeError(
|
|
"Failed to load embedding model due to disk space constraints.\n\n"
|
|
"Please free up space on your system (at least 500MB required).\n"
|
|
"Or specify a custom cache directory with available space:\n"
|
|
" from fastembed import TextEmbedding\n"
|
|
" model = TextEmbedding(model_name='...', cache_dir='/path/to/large/storage')\n\n"
|
|
f"Error: {e}"
|
|
) from e
|
|
raise
|
|
|
|
|
|
def get_embedding_model():
|
|
"""
|
|
Get the cached embedding model instance.
|
|
|
|
Returns:
|
|
FastEmbed TextEmbedding instance (lazy-loaded on first call)
|
|
|
|
Raises:
|
|
ImportError: If FastEmbed is not installed
|
|
RuntimeError: If model download/load failed
|
|
"""
|
|
global _embedding_model
|
|
if _embedding_model is None:
|
|
_embedding_model = _load_model()
|
|
return _embedding_model
|
|
|
|
|
|
def embed_text(text: str) -> List[float]:
|
|
"""
|
|
Generate embedding for a single text.
|
|
|
|
Args:
|
|
text: The text string to embed
|
|
|
|
Returns:
|
|
List of floats representing the embedding vector
|
|
|
|
Raises:
|
|
ImportError: If FastEmbed is not installed
|
|
RuntimeError: If model loading failed
|
|
"""
|
|
if not text or not isinstance(text, str):
|
|
return [0.0] * get_embedding_size()
|
|
|
|
model = get_embedding_model()
|
|
embedding = model.embed([text])
|
|
return embedding[0].tolist()
|
|
|
|
|
|
def embed_texts(texts: List[str]) -> List[List[float]]:
|
|
"""
|
|
Generate embeddings for multiple texts.
|
|
|
|
Args:
|
|
texts: List of text strings to embed
|
|
|
|
Returns:
|
|
List of lists containing embedding vectors (one per input text)
|
|
|
|
Raises:
|
|
ImportError: If FastEmbed is not installed
|
|
RuntimeError: If model loading failed
|
|
"""
|
|
if not texts:
|
|
return []
|
|
|
|
model = get_embedding_model()
|
|
embeddings = model.embed(texts)
|
|
|
|
result = []
|
|
for emb in embeddings:
|
|
if hasattr(emb, 'tolist'):
|
|
result.append(emb.tolist())
|
|
else:
|
|
result.append(emb)
|
|
|
|
return result
|
|
|
|
|
|
def get_embedding_size() -> int:
|
|
"""
|
|
Get the embedding dimension size.
|
|
|
|
Returns:
|
|
Integer representing vector dimension (384 for bge-small-en-v1.5)
|
|
|
|
Note:
|
|
This returns a sensible default. Actual dimension is determined by model.
|
|
"""
|
|
return _embedding_size
|
|
|
|
|
|
# Async wrapper for compatibility with existing code
|
|
async def generate_embeddings(chunks: List[str]) -> List[List[float]]:
|
|
"""
|
|
Async wrapper around embed_texts for compatibility.
|
|
|
|
Args:
|
|
chunks: List of text strings to embed
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
return embed_texts(chunks)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test the embeddings module
|
|
print("Testing embeddings module...\n")
|
|
|
|
# Test get_embedding_size
|
|
size = get_embedding_size()
|
|
print(f"Embedding dimension: {size}")
|
|
|
|
# Test single text embedding
|
|
test_text = "Hello, world! This is a test of the embedding generation."
|
|
try:
|
|
emb = embed_text(test_text)
|
|
print(f"\nSingle text embedding shape: ({len(emb)},)")
|
|
print(f"First 5 values: {emb[:5]}")
|
|
print("✓ Single embedding works")
|
|
except Exception as e:
|
|
print(f"✗ Single embedding failed: {e}")
|
|
|
|
# Test batch embedding
|
|
test_texts = [
|
|
"The quick brown fox jumps over the lazy dog.",
|
|
"Machine learning is a subset of artificial intelligence.",
|
|
"Natural language processing enables computers to understand human language."
|
|
]
|
|
try:
|
|
embeddings = embed_texts(test_texts)
|
|
print(f"\nBatch embedding shape: ({len(embeddings)}, {len(embeddings[0])})")
|
|
print("✓ Batch embeddings work")
|
|
except Exception as e:
|
|
print(f"✗ Batch embeddings failed: {e}")
|
|
|
|
# Test empty inputs
|
|
assert embed_text("") == [0.0] * size, "Empty text should return zero vector"
|
|
assert embed_texts([]) == [], "Empty list should return empty list"
|
|
print("✓ Empty input handling works")
|
|
|
|
print("\n✅ All tests passed!") |