Skip to content

Commit 72a6b15

Browse files
authored
Add evalro support (#416)
* Add EVAL_RO and EVALSHA_RO commands Signed-off-by: Maximilian Frank <1375575+max-frank@users.noreply.github.com>
1 parent 3f960ee commit 72a6b15

File tree

6 files changed

+206
-8
lines changed

6 files changed

+206
-8
lines changed

cmd_scripting.go

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ import (
1818

1919
func commandsScripting(m *Miniredis) {
2020
m.srv.Register("EVAL", m.cmdEval)
21+
m.srv.Register("EVAL_RO", m.cmdEvalro, server.ReadOnlyOption())
2122
m.srv.Register("EVALSHA", m.cmdEvalsha)
23+
m.srv.Register("EVALSHA_RO", m.cmdEvalshaRo, server.ReadOnlyOption())
2224
m.srv.Register("SCRIPT", m.cmdScript)
2325
}
2426

@@ -28,7 +30,7 @@ var (
2830

2931
// Execute lua. Needs to run m.Lock()ed, from within withTx().
3032
// Returns true if the lua was OK (and hence should be cached).
31-
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool {
33+
func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, readOnly bool, args []string) bool {
3234
l := lua.NewState(lua.Options{SkipOpenLibs: true})
3335
defer l.Close()
3436

@@ -85,7 +87,7 @@ func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []stri
8587
}
8688
l.SetGlobal("ARGV", argvTable)
8789

88-
redisFuncs, redisConstants := mkLua(m.srv, c, sha)
90+
redisFuncs, redisConstants := mkLua(m.srv, c, sha, readOnly)
8991
// Register command handlers
9092
l.Push(l.NewFunction(func(l *lua.LState) int {
9193
mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable)
@@ -150,7 +152,8 @@ func compile(script string) (*lua.FunctionProto, error) {
150152
return proto, nil
151153
}
152154

153-
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
155+
// Shared implementation for EVAL and EVALRO
156+
func (m *Miniredis) cmdEvalShared(c *server.Peer, cmd string, readOnly bool, args []string) {
154157
if !m.isValidCMD(c, cmd, args, atLeast(2)) {
155158
return
156159
}
@@ -165,14 +168,20 @@ func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
165168

166169
withTx(m, c, func(c *server.Peer, ctx *connCtx) {
167170
sha := sha1Hex(script)
168-
ok := m.runLuaScript(c, sha, script, args)
171+
ok := m.runLuaScript(c, sha, script, readOnly, args)
169172
if ok {
170173
m.scripts[sha] = script
171174
}
172175
})
173176
}
174177

175-
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
178+
// Wrapper function for EVAL command
179+
func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) {
180+
m.cmdEvalShared(c, cmd, false, args)
181+
}
182+
183+
// Shared implementation for EVALSHA and EVALSHA_RO
184+
func (m *Miniredis) cmdEvalshaShared(c *server.Peer, cmd string, readOnly bool, args []string) {
176185
if !m.isValidCMD(c, cmd, args, atLeast(2)) {
177186
return
178187
}
@@ -192,10 +201,25 @@ func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
192201
return
193202
}
194203

195-
m.runLuaScript(c, sha, script, args)
204+
m.runLuaScript(c, sha, script, readOnly, args)
196205
})
197206
}
198207

208+
// Wrapper function for EVALSHA command
209+
func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) {
210+
m.cmdEvalshaShared(c, cmd, false, args)
211+
}
212+
213+
// Wrapper function for EVALRO command
214+
func (m *Miniredis) cmdEvalro(c *server.Peer, cmd string, args []string) {
215+
m.cmdEvalShared(c, cmd, true, args)
216+
}
217+
218+
// Wrapper function for EVALSHA_RO command
219+
func (m *Miniredis) cmdEvalshaRo(c *server.Peer, cmd string, args []string) {
220+
m.cmdEvalshaShared(c, cmd, true, args)
221+
}
222+
199223
func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) {
200224
if !m.isValidCMD(c, cmd, args, atLeast(1)) {
201225
return

cmd_scripting_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,3 +602,124 @@ func TestLuaTX(t *testing.T) {
602602
)
603603
})
604604
}
605+
606+
func TestEvalRo(t *testing.T) {
607+
_, c := runWithClient(t)
608+
609+
t.Run("read-only command", func(t *testing.T) {
610+
mustOK(t, c,
611+
"SET", "readonly", "foo",
612+
)
613+
614+
// Test EVALRO with read-only command (should work)
615+
mustDo(t, c,
616+
"EVAL_RO", "return redis.call('GET', KEYS[1])", "1", "readonly",
617+
proto.String("foo"),
618+
)
619+
})
620+
621+
t.Run("write command", func(t *testing.T) {
622+
// Test EVALRO with write command (should fail)
623+
mustContain(t, c,
624+
"EVAL_RO", "return redis.call('SET', KEYS[1], ARGV[1])", "1", "key1", "value1",
625+
"Write commands are not allowed in read-only scripts",
626+
)
627+
})
628+
}
629+
630+
func TestEvalshaRo(t *testing.T) {
631+
_, c := runWithClient(t)
632+
633+
// First load a read-only script
634+
script := "return redis.call('GET', KEYS[1])"
635+
t.Run("read-only script", func(t *testing.T) {
636+
mustDo(t, c,
637+
"SCRIPT", "LOAD", script,
638+
proto.String("d3c21d0c2b9ca22f82737626a27bcaf5d288f99f"),
639+
)
640+
641+
mustOK(t, c,
642+
"SET", "readonly", "foo",
643+
)
644+
645+
// Test EVALSHA_RO with read-only command (should work)
646+
mustDo(t, c,
647+
"EVALSHA_RO", "d3c21d0c2b9ca22f82737626a27bcaf5d288f99f", "1", "readonly",
648+
proto.String("foo"),
649+
)
650+
651+
})
652+
653+
t.Run("write script", func(t *testing.T) {
654+
// Load a write script
655+
writeScript := "return redis.call('SET', KEYS[1], ARGV[1])"
656+
mustDo(t, c,
657+
"SCRIPT", "LOAD", writeScript,
658+
proto.String("d8f2fad9f8e86a53d2a6ebd960b33c4972cacc37"),
659+
)
660+
661+
// Test EVALSHA_RO with write command (should fail)
662+
mustContain(t, c,
663+
"EVALSHA_RO", "d8f2fad9f8e86a53d2a6ebd960b33c4972cacc37", "1", "key1", "value1",
664+
"Write commands are not allowed in read-only scripts",
665+
)
666+
})
667+
}
668+
669+
func TestEvalRoWriteCommandWithPcall(t *testing.T) {
670+
_, c := runWithClient(t)
671+
672+
t.Run("return error", func(t *testing.T) {
673+
// Test EVAL with pcall and write command (should fail)
674+
mustContain(t, c,
675+
"EVAL_RO", "return redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1])", "1", "key1", "value1",
676+
"Unknown Redis command called from script",
677+
)
678+
})
679+
680+
t.Run("extra work after error", func(t *testing.T) {
681+
script := `
682+
local err = redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1]);
683+
local res = "pcall:" .. err['err'];
684+
return res;
685+
`
686+
// Test EVAL with pcall and write command (should fail)
687+
mustContain(t, c,
688+
"EVAL_RO", script, "1", "key1", "value1",
689+
"pcall:ERR Unknown Redis command called from script",
690+
)
691+
})
692+
693+
t.Run("write command in read-only script", func(t *testing.T) {
694+
// Test EVALRO with pcall and write command (should fail)
695+
mustContain(t, c,
696+
"EVAL_RO", "return redis.pcall('SET', KEYS[1], ARGV[1])", "1", "key1", "value1",
697+
"Write commands are not allowed in read-only scripts",
698+
)
699+
})
700+
}
701+
702+
func TestEvalWithPcall(t *testing.T) {
703+
_, c := runWithClient(t)
704+
705+
t.Run("return error", func(t *testing.T) {
706+
// Test EVAL with pcall and write command (should fail)
707+
mustContain(t, c,
708+
"EVAL", "return redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1])", "1", "key1", "value1",
709+
"Unknown Redis command called from script",
710+
)
711+
})
712+
713+
t.Run("continue after error", func(t *testing.T) {
714+
script := `
715+
local err = redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1]);
716+
local res = "pcall:" .. err['err'];
717+
return res;
718+
`
719+
// Test EVAL with pcall and write command (should fail)
720+
mustContain(t, c,
721+
"EVAL", script, "1", "foo", "value1",
722+
"pcall:ERR Unknown Redis command called from script",
723+
)
724+
})
725+
}

