Handle FastEmbed generator results
This commit is contained in:
@@ -86,8 +86,15 @@ def embed_text(text: str) -> List[float]:
|
|||||||
return [0.0] * get_embedding_size()
|
return [0.0] * get_embedding_size()
|
||||||
|
|
||||||
model = get_embedding_model()
|
model = get_embedding_model()
|
||||||
embedding = model.embed([text])
|
embeddings = iter(model.embed([text]))
|
||||||
return embedding[0].tolist()
|
try:
|
||||||
|
embedding = next(embeddings)
|
||||||
|
except StopIteration as exc:
|
||||||
|
raise RuntimeError("Embedding model returned no vector") from exc
|
||||||
|
|
||||||
|
if hasattr(embedding, "tolist"):
|
||||||
|
return embedding.tolist()
|
||||||
|
return list(embedding)
|
||||||
|
|
||||||
|
|
||||||
def embed_texts(texts: List[str]) -> List[List[float]]:
|
def embed_texts(texts: List[str]) -> List[List[float]]:
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Tests for FastEmbed result normalization."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from backend.app.embeddings import embed_text
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayLike:
|
||||||
|
def tolist(self):
|
||||||
|
return [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_text_consumes_generator_result():
|
||||||
|
class Model:
|
||||||
|
def embed(self, texts):
|
||||||
|
assert texts == ["query"]
|
||||||
|
yield ArrayLike()
|
||||||
|
|
||||||
|
with patch("backend.app.embeddings.get_embedding_model", return_value=Model()):
|
||||||
|
assert embed_text("query") == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_embed_text_accepts_plain_vector_iterables():
|
||||||
|
class Model:
|
||||||
|
def embed(self, texts):
|
||||||
|
return iter([[0.4, 0.5]])
|
||||||
|
|
||||||
|
with patch("backend.app.embeddings.get_embedding_model", return_value=Model()):
|
||||||
|
assert embed_text("query") == [0.4, 0.5]
|
||||||
Reference in New Issue
Block a user