Skip to content

Commit 00cd08e

Browse files
committed
Improved tests
1 parent 8a7040d commit 00cd08e

File tree

1 file changed

+71
-68
lines changed

1 file changed

+71
-68
lines changed

tests/test_sqlalchemy.py

+71-68
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ class Item(Base):
7676

7777

7878
def create_items():
79-
session = Session(engine)
80-
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])))
81-
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])))
82-
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])))
83-
session.commit()
79+
with Session(engine) as session:
80+
session.add(Item(id=1, embedding=[1, 1, 1], half_embedding=[1, 1, 1], binary_embedding='000', sparse_embedding=SparseVector([1, 1, 1])))
81+
session.add(Item(id=2, embedding=[2, 2, 2], half_embedding=[2, 2, 2], binary_embedding='101', sparse_embedding=SparseVector([2, 2, 2])))
82+
session.add(Item(id=3, embedding=[1, 1, 2], half_embedding=[1, 1, 2], binary_embedding='111', sparse_embedding=SparseVector([1, 1, 2])))
83+
session.commit()
8484

8585

8686
class TestSqlalchemy:
@@ -129,11 +129,11 @@ def test_orm(self):
129129
item2 = Item(embedding=[4, 5, 6])
130130
item3 = Item()
131131

132-
session = Session(engine)
133-
session.add(item)
134-
session.add(item2)
135-
session.add(item3)
136-
session.commit()
132+
with Session(engine) as session:
133+
session.add(item)
134+
session.add(item2)
135+
session.add(item3)
136+
session.commit()
137137

138138
stmt = select(Item)
139139
with Session(engine) as session:
@@ -148,11 +148,11 @@ def test_orm(self):
148148
assert items[2].embedding is None
149149

150150
def test_vector(self):
151-
session = Session(engine)
152-
session.add(Item(id=1, embedding=[1, 2, 3]))
153-
session.commit()
154-
item = session.get(Item, 1)
155-
assert item.embedding.tolist() == [1, 2, 3]
151+
with Session(engine) as session:
152+
session.add(Item(id=1, embedding=[1, 2, 3]))
153+
session.commit()
154+
item = session.get(Item, 1)
155+
assert item.embedding.tolist() == [1, 2, 3]
156156

157157
def test_vector_l2_distance(self):
158158
create_items()
@@ -203,11 +203,11 @@ def test_vector_l1_distance_orm(self):
203203
assert [v.id for v in items] == [1, 3, 2]
204204

205205
def test_halfvec(self):
206-
session = Session(engine)
207-
session.add(Item(id=1, half_embedding=[1, 2, 3]))
208-
session.commit()
209-
item = session.get(Item, 1)
210-
assert item.half_embedding.to_list() == [1, 2, 3]
206+
with Session(engine) as session:
207+
session.add(Item(id=1, half_embedding=[1, 2, 3]))
208+
session.commit()
209+
item = session.get(Item, 1)
210+
assert item.half_embedding.to_list() == [1, 2, 3]
211211

212212
def test_halfvec_l2_distance(self):
213213
create_items()
@@ -258,11 +258,11 @@ def test_halfvec_l1_distance_orm(self):
258258
assert [v.id for v in items] == [1, 3, 2]
259259

260260
def test_bit(self):
261-
session = Session(engine)
262-
session.add(Item(id=1, binary_embedding='101'))
263-
session.commit()
264-
item = session.get(Item, 1)
265-
assert item.binary_embedding == '101'
261+
with Session(engine) as session:
262+
session.add(Item(id=1, binary_embedding='101'))
263+
session.commit()
264+
item = session.get(Item, 1)
265+
assert item.binary_embedding == '101'
266266

267267
def test_bit_hamming_distance(self):
268268
create_items()
@@ -289,11 +289,11 @@ def test_bit_jaccard_distance_orm(self):
289289
assert [v.id for v in items] == [2, 3, 1]
290290

291291
def test_sparsevec(self):
292-
session = Session(engine)
293-
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
294-
session.commit()
295-
item = session.get(Item, 1)
296-
assert item.sparse_embedding.to_list() == [1, 2, 3]
292+
with Session(engine) as session:
293+
session.add(Item(id=1, sparse_embedding=[1, 2, 3]))
294+
session.commit()
295+
item = session.get(Item, 1)
296+
assert item.sparse_embedding.to_list() == [1, 2, 3]
297297

298298
def test_sparsevec_l2_distance(self):
299299
create_items()
@@ -405,24 +405,24 @@ def test_sum_orm(self):
405405

406406
def test_bad_dimensions(self):
407407
item = Item(embedding=[1, 2])
408-
session = Session(engine)
409-
session.add(item)
410-
with pytest.raises(StatementError, match='expected 3 dimensions, not 2'):
411-
session.commit()
408+
with Session(engine) as session:
409+
session.add(item)
410+
with pytest.raises(StatementError, match='expected 3 dimensions, not 2'):
411+
session.commit()
412412

413413
def test_bad_ndim(self):
414414
item = Item(embedding=np.array([[1, 2, 3]]))
415-
session = Session(engine)
416-
session.add(item)
417-
with pytest.raises(StatementError, match='expected ndim to be 1'):
418-
session.commit()
415+
with Session(engine) as session:
416+
session.add(item)
417+
with pytest.raises(StatementError, match='expected ndim to be 1'):
418+
session.commit()
419419