integration/script_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ func TestScript(t *testing.T) {
2121
c.Do("EVAL", "return redis.call('ZRANGE', 'q', 0, -1)", "0")
2222
c.Do("EVAL", "return redis.call('LPOP', 'foo')", "0")
2323

24+
c.Do("EVAL_RO", "return 42", "0")
25+
c.Do("EVAL_RO", "return 42+2", "0")
26+
c.Error("Write commands are not allowed", "EVAL_RO", "return redis.call('LPOP', 'foo')", "0")
27+
2428
// failure cases
2529
c.Error("wrong number", "EVAL")
2630
c.Error("wrong number", "EVAL", "return 42")

lua.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ var luaRedisConstants = map[string]lua.LValue{
1818
"LOG_WARNING": lua.LNumber(3),
1919
}
2020

21-
func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFunction, map[string]lua.LValue) {
21+
func mkLua(srv *server.Server, c *server.Peer, sha string, readOnly bool) (map[string]lua.LGFunction, map[string]lua.LValue) {
2222
mkCall := func(failFast bool) func(l *lua.LState) int {
2323
// one server.Ctx for a single Lua run
2424
pCtx := &connCtx{}
@@ -52,6 +52,20 @@ func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFun
5252
return 0
5353
}
5454

55+
if readOnly && len(args) > 0 {
56+
if srv.IsRegisteredCommand(args[0]) && !srv.IsReadOnlyCommand(args[0]) {
57+
if failFast {
58+
l.Error(lua.LString("Write commands are not allowed in read-only scripts"), 1)
59+
return 0
60+
}
61+
// pcall() mode - return error table
62+
res := &lua.LTable{}
63+
res.RawSetString("err", lua.LString("Write commands are not allowed in read-only scripts"))
64+
l.Push(res)
65+
return 1
66+
}
67+
}
68+
5569
buf := &bytes.Buffer{}
5670
wr := bufio.NewWriter(buf)
5771
peer := server.NewPeer(wr)
@@ -71,7 +85,13 @@ func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFun
7185
return 0
7286
}
7387
// pcall() mode
74-
l.Push(lua.LNil)
88+
res := &lua.LTable{}
89+
if strings.Contains(err.Error(), "ERR unknown command") {
90+
res.RawSetString("err", lua.LString("ERR Unknown Redis command called from script"))
91+
} else {
92+
res.RawSetString("err", lua.LString(err.Error()))
93+
}
94+
l.Push(res)
7595
return 1
7696
}
7797

