Skip to content

Commit 32f91e5

Browse files
committed
Improved array tests
1 parent 73016ab commit 32f91e5

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

pgx_test.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ type PgxItem struct {
1919
HalfEmbedding pgvector.HalfVector
2020
BinaryEmbedding string
2121
SparseEmbedding pgvector.SparseVector
22+
Embeddings []pgvector.Vector
2223
}
2324

2425
func CreatePgxItems(ctx context.Context, conn *pgx.Conn) {
@@ -28,23 +29,26 @@ func CreatePgxItems(ctx context.Context, conn *pgx.Conn) {
2829
HalfEmbedding: pgvector.NewHalfVector([]float32{1, 1, 1}),
2930
BinaryEmbedding: "000",
3031
SparseEmbedding: pgvector.NewSparseVector([]float32{1, 1, 1}),
32+
Embeddings: []pgvector.Vector{pgvector.NewVector([]float32{1, 1, 1})},
3133
},
3234
PgxItem{
3335
Embedding: pgvector.NewVector([]float32{2, 2, 2}),
3436
HalfEmbedding: pgvector.NewHalfVector([]float32{2, 2, 2}),
3537
BinaryEmbedding: "101",
3638
SparseEmbedding: pgvector.NewSparseVector([]float32{2, 2, 2}),
39+
Embeddings: []pgvector.Vector{pgvector.NewVector([]float32{2, 2, 2})},
3740
},
3841
PgxItem{
3942
Embedding: pgvector.NewVector([]float32{1, 1, 2}),
4043
HalfEmbedding: pgvector.NewHalfVector([]float32{1, 1, 2}),
4144
BinaryEmbedding: "111",
4245
SparseEmbedding: pgvector.NewSparseVector([]float32{1, 1, 2}),
46+
Embeddings: []pgvector.Vector{pgvector.NewVector([]float32{1, 1, 2})},
4347
},
4448
}
4549

4650
for _, item := range items {
47-
_, err := conn.Exec(ctx, "INSERT INTO pgx_items (embedding, half_embedding, binary_embedding, sparse_embedding) VALUES ($1, $2, $3, $4)", item.Embedding, item.HalfEmbedding, item.BinaryEmbedding, item.SparseEmbedding)
51+
_, err := conn.Exec(ctx, "INSERT INTO pgx_items (embedding, half_embedding, binary_embedding, sparse_embedding, embeddings) VALUES ($1, $2, $3, $4, $5)", item.Embedding, item.HalfEmbedding, item.BinaryEmbedding, item.SparseEmbedding, item.Embeddings)
4852
if err != nil {
4953
panic(err)
5054
}
@@ -75,7 +79,7 @@ func TestPgx(t *testing.T) {
7579
panic(err)
7680
}
7781

78-
_, err = conn.Exec(ctx, "CREATE TABLE pgx_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))")
82+
_, err = conn.Exec(ctx, "CREATE TABLE pgx_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3), embeddings vector(3)[])")
7983
if err != nil {
8084
panic(err)
8185
}
@@ -87,7 +91,7 @@ func TestPgx(t *testing.T) {
8791

8892
CreatePgxItems(ctx, conn)
8993

90-
rows, err := conn.Query(ctx, "SELECT id, embedding, half_embedding, binary_embedding, sparse_embedding, embedding <-> $1 FROM pgx_items ORDER BY embedding <-> $1 LIMIT 5", pgvector.NewVector([]float32{1, 1, 1}))
94+
rows, err := conn.Query(ctx, "SELECT id, embedding, half_embedding, binary_embedding, sparse_embedding, embeddings, embedding <-> $1 FROM pgx_items ORDER BY embedding <-> $1 LIMIT 5", pgvector.NewVector([]float32{1, 1, 1}))
9195
if err != nil {
9296
panic(err)
9397
}
@@ -100,7 +104,7 @@ func TestPgx(t *testing.T) {
100104
var item PgxItem
101105
var binaryEmbedding pgtype.Bits
102106
var distance float64
103-
err = rows.Scan(&item.Id, &item.Embedding, &item.HalfEmbedding, &binaryEmbedding, &item.SparseEmbedding, &distance)
107+
err = rows.Scan(&item.Id, &item.Embedding, &item.HalfEmbedding, &binaryEmbedding, &item.SparseEmbedding, &item.Embeddings, &distance)
104108
if err != nil {
105109
panic(err)
106110
}
@@ -128,6 +132,9 @@ func TestPgx(t *testing.T) {
128132
if !reflect.DeepEqual(items[1].SparseEmbedding.Slice(), []float32{1, 1, 2}) {
129133
t.Error()
130134
}
135+
if !reflect.DeepEqual(items[1].Embeddings, []pgvector.Vector{pgvector.NewVector([]float32{1, 1, 2})}) {
136+
t.Error()
137+
}
131138
if distances[0] != 0 || distances[1] != 1 || distances[2] != math.Sqrt(3) {
132139
t.Error()
133140
}

0 commit comments

Comments
 (0)