Skip to content

Commit dea7631

Browse files
authored
Lua .call() not via a socket (#126)
* don't go over a socker for lua .call() This solves some locking issues, and is nicer anyway.
1 parent f930c06 commit dea7631

File tree

12 files changed

+282
-64
lines changed

12 files changed

+282
-64
lines changed

cmd_connection.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) {
6363
if m.checkPubsub(c) {
6464
return
6565
}
66+
if getCtx(c).nested {
67+
c.WriteError(msgNotFromScripts)
68+
return
69+
}
6670

6771
pw := args[0]
6872

cmd_pubsub.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) {
2929
if !m.handleAuth(c) {
3030
return
3131
}
32+
if getCtx(c).nested {
33+
c.WriteError(msgNotFromScripts)
34+
return
35+
}
3236

3337
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
3438
sub := m.subscribedState(c)
@@ -49,6 +53,10 @@ func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) {
4953
if !m.handleAuth(c) {
5054
return
5155
}
56+
if getCtx(c).nested {
57+
c.WriteError(msgNotFromScripts)
58+
return
59+
}
5260

5361
channels := args
5462

@@ -86,6 +94,10 @@ func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) {
8694
if !m.handleAuth(c) {
8795
return
8896
}
97+
if getCtx(c).nested {
98+
c.WriteError(msgNotFromScripts)
99+
return
100+
}
89101

90102
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
91103
sub := m.subscribedState(c)
@@ -106,6 +118,10 @@ func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) {
106118
if !m.handleAuth(c) {
107119
return
108120
}
121+
if getCtx(c).nested {
122+
c.WriteError(msgNotFromScripts)
123+
return
124+
}
109125

110126
patterns := args
111127

cmd_scripting.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) {
5151
luajson.Preload(l)
5252
requireGlobal(l, "cjson", "json")
5353

54-
m.Unlock()
55-
conn := m.redigo()
56-
m.Lock()
57-
defer conn.Close()
58-
5954
// set global variable KEYS
6055
keysTable := l.NewTable()
6156
keysS, args := args[0], args[1:]
@@ -84,7 +79,7 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) {
8479
}
8580
l.SetGlobal("ARGV", argvTable)
8681