server/server.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,15 @@ func (s *Server) TotalCommands() int {
238238
return s.infoCmds
239239
}
240240

241+
// IsRegisteredCommand checks if a command is registered
242+
func (s *Server) IsRegisteredCommand(cmd string) bool {
243+
s.mu.Lock()
244+
defer s.mu.Unlock()
245+
cmdUp := strings.ToUpper(cmd)
246+
_, ok := s.cmds[cmdUp]
247+
return ok
248+
}
249+
241250
// IsReadOnlyCommand checks if a command is marked as read-only
242251
func (s *Server) IsReadOnlyCommand(cmd string) bool {
243252
s.mu.Lock()

server/server_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,23 @@ func TestReadOnlyOption(t *testing.T) {
308308
t.Error("Non-existent command should return false")
309309
}
310310
}
311+
312+
func TestIsRegisteredCommand(t *testing.T) {
313+
srv, err := NewServer("127.0.0.1:0")
314+
if err != nil {
315+
t.Fatalf("Failed to create server: %v", err)
316+
}
317+
defer srv.Close()
318+
319+
srv.Register("TESTGET", func(c *Peer, cmd string, args []string) {
320+
c.WriteOK()
321+
})
322+
323+
if !srv.IsRegisteredCommand("TESTGET") {
324+
t.Error("TESTGET should be registered")
325+
}
326+
327+
if srv.IsRegisteredCommand("NONEXISTENT") {
328+
t.Error("NONEXISTENT should not be registered")
329+
}
330+
}

0 commit comments

Comments
 (0)