Handle FastEmbed generator results
This commit is contained in:
@@ -86,8 +86,15 @@ def embed_text(text: str) -> List[float]:
|
||||
return [0.0] * get_embedding_size()
|
||||
|
||||
model = get_embedding_model()
|
||||
embedding = model.embed([text])
|
||||
return embedding[0].tolist()
|
||||
embeddings = iter(model.embed([text]))
|
||||
try:
|
||||
embedding = next(embeddings)
|
||||
except StopIteration as exc:
|
||||
raise RuntimeError("Embedding model returned no vector") from exc
|
||||
|
||||
if hasattr(embedding, "tolist"):
|
||||
return embedding.tolist()
|
||||
return list(embedding)
|
||||
|
||||
|
||||
def embed_texts(texts: List[str]) -> List[List[float]]:
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Tests for FastEmbed result normalization."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.app.embeddings import embed_text
|
||||
|
||||
|
||||
class ArrayLike:
|
||||
def tolist(self):
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
def test_embed_text_consumes_generator_result():
|
||||
class Model:
|
||||
def embed(self, texts):
|
||||
assert texts == ["query"]
|
||||
yield ArrayLike()
|
||||
|
||||
with patch("backend.app.embeddings.get_embedding_model", return_value=Model()):
|
||||
assert embed_text("query") == [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
def test_embed_text_accepts_plain_vector_iterables():
|
||||
class Model:
|
||||
def embed(self, texts):
|
||||
return iter([[0.4, 0.5]])
|
||||
|
||||
with patch("backend.app.embeddings.get_embedding_model", return_value=Model()):
|
||||
assert embed_text("query") == [0.4, 0.5]
|
||||
Reference in New Issue
Block a user