Skip to content

Commit 2496340

Browse files
committed
Added support for pg8000
1 parent 8443ff5 commit 2496340

File tree

5 files changed

+136
-1
lines changed

5 files changed

+136
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.4.0 (unreleased)
22

33
- Added top-level `pgvector` package
4+
- Added support for pg8000
45
- Changed `globally` option to default to `False` for Psycopg 2
56
- Changed `arrays` option to default to `True` for Psycopg 2
67
- Fixed equality for `Vector`, `HalfVector`, `Bit`, and `SparseVector` classes

README.md

+47-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[pgvector](https://github.com/pgvector/pgvector) support for Python
44

5-
Supports [Django](https://github.com/django/django), [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy), [SQLModel](https://github.com/tiangolo/sqlmodel), [Psycopg 3](https://github.com/psycopg/psycopg), [Psycopg 2](https://github.com/psycopg/psycopg2), [asyncpg](https://github.com/MagicStack/asyncpg), and [Peewee](https://github.com/coleifer/peewee)
5+
Supports [Django](https://github.com/django/django), [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy), [SQLModel](https://github.com/tiangolo/sqlmodel), [Psycopg 3](https://github.com/psycopg/psycopg), [Psycopg 2](https://github.com/psycopg/psycopg2), [asyncpg](https://github.com/MagicStack/asyncpg), [pg8000](https://github.com/tlocke/pg8000), and [Peewee](https://github.com/coleifer/peewee)
66

77
[![Build Status](https://github.com/pgvector/pgvector-python/actions/workflows/build.yml/badge.svg)](https://github.com/pgvector/pgvector-python/actions)
88

@@ -22,6 +22,7 @@ And follow the instructions for your database library:
2222
- [Psycopg 3](#psycopg-3)
2323
- [Psycopg 2](#psycopg-2)
2424
- [asyncpg](#asyncpg)
25+
- [pg8000](#pg8000) [unreleased]
2526
- [Peewee](#peewee)
2627

2728
Or check out some examples:
@@ -562,6 +563,51 @@ await conn.execute('CREATE INDEX ON items USING ivfflat (embedding vector_l2_ops
562563

563564
Use `vector_ip_ops` for inner product and `vector_cosine_ops` for cosine distance
564565

566+
## pg8000
567+
568+
Enable the extension
569+
570+
```python
571+
conn.run('CREATE EXTENSION IF NOT EXISTS vector')
572+
```
573+
574+
Register the vector type with your connection
575+
576+
```python
577+
from pgvector.pg8000 import register_vector
578+
579+
register_vector(conn)
580+
```
581+
582+
Create a table
583+
584+
```python
585+
conn.run('CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3))')
586+
```
587+
588+
Insert a vector
589+
590+
```python
591+
embedding = np.array([1, 2, 3])
592+
conn.run('INSERT INTO items (embedding) VALUES (:embedding)', embedding=embedding)
593+
```
594+
595+
Get the nearest neighbors to a vector
596+
597+
```python
598+
conn.run('SELECT * FROM items ORDER BY embedding <-> :embedding LIMIT 5', embedding=embedding)
599+
```
600+
601+
Add an approximate index
602+
603+
```python
604+
conn.run('CREATE INDEX ON items USING hnsw (embedding vector_l2_ops)')
605+
# or
606+
conn.run('CREATE INDEX ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 100)')
607+
```
608+
609+
Use `vector_ip_ops` for inner product and `vector_cosine_ops` for cosine distance
610+
565611
## Peewee
566612

567613
Add a vector column

pgvector/pg8000/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .register import register_vector
2+
3+
__all__ = [
4+
'register_vector'
5+
]

pgvector/pg8000/register.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
from .. import Vector, HalfVector, SparseVector
3+
4+
5+
def register_vector(conn):
6+
# use to_regtype to get first matching type in search path
7+
res = conn.run("SELECT typname, oid FROM pg_type WHERE oid IN (to_regtype('vector'), to_regtype('halfvec'), to_regtype('sparsevec'))")
8+
type_info = dict(res)
9+
10+
if 'vector' not in type_info:
11+
raise RuntimeError('vector type not found in the database')
12+
13+
conn.register_out_adapter(Vector, Vector._to_db)
14+
conn.register_out_adapter(np.ndarray, Vector._to_db)
15+
conn.register_in_adapter(type_info['vector'], Vector._from_db)
16+
17+
if 'halfvec' in type_info:
18+
conn.register_out_adapter(HalfVector, HalfVector._to_db)
19+
conn.register_in_adapter(type_info['halfvec'], HalfVector._from_db)
20+
21+
if 'sparsevec' in type_info:
22+
conn.register_out_adapter(SparseVector, SparseVector._to_db)
23+
conn.register_in_adapter(type_info['sparsevec'], SparseVector._from_db)

tests/test_pg8000.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
import os
3+
from pgvector import HalfVector, SparseVector, Vector
4+
from pgvector.pg8000 import register_vector
5+
from pg8000.native import Connection
6+
7+
conn = Connection(os.environ["USER"], database='pgvector_python_test')
8+
9+
conn.run('CREATE EXTENSION IF NOT EXISTS vector')
10+
conn.run('DROP TABLE IF EXISTS pg8000_items')
11+
conn.run('CREATE TABLE pg8000_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3), embeddings vector[], half_embeddings halfvec[], sparse_embeddings sparsevec[])')
12+
13+
register_vector(conn)
14+
15+
16+
class TestPg8000:
17+
def setup_method(self):
18+
conn.run('DELETE FROM pg8000_items')
19+
20+
def test_vector(self):
21+
embedding = np.array([1.5, 2, 3])
22+
conn.run('INSERT INTO pg8000_items (embedding) VALUES (:embedding), (NULL)', embedding=embedding)
23+
24+
res = conn.run('SELECT embedding FROM pg8000_items ORDER BY id')
25+
assert np.array_equal(res[0][0], embedding)
26+
assert res[0][0].dtype == np.float32
27+
assert res[1][0] is None
28+
29+
def test_vector_class(self):
30+
embedding = Vector([1.5, 2, 3])
31+
conn.run('INSERT INTO pg8000_items (embedding) VALUES (:embedding), (NULL)', embedding=embedding)
32+
33+
res = conn.run('SELECT embedding FROM pg8000_items ORDER BY id')
34+
assert np.array_equal(res[0][0], embedding.to_numpy())
35+
assert res[0][0].dtype == np.float32
36+
assert res[1][0] is None
37+
38+
def test_halfvec(self):
39+
embedding = HalfVector([1.5, 2, 3])
40+
conn.run('INSERT INTO pg8000_items (half_embedding) VALUES (:embedding), (NULL)', embedding=embedding)
41+
42+
res = conn.run('SELECT half_embedding FROM pg8000_items ORDER BY id')
43+
assert res[0][0] == embedding
44+
assert res[1][0] is None
45+
46+
def test_bit(self):
47+
embedding = '101'
48+
conn.run('INSERT INTO pg8000_items (binary_embedding) VALUES (:embedding), (NULL)', embedding=embedding)
49+
50+
res = conn.run('SELECT binary_embedding FROM pg8000_items ORDER BY id')
51+
assert res[0][0] == '101'
52+
assert res[1][0] is None
53+
54+
def test_sparsevec(self):
55+
embedding = SparseVector([1.5, 2, 3])
56+
conn.run('INSERT INTO pg8000_items (sparse_embedding) VALUES (:embedding), (NULL)', embedding=embedding)
57+
58+
res = conn.run('SELECT sparse_embedding FROM pg8000_items ORDER BY id')
59+
assert res[0][0] == embedding
60+
assert res[1][0] is None

0 commit comments

Comments
 (0)