diff --git a/cmd_scripting.go b/cmd_scripting.go index 340ccc4..02ca038 100644 --- a/cmd_scripting.go +++ b/cmd_scripting.go @@ -18,7 +18,9 @@ import ( func commandsScripting(m *Miniredis) { m.srv.Register("EVAL", m.cmdEval) + m.srv.Register("EVAL_RO", m.cmdEvalro, server.ReadOnlyOption()) m.srv.Register("EVALSHA", m.cmdEvalsha) + m.srv.Register("EVALSHA_RO", m.cmdEvalshaRo, server.ReadOnlyOption()) m.srv.Register("SCRIPT", m.cmdScript) } @@ -28,7 +30,7 @@ var ( // Execute lua. Needs to run m.Lock()ed, from within withTx(). // Returns true if the lua was OK (and hence should be cached). -func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []string) bool { +func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, readOnly bool, args []string) bool { l := lua.NewState(lua.Options{SkipOpenLibs: true}) defer l.Close() @@ -85,7 +87,7 @@ func (m *Miniredis) runLuaScript(c *server.Peer, sha, script string, args []stri } l.SetGlobal("ARGV", argvTable) - redisFuncs, redisConstants := mkLua(m.srv, c, sha) + redisFuncs, redisConstants := mkLua(m.srv, c, sha, readOnly) // Register command handlers l.Push(l.NewFunction(func(l *lua.LState) int { mod := l.RegisterModule("redis", redisFuncs).(*lua.LTable) @@ -150,7 +152,8 @@ func compile(script string) (*lua.FunctionProto, error) { return proto, nil } -func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) { +// Shared implementation for EVAL and EVALRO +func (m *Miniredis) cmdEvalShared(c *server.Peer, cmd string, readOnly bool, args []string) { if !m.isValidCMD(c, cmd, args, atLeast(2)) { return } @@ -165,14 +168,20 @@ func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) { withTx(m, c, func(c *server.Peer, ctx *connCtx) { sha := sha1Hex(script) - ok := m.runLuaScript(c, sha, script, args) + ok := m.runLuaScript(c, sha, script, readOnly, args) if ok { m.scripts[sha] = script } }) } -func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) { +// Wrapper function for EVAL command +func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) { + m.cmdEvalShared(c, cmd, false, args) +} + +// Shared implementation for EVALSHA and EVALSHA_RO +func (m *Miniredis) cmdEvalshaShared(c *server.Peer, cmd string, readOnly bool, args []string) { if !m.isValidCMD(c, cmd, args, atLeast(2)) { return } @@ -192,10 +201,25 @@ func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) { return } - m.runLuaScript(c, sha, script, args) + m.runLuaScript(c, sha, script, readOnly, args) }) } +// Wrapper function for EVALSHA command +func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) { + m.cmdEvalshaShared(c, cmd, false, args) +} + +// Wrapper function for EVALRO command +func (m *Miniredis) cmdEvalro(c *server.Peer, cmd string, args []string) { + m.cmdEvalShared(c, cmd, true, args) +} + +// Wrapper function for EVALSHA_RO command +func (m *Miniredis) cmdEvalshaRo(c *server.Peer, cmd string, args []string) { + m.cmdEvalshaShared(c, cmd, true, args) +} + func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) { if !m.isValidCMD(c, cmd, args, atLeast(1)) { return diff --git a/cmd_scripting_test.go b/cmd_scripting_test.go index 144f0f4..005c284 100644 --- a/cmd_scripting_test.go +++ b/cmd_scripting_test.go @@ -602,3 +602,124 @@ func TestLuaTX(t *testing.T) { ) }) } + +func TestEvalRo(t *testing.T) { + _, c := runWithClient(t) + + t.Run("read-only command", func(t *testing.T) { + mustOK(t, c, + "SET", "readonly", "foo", + ) + + // Test EVALRO with read-only command (should work) + mustDo(t, c, + "EVAL_RO", "return redis.call('GET', KEYS[1])", "1", "readonly", + proto.String("foo"), + ) + }) + + t.Run("write command", func(t *testing.T) { + // Test EVALRO with write command (should fail) + mustContain(t, c, + "EVAL_RO", "return redis.call('SET', KEYS[1], ARGV[1])", "1", "key1", "value1", + "Write commands are not allowed in read-only scripts", + ) + }) +} + +func TestEvalshaRo(t *testing.T) { + _, c := runWithClient(t) + + // First load a read-only script + script := "return redis.call('GET', KEYS[1])" + t.Run("read-only script", func(t *testing.T) { + mustDo(t, c, + "SCRIPT", "LOAD", script, + proto.String("d3c21d0c2b9ca22f82737626a27bcaf5d288f99f"), + ) + + mustOK(t, c, + "SET", "readonly", "foo", + ) + + // Test EVALSHA_RO with read-only command (should work) + mustDo(t, c, + "EVALSHA_RO", "d3c21d0c2b9ca22f82737626a27bcaf5d288f99f", "1", "readonly", + proto.String("foo"), + ) + + }) + + t.Run("write script", func(t *testing.T) { + // Load a write script + writeScript := "return redis.call('SET', KEYS[1], ARGV[1])" + mustDo(t, c, + "SCRIPT", "LOAD", writeScript, + proto.String("d8f2fad9f8e86a53d2a6ebd960b33c4972cacc37"), + ) + + // Test EVALSHA_RO with write command (should fail) + mustContain(t, c, + "EVALSHA_RO", "d8f2fad9f8e86a53d2a6ebd960b33c4972cacc37", "1", "key1", "value1", + "Write commands are not allowed in read-only scripts", + ) + }) +} + +func TestEvalRoWriteCommandWithPcall(t *testing.T) { + _, c := runWithClient(t) + + t.Run("return error", func(t *testing.T) { + // Test EVAL with pcall and write command (should fail) + mustContain(t, c, + "EVAL_RO", "return redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1])", "1", "key1", "value1", + "Unknown Redis command called from script", + ) + }) + + t.Run("extra work after error", func(t *testing.T) { + script := ` +local err = redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1]); +local res = "pcall:" .. err['err']; +return res; +` + // Test EVAL with pcall and write command (should fail) + mustContain(t, c, + "EVAL_RO", script, "1", "key1", "value1", + "pcall:ERR Unknown Redis command called from script", + ) + }) + + t.Run("write command in read-only script", func(t *testing.T) { + // Test EVALRO with pcall and write command (should fail) + mustContain(t, c, + "EVAL_RO", "return redis.pcall('SET', KEYS[1], ARGV[1])", "1", "key1", "value1", + "Write commands are not allowed in read-only scripts", + ) + }) +} + +func TestEvalWithPcall(t *testing.T) { + _, c := runWithClient(t) + + t.Run("return error", func(t *testing.T) { + // Test EVAL with pcall and write command (should fail) + mustContain(t, c, + "EVAL", "return redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1])", "1", "key1", "value1", + "Unknown Redis command called from script", + ) + }) + + t.Run("continue after error", func(t *testing.T) { + script := ` +local err = redis.pcall('FAKECOMMAND', KEYS[1], ARGV[1]); +local res = "pcall:" .. err['err']; +return res; +` + // Test EVAL with pcall and write command (should fail) + mustContain(t, c, + "EVAL", script, "1", "foo", "value1", + "pcall:ERR Unknown Redis command called from script", + ) + }) +} diff --git a/integration/script_test.go b/integration/script_test.go index 4182c36..51cf13f 100644 --- a/integration/script_test.go +++ b/integration/script_test.go @@ -21,6 +21,10 @@ func TestScript(t *testing.T) { c.Do("EVAL", "return redis.call('ZRANGE', 'q', 0, -1)", "0") c.Do("EVAL", "return redis.call('LPOP', 'foo')", "0") + c.Do("EVAL_RO", "return 42", "0") + c.Do("EVAL_RO", "return 42+2", "0") + c.Error("Write commands are not allowed", "EVAL_RO", "return redis.call('LPOP', 'foo')", "0") + // failure cases c.Error("wrong number", "EVAL") c.Error("wrong number", "EVAL", "return 42") diff --git a/lua.go b/lua.go index f623739..2ad34a5 100644 --- a/lua.go +++ b/lua.go @@ -18,7 +18,7 @@ var luaRedisConstants = map[string]lua.LValue{ "LOG_WARNING": lua.LNumber(3), } -func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFunction, map[string]lua.LValue) { +func mkLua(srv *server.Server, c *server.Peer, sha string, readOnly bool) (map[string]lua.LGFunction, map[string]lua.LValue) { mkCall := func(failFast bool) func(l *lua.LState) int { // one server.Ctx for a single Lua run pCtx := &connCtx{} @@ -52,6 +52,20 @@ func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFun return 0 } + if readOnly && len(args) > 0 { + if srv.IsRegisteredCommand(args[0]) && !srv.IsReadOnlyCommand(args[0]) { + if failFast { + l.Error(lua.LString("Write commands are not allowed in read-only scripts"), 1) + return 0 + } + // pcall() mode - return error table + res := &lua.LTable{} + res.RawSetString("err", lua.LString("Write commands are not allowed in read-only scripts")) + l.Push(res) + return 1 + } + } + buf := &bytes.Buffer{} wr := bufio.NewWriter(buf) peer := server.NewPeer(wr) @@ -71,7 +85,13 @@ func mkLua(srv *server.Server, c *server.Peer, sha string) (map[string]lua.LGFun return 0 } // pcall() mode - l.Push(lua.LNil) + res := &lua.LTable{} + if strings.Contains(err.Error(), "ERR unknown command") { + res.RawSetString("err", lua.LString("ERR Unknown Redis command called from script")) + } else { + res.RawSetString("err", lua.LString(err.Error())) + } + l.Push(res) return 1 } diff --git a/server/server.go b/server/server.go index 5bc2ce1..af36105 100644 --- a/server/server.go +++ b/server/server.go @@ -238,6 +238,15 @@ func (s *Server) TotalCommands() int { return s.infoCmds } +// IsRegisteredCommand checks if a command is registered +func (s *Server) IsRegisteredCommand(cmd string) bool { + s.mu.Lock() + defer s.mu.Unlock() + cmdUp := strings.ToUpper(cmd) + _, ok := s.cmds[cmdUp] + return ok +} + // IsReadOnlyCommand checks if a command is marked as read-only func (s *Server) IsReadOnlyCommand(cmd string) bool { s.mu.Lock() diff --git a/server/server_test.go b/server/server_test.go index ac861fc..eda0fcb 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -308,3 +308,23 @@ func TestReadOnlyOption(t *testing.T) { t.Error("Non-existent command should return false") } } + +func TestIsRegisteredCommand(t *testing.T) { + srv, err := NewServer("127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create server: %v", err) + } + defer srv.Close() + + srv.Register("TESTGET", func(c *Peer, cmd string, args []string) { + c.WriteOK() + }) + + if !srv.IsRegisteredCommand("TESTGET") { + t.Error("TESTGET should be registered") + } + + if srv.IsRegisteredCommand("NONEXISTENT") { + t.Error("NONEXISTENT should not be registered") + } +}