87-
redisFuncs := mkLuaFuncs(conn)
82+
redisFuncs := mkLuaFuncs(m.srv, c)
8883
// Register command handlers
8984
l.Push(l.NewFunction(func(l *lua.LState) int {
9085
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
@@ -97,8 +92,6 @@ func (m *Miniredis) runLuaScript(c *server.Peer, script string, args []string) {
9792
l.Push(lua.LString("redis"))
9893
l.Call(1, 0)
9994

100-
m.Unlock() // This runs in a transaction, but can access our db recursively
101-
defer m.Lock()
10295
if err := l.DoString(script); err != nil {
10396
c.WriteError(errLuaParseError(err))
10497
return
@@ -120,6 +113,11 @@ func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
120113
return
121114
}
122115

116+
if getCtx(c).nested {
117+
c.WriteError(msgNotFromScripts)
118+
return
119+
}
120+
123121
script, args := args[0], args[1:]
124122

125123
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
@@ -139,6 +137,10 @@ func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
139137
if m.checkPubsub(c) {
140138
return
141139
}
140+
if getCtx(c).nested {
141+
c.WriteError(msgNotFromScripts)
142+
return
143+
}
142144

143145
sha, args := args[0], args[1:]
144146

@@ -166,6 +168,11 @@ func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
166168
return
167169
}
168170

171+
if getCtx(c).nested {
172+
c.WriteError(msgNotFromScripts)
173+
return
174+
}
175+
169176
subcmd, args := args[0], args[1:]
170177

171178
withTx(m, c, func(c *server.Peer, ctx *connCtx) {

cmd_transactions.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ func (m *Miniredis) cmdMulti(c *server.Peer, cmd string, args []string) {
2929
}
3030

3131
ctx := getCtx(c)
32-
32+
if ctx.nested {
33+
c.WriteError(msgNotFromScripts)
34+
return
35+
}
3336
if inTx(ctx) {
3437
c.WriteError("ERR MULTI calls can not be nested")
3538
return
@@ -55,7 +58,10 @@ func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) {
5558
}
5659

5760
ctx := getCtx(c)
58-
61+
if ctx.nested {
62+
c.WriteError(msgNotFromScripts)
63+
return
64+
}
5965
if !inTx(ctx) {
6066
c.WriteError("ERR EXEC without MULTI")
6167
return
@@ -130,6 +136,10 @@ func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) {
130136
}
131137

132138
ctx := getCtx(c)
139+
if ctx.nested {
140+
c.WriteError(msgNotFromScripts)
141+
return
142+
}
133143
if inTx(ctx) {
134144
c.WriteError("ERR WATCH in MULTI")
135145
return

integration/script_test.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ func TestEval(t *testing.T) {
2020
succ("EVAL", "return redis.call('GET', 'nosuch')==nil", 0),
2121
succ("EVAL", "local a = redis.call('MGET', 'bar'); return a[1] == false", 0),
2222
succ("EVAL", "local a = redis.call('MGET', 'bar'); return a[1] == nil", 0),
23+
succ("EVAL", "return redis.call('ZRANGE', 'q', 0, -1)", 0),
24+
succ("EVAL", "return redis.call('LPOP', 'foo')", 0),
2325

2426
// failure cases
2527
fail("EVAL"),
@@ -98,6 +100,7 @@ func TestLua(t *testing.T) {
98100
succ("EVAL", "return 3.9999+0.201", 0),
99101
succ("EVAL", "return {{1}}", 0),
100102
succ("EVAL", "return {1,{1,{1,'bar'}}}", 0),
103+
succ("EVAL", "return nil", 0),
101104
)
102105

103106
// special returns
@@ -277,7 +280,7 @@ func TestLuaCall(t *testing.T) {
277280
succ("GET", "res"),
278281
)
279282

280-
// call() with transaction commands
283+
// call() with non-allowed commands
281284
testCommands(t,
282285
succ("SET", "foo", 1),
283286

@@ -289,6 +292,42 @@ func TestLuaCall(t *testing.T) {
289292
"This Redis command is not allowed from scripts",
290293
"EVAL", `redis.call("EXEC")`, 0,
291294
),
295+
failWith(
296+
"This Redis command is not allowed from scripts",
297+
"EVAL", `redis.call("EVAL", "redis.call(\"GET\", \"foo\")", 0)`, 0,
298+
),
299+
failWith(
300+
"This Redis command is not allowed from scripts",
301+
"EVAL", `redis.call("SCRIPT", "LOAD", "return 42")`, 0,
302+
),
303+
failWith(
304+
"This Redis command is not allowed from scripts",
305+
"EVAL", `redis.call("EVALSHA", "123", "0")`, 0,
306+
),
307+
failWith(
308+
"This Redis command is not allowed from scripts",
309+
"EVAL", `redis.call("AUTH", "foobar")`, 0,
310+
),
311+
failWith(
312+
"This Redis command is not allowed from scripts",
313+
"EVAL", `redis.call("WATCH", "foobar")`, 0,
314+
),
315+
failWith(
316+
"This Redis command is not allowed from scripts",
317+
"EVAL", `redis.call("SUBSCRIBE", "foo")`, 0,
318+
),
319+
failWith(
320+
"This Redis command is not allowed from scripts",
321+
"EVAL", `redis.call("UNSUBSCRIBE", "foo")`, 0,
322+
),
323+
failWith(
324+
"This Redis command is not allowed from scripts",
325+
"EVAL", `redis.call("PSUBSCRIBE", "foo")`, 0,
326+
),
327+
failWith(
328+
"This Redis command is not allowed from scripts",
329+
"EVAL", `redis.call("PUNSUBSCRIBE", "foo")`, 0,
330+
),
292331
succ("EVAL", `redis.pcall("EXEC")`, 0),
293332
succ("GET", "foo"),
294333
)

lua.go

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,32 @@
11
package miniredis
22

33
import (
4+
"bufio"
5+
"bytes"
6+
"fmt"
47
"strings"
58

6-
redigo "github.com/gomodule/redigo/redis"
79
lua "github.com/yuin/gopher-lua"
810

911
"github.com/alicebob/miniredis/v2/server"
1012
)
1113

12-
func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction {
14+
func mkLuaFuncs(srv *server.Server, c *server.Peer) map[string]lua.LGFunction {
1315
mkCall := func(failFast bool) func(l *lua.LState) int {
16+
// one server.Ctx for a single Lua run
17+
pCtx := &connCtx{}
18+
if getCtx(c).authenticated {
19+
pCtx.authenticated = true
20+
}
21+
pCtx.nested = true
22+
1423
return func(l *lua.LState) int {
1524
top := l.GetTop()
1625
if top == 0 {
1726
l.Error(lua.LString("Please specify at least one argument for redis.call()"), 1)
1827
return 0
1928
}
20-
var args []interface{}
29+
var args []string
2130
for i := 1; i <= top; i++ {
2231
switch a := l.Get(i).(type) {
2332
case lua.LNumber:
@@ -29,22 +38,19 @@ func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction {
2938
return 0
3039
}
3140
}
32-
cmd, ok := args[0].(string)
33-
if !ok {
34-
l.Error(lua.LString("Unknown Redis command called from Lua script"), 1)
41+
if len(args) == 0 {
42+
l.Error(lua.LString(msgNotFromScripts), 1)
3543
return 0
3644
}
37-
switch strings.ToUpper(cmd) {
38-
case "MULTI", "EXEC":
39-
if failFast {
40-
l.Error(lua.LString("This Redis command is not allowed from scripts"), 1)
41-
return 0
42-
}
43-
l.Push(lua.LNil)
44-
return 1
45-
}
4645

47-
res, err := conn.Do(cmd, args[1:]...)
46+
buf := &bytes.Buffer{}
47+
wr := bufio.NewWriter(buf)
48+
peer := server.NewPeer(wr)
49+
peer.Ctx = pCtx
50+
srv.Dispatch(peer, args)
51+
wr.Flush()
52+
53+
res, err := server.ParseReply(bufio.NewReader(buf))
4854
if err != nil {
4955
if failFast {
5056
// call() mode
@@ -66,14 +72,19 @@ func mkLuaFuncs(conn redigo.Conn) map[string]lua.LGFunction {
6672
switch r := res.(type) {
6773
case int64:
6874
l.Push(lua.LNumber(r))
75+
case int:
76+
l.Push(lua.LNumber(r))
6977
case []uint8:
7078
l.Push(lua.LString(string(r)))
7179
case []interface{}:
7280
l.Push(redisToLua(l, r))
7381
case string:
7482
l.Push(lua.LString(r))
83+
case error:
84+
l.Error(lua.LString(r.Error()), 1)
85+
return 0
7586
default:
76-
panic("type not handled")
87+
panic(fmt.Sprintf("type not handled (%T)", r))
7788
}
7889
}
7990
return 1

miniredis.go

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,10 @@ import (
1818
"context"
1919
"fmt"
2020
"math/rand"
21-
"net"
2221
"strconv"
2322
"sync"
2423
"time"
2524

26-
redigo "github.com/gomodule/redigo/redis"
27-
2825
"github.com/alicebob/miniredis/v2/server"
2926
)
3027

@@ -80,6 +77,7 @@ type connCtx struct {
8077
dirtyTransaction bool // any error during QUEUEing
8178
watch map[dbKey]uint // WATCHed keys
8279
subscriber *Subscriber // client is in PUBSUB mode if not nil
80+
nested bool // this is called via Lua
8381
}
8482

8583
// NewMiniRedis makes a new, non-started, Miniredis object.
@@ -287,19 +285,6 @@ func (m *Miniredis) Server() *server.Server {
287285
return m.srv
288286
}
289287

290-
// redigo returns a redigo.Conn, connected using net.Pipe
291-
func (m *Miniredis) redigo() redigo.Conn {
292-
c1, c2 := net.Pipe()
293-
m.srv.ServeConn(c1)
294-
c := redigo.NewConn(c2, 0, 0)
295-
if m.password != "" {
296-
if _, err := c.Do("AUTH", m.password); err != nil {
297-
// ?
298-
}
299-
}
300-
return c
301-
}
302-
303288
// Dump returns a text version of the selected DB, usable for debugging.
304289
func (m *Miniredis) Dump() string {
305290
m.Lock()
@@ -366,6 +351,10 @@ func (m *Miniredis) SetTime(t time.Time) {
366351

367352
// handleAuth returns false if connection has no access. It sends the reply.
368353
func (m *Miniredis) handleAuth(c *server.Peer) bool {
354+
if getCtx(c).nested {
355+
return true
356+
}
357+
369358
m.Lock()
370359
defer m.Unlock()
371360
if m.password == "" {
@@ -381,6 +370,10 @@ func (m *Miniredis) handleAuth(c *server.Peer) bool {
381370
// handlePubsub sends an error to the user if the connection is in PUBSUB mode.
382371
// It'll return true if it did.
383372
func (m *Miniredis) checkPubsub(c *server.Peer) bool {
373+
if getCtx(c).nested {
374+
return false
375+
}
376+
384377
m.Lock()
385378
defer m.Unlock()
386379

miniredis_test.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -238,21 +238,3 @@ func TestExpireWithFastForward(t *testing.T) {
238238
s.FastForward(5 * time.Second)
239239
equals(t, 1, len(s.Keys()))
240240
}
241-
242-
func TestRedigo(t *testing.T) {
243-
s, err := Run()
244-
ok(t, err)
245-
246-
r := s.redigo()
247-
defer r.Close()
248-
249-
_, err = r.Do("SELECT", 2)
250-
ok(t, err)
251-
252-
_, err = r.Do("SET", "foo", "bar")
253-
ok(t, err)
254-
255-
v, err := redis.String(r.Do("GET", "foo"))
256-
ok(t, err)
257-
equals(t, "bar", v)
258-
}

0 commit comments

Comments
 (0)