Skip to content

Commit 70feec9

Browse files
committed
Added distance functions for Ent - #23
1 parent 83d40cb commit 70feec9

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.2.4 (unreleased)
2+
3+
- Added distance functions for Ent
4+
15
## 0.2.3 (2025-01-15)
26

37
- Added support for Postgres arrays for pgx

ent_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
_ "github.com/lib/pq"
1010
"github.com/pgvector/pgvector-go"
1111
"github.com/pgvector/pgvector-go/ent"
12+
"github.com/pgvector/pgvector-go/entvec"
1213
)
1314

1415
func TestEnt(t *testing.T) {
@@ -67,7 +68,7 @@ func TestEnt(t *testing.T) {
6768
items, err := client.Item.
6869
Query().
6970
Order(func(s *sql.Selector) {
70-
s.OrderExpr(sql.ExprP("embedding <-> $1", embedding))
71+
s.OrderExpr(entvec.L2Distance("embedding", embedding))
7172
}).
7273
Limit(5).
7374
All(ctx)

entvec/distance.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package entvec
2+
3+
import (
4+
"entgo.io/ent/dialect/sql"
5+
)
6+
7+
func L2Distance(column string, value any) sql.Querier {
8+
return sql.ExprFunc(func(b *sql.Builder) {
9+
b.Ident(column).WriteString(" <-> ").Arg(value)
10+
})
11+
}
12+
13+
func MaxInnerProduct(column string, value any) sql.Querier {
14+
return sql.ExprFunc(func(b *sql.Builder) {
15+
b.Ident(column).WriteString(" <#> ").Arg(value)
16+
})
17+
}
18+
19+
func CosineDistance(column string, value any) sql.Querier {
20+
return sql.ExprFunc(func(b *sql.Builder) {
21+
b.Ident(column).WriteString(" <=> ").Arg(value)
22+
})
23+
}
24+
25+
func L1Distance(column string, value any) sql.Querier {
26+
return sql.ExprFunc(func(b *sql.Builder) {
27+
b.Ident(column).WriteString(" <+> ").Arg(value)
28+
})
29+
}
30+
31+
func HammingDistance(column string, value any) sql.Querier {
32+
return sql.ExprFunc(func(b *sql.Builder) {
33+
b.Ident(column).WriteString(" <~> ").Arg(value)
34+
})
35+
}
36+
37+
func JaccardDistance(column string, value any) sql.Querier {
38+
return sql.ExprFunc(func(b *sql.Builder) {
39+
b.Ident(column).WriteString(" <%> ").Arg(value)
40+
})
41+
}

0 commit comments

Comments
 (0)