Handle FastEmbed generator results
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
"""Tests for FastEmbed result normalization."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from backend.app.embeddings import embed_text
|
||||
|
||||
|
||||
class ArrayLike:
|
||||
def tolist(self):
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
def test_embed_text_consumes_generator_result():
|
||||
class Model:
|
||||
def embed(self, texts):
|
||||
assert texts == ["query"]
|
||||
yield ArrayLike()
|
||||
|
||||
with patch("backend.app.embeddings.get_embedding_model", return_value=Model()):
|
||||
assert embed_text("query") == [0.1, 0.2, 0.3]
|
||||
|
||||
|
||||
def test_embed_text_accepts_plain_vector_iterables():
|
||||
class Model:
|
||||
def embed(self, texts):
|
||||
return iter([[0.4, 0.5]])
|
||||
|
||||
with patch("backend.app.embeddings.get_embedding_model", return_value=Model()):
|
||||
assert embed_text("query") == [0.4, 0.5]
|
||||
Reference in New Issue
Block a user