420420
def test_bad_dtype(self):
421421
item = Item(embedding=np.array(['one', 'two', 'three']))
422-
session = Session(engine)
423-
session.add(item)
424-
with pytest.raises(StatementError, match='could not convert string to float'):
425-
session.commit()
422+
with Session(engine) as session:
423+
session.add(item)
424+
with pytest.raises(StatementError, match='could not convert string to float'):
425+
session.commit()
426426

427427
def test_inspect(self):
428428
columns = inspect(engine).get_columns('sqlalchemy_orm_item')
@@ -433,44 +433,48 @@ def test_literal_binds(self):
433433
assert "embedding <-> '[1.0,2.0,3.0]'" in str(sql)
434434

435435
def test_insert(self):
436-
session.execute(insert(Item).values(embedding=np.array([1, 2, 3])))
436+
with Session(engine) as session:
437+
session.execute(insert(Item).values(embedding=np.array([1, 2, 3])))
437438

438439
def test_insert_bulk(self):
439-
session.execute(insert(Item), [{'embedding': np.array([1, 2, 3])}])
440+
with Session(engine) as session:
441+
session.execute(insert(Item), [{'embedding': np.array([1, 2, 3])}])
440442

441443
# register_vector in psycopg2 tests change this behavior
442444
# def test_insert_text(self):
443-
# session.execute(text('INSERT INTO sqlalchemy_orm_item (embedding) VALUES (:embedding)'), {'embedding': np.array([1, 2, 3])})
445+
# with Session(engine) as session:
446+
# session.execute(text('INSERT INTO sqlalchemy_orm_item (embedding) VALUES (:embedding)'), {'embedding': np.array([1, 2, 3])})
444447

445448
def test_automap(self):
446449
metadata = MetaData()
447450
metadata.reflect(engine, only=['sqlalchemy_orm_item'])
448451
AutoBase = automap_base(metadata=metadata)
449452
AutoBase.prepare()
450453
AutoItem = AutoBase.classes.sqlalchemy_orm_item
451-
session.execute(insert(AutoItem), [{'embedding': np.array([1, 2, 3])}])
452-
item = session.query(AutoItem).first()
453-
assert item.embedding.tolist() == [1, 2, 3]
454+
with Session(engine) as session:
455+
session.execute(insert(AutoItem), [{'embedding': np.array([1, 2, 3])}])
456+
item = session.query(AutoItem).first()
457+
assert item.embedding.tolist() == [1, 2, 3]
454458

455459
def test_vector_array(self):
456-
session = Session(array_engine)
457-
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
458-
session.commit()
460+
with Session(array_engine) as session:
461+
session.add(Item(id=1, embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
462+
session.commit()
459463

460-
# this fails if the driver does not cast arrays
461-
item = session.get(Item, 1)
462-
assert item.embeddings[0].tolist() == [1, 2, 3]
463-
assert item.embeddings[1].tolist() == [4, 5, 6]
464+
# this fails if the driver does not cast arrays
465+
item = session.get(Item, 1)
466+
assert item.embeddings[0].tolist() == [1, 2, 3]
467+
assert item.embeddings[1].tolist() == [4, 5, 6]
464468

465469
def test_halfvec_array(self):
466-
session = Session(array_engine)
467-
session.add(Item(id=1, half_embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
468-
session.commit()
470+
with Session(array_engine) as session:
471+
session.add(Item(id=1, half_embeddings=[np.array([1, 2, 3]), np.array([4, 5, 6])]))
472+
session.commit()
469473

470-
# this fails if the driver does not cast arrays
471-
item = session.get(Item, 1)
472-
assert item.half_embeddings[0].to_list() == [1, 2, 3]
473-
assert item.half_embeddings[1].to_list() == [4, 5, 6]
474+
# this fails if the driver does not cast arrays
475+
item = session.get(Item, 1)
476+
assert item.half_embeddings[0].to_list() == [1, 2, 3]
477+
assert item.half_embeddings[1].to_list() == [4, 5, 6]
474478

475479
def test_half_precision(self):
476480
create_items()
@@ -479,13 +483,12 @@ def test_half_precision(self):
479483
assert [v.id for v in items] == [1, 3, 2]
480484

481485
def test_binary_quantize(self):
482-
session = Session(engine)
483-
session.add(Item(id=1, embedding=[-1, -2, -3]))
484-
session.add(Item(id=2, embedding=[1, -2, 3]))
485-
session.add(Item(id=3, embedding=[1, 2, 3]))
486-
session.commit()
487-
488486
with Session(engine) as session:
487+
session.add(Item(id=1, embedding=[-1, -2, -3]))
488+
session.add(Item(id=2, embedding=[1, -2, 3]))
489+
session.add(Item(id=3, embedding=[1, 2, 3]))
490+
session.commit()
491+
489492
distance = func.cast(func.binary_quantize(Item.embedding), BIT(3)).hamming_distance(func.binary_quantize(func.cast([3, -1, 2], VECTOR(3))))
490493
items = session.query(Item).order_by(distance).all()
491494
assert [v.id for v in items] == [2, 3, 1]

0 commit comments

Comments
 (0)