@@ -76,11 +76,11 @@ class Item(Base):
76
76
77
77
78
78
def create_items ():
79
- session = Session (engine )
80
- session .add (Item (id = 1 , embedding = [1 , 1 , 1 ], half_embedding = [1 , 1 , 1 ], binary_embedding = '000' , sparse_embedding = SparseVector ([1 , 1 , 1 ])))
81
- session .add (Item (id = 2 , embedding = [2 , 2 , 2 ], half_embedding = [2 , 2 , 2 ], binary_embedding = '101' , sparse_embedding = SparseVector ([2 , 2 , 2 ])))
82
- session .add (Item (id = 3 , embedding = [1 , 1 , 2 ], half_embedding = [1 , 1 , 2 ], binary_embedding = '111' , sparse_embedding = SparseVector ([1 , 1 , 2 ])))
83
- session .commit ()
79
+ with Session (engine ) as session :
80
+ session .add (Item (id = 1 , embedding = [1 , 1 , 1 ], half_embedding = [1 , 1 , 1 ], binary_embedding = '000' , sparse_embedding = SparseVector ([1 , 1 , 1 ])))
81
+ session .add (Item (id = 2 , embedding = [2 , 2 , 2 ], half_embedding = [2 , 2 , 2 ], binary_embedding = '101' , sparse_embedding = SparseVector ([2 , 2 , 2 ])))
82
+ session .add (Item (id = 3 , embedding = [1 , 1 , 2 ], half_embedding = [1 , 1 , 2 ], binary_embedding = '111' , sparse_embedding = SparseVector ([1 , 1 , 2 ])))
83
+ session .commit ()
84
84
85
85
86
86
class TestSqlalchemy :
@@ -129,11 +129,11 @@ def test_orm(self):
129
129
item2 = Item (embedding = [4 , 5 , 6 ])
130
130
item3 = Item ()
131
131
132
- session = Session (engine )
133
- session .add (item )
134
- session .add (item2 )
135
- session .add (item3 )
136
- session .commit ()
132
+ with Session (engine ) as session :
133
+ session .add (item )
134
+ session .add (item2 )
135
+ session .add (item3 )
136
+ session .commit ()
137
137
138
138
stmt = select (Item )
139
139
with Session (engine ) as session :
@@ -148,11 +148,11 @@ def test_orm(self):
148
148
assert items [2 ].embedding is None
149
149
150
150
def test_vector (self ):
151
- session = Session (engine )
152
- session .add (Item (id = 1 , embedding = [1 , 2 , 3 ]))
153
- session .commit ()
154
- item = session .get (Item , 1 )
155
- assert item .embedding .tolist () == [1 , 2 , 3 ]
151
+ with Session (engine ) as session :
152
+ session .add (Item (id = 1 , embedding = [1 , 2 , 3 ]))
153
+ session .commit ()
154
+ item = session .get (Item , 1 )
155
+ assert item .embedding .tolist () == [1 , 2 , 3 ]
156
156
157
157
def test_vector_l2_distance (self ):
158
158
create_items ()
@@ -203,11 +203,11 @@ def test_vector_l1_distance_orm(self):
203
203
assert [v .id for v in items ] == [1 , 3 , 2 ]
204
204
205
205
def test_halfvec (self ):
206
- session = Session (engine )
207
- session .add (Item (id = 1 , half_embedding = [1 , 2 , 3 ]))
208
- session .commit ()
209
- item = session .get (Item , 1 )
210
- assert item .half_embedding .to_list () == [1 , 2 , 3 ]
206
+ with Session (engine ) as session :
207
+ session .add (Item (id = 1 , half_embedding = [1 , 2 , 3 ]))
208
+ session .commit ()
209
+ item = session .get (Item , 1 )
210
+ assert item .half_embedding .to_list () == [1 , 2 , 3 ]
211
211
212
212
def test_halfvec_l2_distance (self ):
213
213
create_items ()
@@ -258,11 +258,11 @@ def test_halfvec_l1_distance_orm(self):
258
258
assert [v .id for v in items ] == [1 , 3 , 2 ]
259
259
260
260
def test_bit (self ):
261
- session = Session (engine )
262
- session .add (Item (id = 1 , binary_embedding = '101' ))
263
- session .commit ()
264
- item = session .get (Item , 1 )
265
- assert item .binary_embedding == '101'
261
+ with Session (engine ) as session :
262
+ session .add (Item (id = 1 , binary_embedding = '101' ))
263
+ session .commit ()
264
+ item = session .get (Item , 1 )
265
+ assert item .binary_embedding == '101'
266
266
267
267
def test_bit_hamming_distance (self ):
268
268
create_items ()
@@ -289,11 +289,11 @@ def test_bit_jaccard_distance_orm(self):
289
289
assert [v .id for v in items ] == [2 , 3 , 1 ]
290
290
291
291
def test_sparsevec (self ):
292
- session = Session (engine )
293
- session .add (Item (id = 1 , sparse_embedding = [1 , 2 , 3 ]))
294
- session .commit ()
295
- item = session .get (Item , 1 )
296
- assert item .sparse_embedding .to_list () == [1 , 2 , 3 ]
292
+ with Session (engine ) as session :
293
+ session .add (Item (id = 1 , sparse_embedding = [1 , 2 , 3 ]))
294
+ session .commit ()
295
+ item = session .get (Item , 1 )
296
+ assert item .sparse_embedding .to_list () == [1 , 2 , 3 ]
297
297
298
298
def test_sparsevec_l2_distance (self ):
299
299
create_items ()
@@ -405,24 +405,24 @@ def test_sum_orm(self):
405
405
406
406
def test_bad_dimensions (self ):
407
407
item = Item (embedding = [1 , 2 ])
408
- session = Session (engine )
409
- session .add (item )
410
- with pytest .raises (StatementError , match = 'expected 3 dimensions, not 2' ):
411
- session .commit ()
408
+ with Session (engine ) as session :
409
+ session .add (item )
410
+ with pytest .raises (StatementError , match = 'expected 3 dimensions, not 2' ):
411
+ session .commit ()
412
412
413
413
def test_bad_ndim (self ):
414
414
item = Item (embedding = np .array ([[1 , 2 , 3 ]]))
415
- session = Session (engine )
416
- session .add (item )
417
- with pytest .raises (StatementError , match = 'expected ndim to be 1' ):
418
- session .commit ()
415
+ with Session (engine ) as session :
416
+ session .add (item )
417
+ with pytest .raises (StatementError , match = 'expected ndim to be 1' ):
418
+ session .commit ()
419
419
420
420
def test_bad_dtype (self ):
421
421
item = Item (embedding = np .array (['one' , 'two' , 'three' ]))
422
- session = Session (engine )
423
- session .add (item )
424
- with pytest .raises (StatementError , match = 'could not convert string to float' ):
425
- session .commit ()
422
+ with Session (engine ) as session :
423
+ session .add (item )
424
+ with pytest .raises (StatementError , match = 'could not convert string to float' ):
425
+ session .commit ()
426
426
427
427
def test_inspect (self ):
428
428
columns = inspect (engine ).get_columns ('sqlalchemy_orm_item' )
@@ -433,44 +433,48 @@ def test_literal_binds(self):
433
433
assert "embedding <-> '[1.0,2.0,3.0]'" in str (sql )
434
434
435
435
def test_insert (self ):
436
- session .execute (insert (Item ).values (embedding = np .array ([1 , 2 , 3 ])))
436
+ with Session (engine ) as session :
437
+ session .execute (insert (Item ).values (embedding = np .array ([1 , 2 , 3 ])))
437
438
438
439
def test_insert_bulk (self ):
439
- session .execute (insert (Item ), [{'embedding' : np .array ([1 , 2 , 3 ])}])
440
+ with Session (engine ) as session :
441
+ session .execute (insert (Item ), [{'embedding' : np .array ([1 , 2 , 3 ])}])
440
442
441
443
# register_vector in psycopg2 tests change this behavior
442
444
# def test_insert_text(self):
443
- # session.execute(text('INSERT INTO sqlalchemy_orm_item (embedding) VALUES (:embedding)'), {'embedding': np.array([1, 2, 3])})
445
+ # with Session(engine) as session:
446
+ # session.execute(text('INSERT INTO sqlalchemy_orm_item (embedding) VALUES (:embedding)'), {'embedding': np.array([1, 2, 3])})
444
447
445
448
def test_automap (self ):
446
449
metadata = MetaData ()
447
450
metadata .reflect (engine , only = ['sqlalchemy_orm_item' ])
448
451
AutoBase = automap_base (metadata = metadata )
449
452
AutoBase .prepare ()
450
453
AutoItem = AutoBase .classes .sqlalchemy_orm_item
451
- session .execute (insert (AutoItem ), [{'embedding' : np .array ([1 , 2 , 3 ])}])
452
- item = session .query (AutoItem ).first ()
453
- assert item .embedding .tolist () == [1 , 2 , 3 ]
454
+ with Session (engine ) as session :
455
+ session .execute (insert (AutoItem ), [{'embedding' : np .array ([1 , 2 , 3 ])}])
456
+ item = session .query (AutoItem ).first ()
457
+ assert item .embedding .tolist () == [1 , 2 , 3 ]
454
458
455
459
def test_vector_array (self ):
456
- session = Session (array_engine )
457
- session .add (Item (id = 1 , embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
458
- session .commit ()
460
+ with Session (array_engine ) as session :
461
+ session .add (Item (id = 1 , embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
462
+ session .commit ()
459
463
460
- # this fails if the driver does not cast arrays
461
- item = session .get (Item , 1 )
462
- assert item .embeddings [0 ].tolist () == [1 , 2 , 3 ]
463
- assert item .embeddings [1 ].tolist () == [4 , 5 , 6 ]
464
+ # this fails if the driver does not cast arrays
465
+ item = session .get (Item , 1 )
466
+ assert item .embeddings [0 ].tolist () == [1 , 2 , 3 ]
467
+ assert item .embeddings [1 ].tolist () == [4 , 5 , 6 ]
464
468
465
469
def test_halfvec_array (self ):
466
- session = Session (array_engine )
467
- session .add (Item (id = 1 , half_embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
468
- session .commit ()
470
+ with Session (array_engine ) as session :
471
+ session .add (Item (id = 1 , half_embeddings = [np .array ([1 , 2 , 3 ]), np .array ([4 , 5 , 6 ])]))
472
+ session .commit ()
469
473
470
- # this fails if the driver does not cast arrays
471
- item = session .get (Item , 1 )
472
- assert item .half_embeddings [0 ].to_list () == [1 , 2 , 3 ]
473
- assert item .half_embeddings [1 ].to_list () == [4 , 5 , 6 ]
474
+ # this fails if the driver does not cast arrays
475
+ item = session .get (Item , 1 )
476
+ assert item .half_embeddings [0 ].to_list () == [1 , 2 , 3 ]
477
+ assert item .half_embeddings [1 ].to_list () == [4 , 5 , 6 ]
474
478
475
479
def test_half_precision (self ):
476
480
create_items ()
@@ -479,13 +483,12 @@ def test_half_precision(self):
479
483
assert [v .id for v in items ] == [1 , 3 , 2 ]
480
484
481
485
def test_binary_quantize (self ):
482
- session = Session (engine )
483
- session .add (Item (id = 1 , embedding = [- 1 , - 2 , - 3 ]))
484
- session .add (Item (id = 2 , embedding = [1 , - 2 , 3 ]))
485
- session .add (Item (id = 3 , embedding = [1 , 2 , 3 ]))
486
- session .commit ()
487
-
488
486
with Session (engine ) as session :
487
+ session .add (Item (id = 1 , embedding = [- 1 , - 2 , - 3 ]))
488
+ session .add (Item (id = 2 , embedding = [1 , - 2 , 3 ]))
489
+ session .add (Item (id = 3 , embedding = [1 , 2 , 3 ]))
490
+ session .commit ()
491
+
489
492
distance = func .cast (func .binary_quantize (Item .embedding ), BIT (3 )).hamming_distance (func .binary_quantize (func .cast ([3 , - 1 , 2 ], VECTOR (3 ))))
490
493
items = session .query (Item ).order_by (distance ).all ()
491
494
assert [v .id for v in items ] == [2 , 3 , 1 ]
0 commit comments