Skip to content

Commit f7bedb5

Browse files
authored
Merge pull request #141 from alicebob/setstore
use a normal set as a ZUNIONSTORE source
2 parents 6d16f8d + 96b23f6 commit f7bedb5

File tree

3 files changed

+43
-23
lines changed

3 files changed

+43
-23
lines changed

cmd_sorted_set.go

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,32 +1265,42 @@ func (m *Miniredis) cmdZunionstore(c *server.Peer, cmd string, args []string) {
12651265
if !db.exists(key) {
12661266
continue
12671267
}
1268-
if db.t(key) != "zset" {
1268+
1269+
var set map[string]float64
1270+
switch db.t(key) {
1271+
case "set":
1272+
set = map[string]float64{}
1273+
for elem := range db.setKeys[key] {
1274+
set[elem] = 1.0
1275+
}
1276+
case "zset":
1277+
set = db.sortedSet(key)
1278+
default:
12691279
c.WriteError(msgWrongType)
12701280
return
12711281
}
1272-
for _, el := range db.ssetElements(key) {
1273-
score := el.score
1282+
1283+
for member, score := range set {
12741284
if withWeights {
12751285
score *= weights[i]
12761286
}
1277-
old, ok := sset[el.member]
1287+
old, ok := sset[member]
12781288
if !ok {
1279-
sset[el.member] = score
1289+
sset[member] = score
12801290
continue
12811291
}
12821292
switch aggregate {
12831293
default:
12841294
panic("Invalid aggregate")
12851295
case "sum":
1286-
sset[el.member] += score
1296+
sset[member] += score
12871297
case "min":
12881298
if score < old {
1289-
sset[el.member] = score
1299+
sset[member] = score
12901300
}
12911301
case "max":
12921302
if score > old {
1293-
sset[el.member] = score
1303+
sset[member] = score
12941304
}
12951305
}
12961306
}

cmd_sorted_set_test.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,19 +1174,17 @@ func TestZunionstore(t *testing.T) {
11741174
s.ZAdd("h2", 1.0, "field1")
11751175
s.ZAdd("h2", 2.0, "field2")
11761176

1177-
// Simple case
1178-
{
1177+
t.Run("simple case", func(t *testing.T) {
11791178
res, err := redis.Int(c.Do("ZUNIONSTORE", "new", 2, "h1", "h2"))
11801179
ok(t, err)
11811180
equals(t, 2, res)
11821181

11831182
ss, err := s.SortedSet("new")
11841183
ok(t, err)
11851184
equals(t, map[string]float64{"field1": 2, "field2": 4}, ss)
1186-
}
1185+
})
11871186

1188-
// Merge destination with itself.
1189-
{
1187+
t.Run("merge destination with itself", func(t *testing.T) {
11901188
s.ZAdd("h3", 1.0, "field1")
11911189
s.ZAdd("h3", 3.0, "field3")
11921190

@@ -1197,32 +1195,37 @@ func TestZunionstore(t *testing.T) {
11971195
ss, err := s.SortedSet("h3")
11981196
ok(t, err)
11991197
equals(t, map[string]float64{"field1": 2, "field2": 2, "field3": 3}, ss)
1200-
}
1198+
})
12011199

1202-
// WEIGHTS
1203-
{
1200+
t.Run("WEIGHTS", func(t *testing.T) {
12041201
res, err := redis.Int(c.Do("ZUNIONSTORE", "weighted", 2, "h1", "h2", "WeIgHtS", "4.5", "12"))
12051202
ok(t, err)
12061203
equals(t, 2, res)
12071204

12081205
ss, err := s.SortedSet("weighted")
12091206
ok(t, err)
12101207
equals(t, map[string]float64{"field1": 16.5, "field2": 33}, ss)
1211-
}
1208+
})
12121209

1213-
// AGGREGATE
1214-
{
1210+
t.Run("AGGREGATE", func(t *testing.T) {
12151211
res, err := redis.Int(c.Do("ZUNIONSTORE", "aggr", 2, "h1", "h2", "AgGrEgAtE", "min"))
12161212
ok(t, err)
12171213
equals(t, 2, res)
12181214

12191215
ss, err := s.SortedSet("aggr")
12201216
ok(t, err)
12211217
equals(t, map[string]float64{"field1": 1.0, "field2": 2.0}, ss)
1222-
}
1218+
})
12231219

1224-
// Wrong usage
1225-
{
1220+
t.Run("normal set", func(t *testing.T) {
1221+
_, err := c.Do("SADD", "set", "aap", "noot", "mies")
1222+
ok(t, err)
1223+
res, err := redis.Int(c.Do("ZUNIONSTORE", "aggr", 1, "set"))
1224+
ok(t, err)
1225+
equals(t, 3, res)
1226+
})
1227+
1228+
t.Run("wrong usage", func(t *testing.T) {
12261229
_, err := redis.Int(c.Do("ZUNIONSTORE"))
12271230
assert(t, err != nil, "do ZUNIONSTORE error")
12281231
_, err = redis.Int(c.Do("ZUNIONSTORE", "set"))
@@ -1255,7 +1258,7 @@ func TestZunionstore(t *testing.T) {
12551258
s.Set("str", "value")
12561259
_, err = redis.Int(c.Do("ZUNIONSTORE", "set", 1, "str"))
12571260
assert(t, err != nil, "do ZUNIONSTORE error")
1258-
}
1261+
})
12591262
}
12601263

12611264
func TestZinterstore(t *testing.T) {

integration/sorted_set_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,13 @@ func TestZunionstore(t *testing.T) {
624624
succ("TYPE", "h1"),
625625
succ("TYPE", "h2"),
626626
)
627+
// not a sorted set, still fine
628+
testCommands(t,
629+
succ("SADD", "super", "1", "2", "3"),
630+
succ("SADD", "exclude", "3"),
631+
succ("ZUNIONSTORE", "tmp", "2", "super", "exclude", "weights", "1", "0", "aggregate", "min"),
632+
succ("ZRANGE", "tmp", "0", "-1", "withscores"),
633+
)
627634
}
628635

629636
func TestZinterstore(t *testing.T) {

0 commit comments

Comments
 (0)