Handle FastEmbed generator results

This commit is contained in:
george
2026-06-06 12:52:04 +01:00
parent 7707a6306d
commit 30fe050182
2 changed files with 38 additions and 2 deletions
+9 -2
View File
@@ -86,8 +86,15 @@ def embed_text(text: str) -> List[float]:
return [0.0] * get_embedding_size() return [0.0] * get_embedding_size()
model = get_embedding_model() model = get_embedding_model()
embedding = model.embed([text]) embeddings = iter(model.embed([text]))
return embedding[0].tolist() try:
embedding = next(embeddings)
except StopIteration as exc:
raise RuntimeError("Embedding model returned no vector") from exc
if hasattr(embedding, "tolist"):
return embedding.tolist()
return list(embedding)
def embed_texts(texts: List[str]) -> List[List[float]]: def embed_texts(texts: List[str]) -> List[List[float]]:
+29
View File
@@ -0,0 +1,29 @@
"""Tests for FastEmbed result normalization."""
from unittest.mock import patch
from backend.app.embeddings import embed_text
class ArrayLike:
def tolist(self):
return [0.1, 0.2, 0.3]
def test_embed_text_consumes_generator_result():
class Model:
def embed(self, texts):
assert texts == ["query"]
yield ArrayLike()
with patch("backend.app.embeddings.get_embedding_model", return_value=Model()):
assert embed_text("query") == [0.1, 0.2, 0.3]
def test_embed_text_accepts_plain_vector_iterables():
class Model:
def embed(self, texts):
return iter([[0.4, 0.5]])
with patch("backend.app.embeddings.get_embedding_model", return_value=Model()):
assert embed_text("query") == [0.4, 0.5]