diff --git a/CHANGELOG.md b/CHANGELOG.md index b2c840c..2c8cb85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,13 @@ ## Changelog +## v2.37.0 + +- suport HEXPIRE (thanks @mojixcoder) + + ## v2.36.1 + - support CLUSTER SHARDS (thanks @dadrus) diff --git a/cmd_hash.go b/cmd_hash.go index 29fec82..9c5e686 100644 --- a/cmd_hash.go +++ b/cmd_hash.go @@ -3,9 +3,11 @@ package miniredis import ( + "fmt" "math/big" "strconv" "strings" + "time" "github.com/alicebob/miniredis/v2/server" ) @@ -28,6 +30,7 @@ func commandsHash(m *Miniredis) { m.srv.Register("HVALS", m.cmdHvals, server.ReadOnlyOption()) m.srv.Register("HSCAN", m.cmdHscan, server.ReadOnlyOption()) m.srv.Register("HRANDFIELD", m.cmdHrandfield, server.ReadOnlyOption()) + m.srv.Register("HEXPIRE", m.cmdHexpire) } // HSET @@ -641,6 +644,151 @@ func (m *Miniredis) cmdHrandfield(c *server.Peer, cmd string, args []string) { }) } +// HEXPIRE +func (m *Miniredis) cmdHexpire(c *server.Peer, cmd string, args []string) { + if !m.isValidCMD(c, cmd, args, atLeast(5)) { + return + } + + opts, err := parseHExpireArgs(args) + if err != "" { + setDirty(c) + c.WriteError(err) + return + } + + withTx(m, c, func(peer *server.Peer, ctx *connCtx) { + db := m.db(ctx.selectedDB) + + if _, ok := db.keys[opts.key]; !ok { + c.WriteLen(len(opts.fields)) + for range opts.fields { + c.WriteInt(-2) + } + return + } + + if db.t(opts.key) != keyTypeHash { + c.WriteError(msgWrongType) + return + } + + fieldTTLs := db.hashTTLs[opts.key] + if fieldTTLs == nil { + fieldTTLs = map[string]time.Duration{} + db.hashTTLs[opts.key] = fieldTTLs + } + + c.WriteLen(len(opts.fields)) + for _, field := range opts.fields { + if _, ok := db.hashKeys[opts.key][field]; !ok { + c.WriteInt(-2) + continue + } + + currentTtl, ok := fieldTTLs[field] + newTTL := time.Duration(opts.ttl) * time.Second + + // NX -- For each specified field, + // set expiration only when the field has no expiration. + if opts.nx && ok { + c.WriteInt(0) + continue + } + + // XX -- For each specified field, + // set expiration only when the field has an existing expiration. + if opts.xx && !ok { + c.WriteInt(0) + continue + } + + // GT -- For each specified field, + // set expiration only when the new expiration is greater than current one. + if opts.gt && (!ok || newTTL <= currentTtl) { + c.WriteInt(0) + continue + } + + // LT -- For each specified field, + // set expiration only when the new expiration is less than current one. + if opts.lt && ok && newTTL >= currentTtl { + c.WriteInt(0) + continue + } + + fieldTTLs[field] = newTTL + c.WriteInt(1) + } + }) +} + +type hexpireOpts struct { + key string + ttl int + nx bool + xx bool + gt bool + lt bool + fields []string +} + +func parseHExpireArgs(args []string) (hexpireOpts, string) { + var opts hexpireOpts + opts.key = args[0] + + if err := optIntSimple(args[1], &opts.ttl); err != nil { + return hexpireOpts{}, err.Error() + } + + args = args[2:] + + for len(args) > 0 { + switch strings.ToLower(args[0]) { + case "nx": + opts.nx = true + args = args[1:] + case "xx": + opts.xx = true + args = args[1:] + case "gt": + opts.gt = true + args = args[1:] + case "lt": + opts.lt = true + args = args[1:] + case "fields": + var numFields int + if err := optIntSimple(args[1], &numFields); err != nil { + return hexpireOpts{}, msgNumFieldsInvalid + } + if numFields <= 0 { + return hexpireOpts{}, msgNumFieldsInvalid + } + + // FIELDS numFields field1 field2 ... + if len(args) < 2+numFields { + return hexpireOpts{}, msgNumFieldsParameter + } + + opts.fields = append([]string{}, args[2:2+numFields]...) + args = args[2+numFields:] + default: + return hexpireOpts{}, fmt.Sprintf(msgMandatoryArgument, "FIELDS") + } + } + + if opts.gt && opts.lt { + return hexpireOpts{}, msgGTandLT + } + + if opts.nx && (opts.xx || opts.gt || opts.lt) { + return hexpireOpts{}, msgNXandXXGTLT + } + + return opts, "" +} + func abs(n int) int { if n < 0 { return -n diff --git a/cmd_hash_test.go b/cmd_hash_test.go index 0a9747f..70aab7e 100644 --- a/cmd_hash_test.go +++ b/cmd_hash_test.go @@ -684,3 +684,707 @@ func TestHashRandField(t *testing.T) { proto.Error(msgInvalidInt), ) } + +func TestParseHExpireArgs(t *testing.T) { + tests := []struct { + name string + args []string + want hexpireOpts + wantErr string + description string + }{ + { + name: "basic usage", + args: []string{"mykey", "300", "FIELDS", "2", "field1", "field2"}, + want: hexpireOpts{ + key: "mykey", + ttl: 300, + fields: []string{"field1", "field2"}, + }, + wantErr: "", + description: "Basic HEXPIRE with key, ttl, and fields", + }, + { + name: "with NX option", + args: []string{"mykey", "300", "NX", "FIELDS", "1", "field1"}, + want: hexpireOpts{ + key: "mykey", + ttl: 300, + nx: true, + fields: []string{"field1"}, + }, + wantErr: "", + description: "HEXPIRE with NX flag", + }, + { + name: "with XX option", + args: []string{"mykey", "300", "XX", "FIELDS", "1", "field1"}, + want: hexpireOpts{ + key: "mykey", + ttl: 300, + xx: true, + fields: []string{"field1"}, + }, + wantErr: "", + description: "HEXPIRE with XX flag", + }, + { + name: "with GT option", + args: []string{"mykey", "300", "GT", "FIELDS", "1", "field1"}, + want: hexpireOpts{ + key: "mykey", + ttl: 300, + gt: true, + fields: []string{"field1"}, + }, + wantErr: "", + description: "HEXPIRE with GT flag", + }, + { + name: "with LT option", + args: []string{"mykey", "300", "LT", "FIELDS", "1", "field1"}, + want: hexpireOpts{ + key: "mykey", + ttl: 300, + lt: true, + fields: []string{"field1"}, + }, + wantErr: "", + description: "HEXPIRE with LT flag", + }, + { + name: "multiple options", + args: []string{"mykey", "300", "XX", "GT", "FIELDS", "3", "f1", "f2", "f3"}, + want: hexpireOpts{ + key: "mykey", + ttl: 300, + xx: true, + gt: true, + fields: []string{"f1", "f2", "f3"}, + }, + wantErr: "", + description: "HEXPIRE with multiple options", + }, + { + name: "invalid TTL", + args: []string{"mykey", "invalid", "FIELDS", "1", "field1"}, + want: hexpireOpts{}, + wantErr: msgInvalidInt, + description: "Invalid TTL value should return error", + }, + { + name: "missing FIELDS keyword", + args: []string{"mykey", "300"}, + want: hexpireOpts{ + key: "mykey", + ttl: 300, + }, + wantErr: "", + description: "Missing FIELDS is OK - validation happens at command level", + }, + { + name: "invalid numFields", + args: []string{"mykey", "300", "FIELDS", "invalid", "field1"}, + want: hexpireOpts{}, + wantErr: msgNumFieldsInvalid, + description: "Invalid numFields should return error", + }, + { + name: "zero numFields", + args: []string{"mykey", "300", "FIELDS", "0"}, + want: hexpireOpts{}, + wantErr: msgNumFieldsInvalid, + description: "Zero numFields should return error", + }, + { + name: "negative numFields", + args: []string{"mykey", "300", "FIELDS", "-1"}, + want: hexpireOpts{}, + wantErr: msgNumFieldsInvalid, + description: "Negative numFields should return error", + }, + { + name: "not enough fields provided", + args: []string{"mykey", "300", "FIELDS", "3", "field1"}, + want: hexpireOpts{}, + wantErr: msgNumFieldsParameter, + description: "Not enough fields provided should return error", + }, + { + name: "GT and LT together", + args: []string{"mykey", "300", "GT", "LT", "FIELDS", "1", "field1"}, + want: hexpireOpts{}, + wantErr: msgGTandLT, + description: "GT and LT together should return error", + }, + { + name: "NX and XX together", + args: []string{"mykey", "300", "NX", "XX", "FIELDS", "1", "field1"}, + want: hexpireOpts{}, + wantErr: msgNXandXXGTLT, + description: "NX and XX together should return error", + }, + { + name: "NX and GT together", + args: []string{"mykey", "300", "NX", "GT", "FIELDS", "1", "field1"}, + want: hexpireOpts{}, + wantErr: msgNXandXXGTLT, + description: "NX and GT together should return error", + }, + { + name: "NX and LT together", + args: []string{"mykey", "300", "NX", "LT", "FIELDS", "1", "field1"}, + want: hexpireOpts{}, + wantErr: msgNXandXXGTLT, + description: "NX and LT together should return error", + }, + { + name: "case insensitive options", + args: []string{"mykey", "300", "nx", "fields", "1", "field1"}, + want: hexpireOpts{ + key: "mykey", + ttl: 300, + nx: true, + fields: []string{"field1"}, + }, + wantErr: "", + description: "Options should be case insensitive", + }, + { + name: "multiple fields", + args: []string{"mykey", "300", "FIELDS", "5", "f1", "f2", "f3", "f4", "f5"}, + want: hexpireOpts{ + key: "mykey", + ttl: 300, + fields: []string{"f1", "f2", "f3", "f4", "f5"}, + }, + wantErr: "", + description: "Should handle multiple fields correctly", + }, + { + name: "negative TTL", + args: []string{"mykey", "-1", "FIELDS", "1", "field1"}, + want: hexpireOpts{ + key: "mykey", + ttl: -1, + fields: []string{"field1"}, + }, + wantErr: "", + description: "Negative TTL should be accepted (for immediate expiration)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotErr := parseHExpireArgs(tt.args) + + // Check error + if tt.wantErr != "" { + if gotErr == "" { + t.Errorf("parseHExpireArgs() error = %q, wantErr containing %q", gotErr, tt.wantErr) + return + } + // Check if the error contains the expected message + if !contains(gotErr, tt.wantErr) { + t.Errorf("parseHExpireArgs() error = %q, wantErr containing %q", gotErr, tt.wantErr) + } + return + } + + if gotErr != "" { + t.Errorf("parseHExpireArgs() unexpected error = %q", gotErr) + return + } + + // Check result + if got.key != tt.want.key { + t.Errorf("parseHExpireArgs() key = %q, want %q", got.key, tt.want.key) + } + if got.ttl != tt.want.ttl { + t.Errorf("parseHExpireArgs() ttl = %d, want %d", got.ttl, tt.want.ttl) + } + if got.nx != tt.want.nx { + t.Errorf("parseHExpireArgs() nx = %v, want %v", got.nx, tt.want.nx) + } + if got.xx != tt.want.xx { + t.Errorf("parseHExpireArgs() xx = %v, want %v", got.xx, tt.want.xx) + } + if got.gt != tt.want.gt { + t.Errorf("parseHExpireArgs() gt = %v, want %v", got.gt, tt.want.gt) + } + if got.lt != tt.want.lt { + t.Errorf("parseHExpireArgs() lt = %v, want %v", got.lt, tt.want.lt) + } + if len(got.fields) != len(tt.want.fields) { + t.Errorf("parseHExpireArgs() fields length = %d, want %d", len(got.fields), len(tt.want.fields)) + } else { + for i := range got.fields { + if got.fields[i] != tt.want.fields[i] { + t.Errorf("parseHExpireArgs() fields[%d] = %q, want %q", i, got.fields[i], tt.want.fields[i]) + } + } + } + }) + } +} + +func contains(s, substr string) bool { + if len(substr) == 0 { + return true + } + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func TestHexpire(t *testing.T) { + s, c := runWithClient(t) + + t.Run("basic expiration", func(t *testing.T) { + must1(t, c, "HSET", "myhash", "field1", "value1") + mustDo(t, c, + "HEXPIRE", "myhash", "10", "FIELDS", "1", "field1", + proto.Ints(1), + ) + }) + + t.Run("expire multiple fields", func(t *testing.T) { + mustDo(t, c, "HSET", "myhash2", "field1", "value1", "field2", "value2", proto.Int(2)) + mustDo(t, c, + "HEXPIRE", "myhash2", "20", "FIELDS", "2", "field1", "field2", + proto.Ints(1, 1), + ) + }) + + t.Run("expire non-existent field", func(t *testing.T) { + must1(t, c, "HSET", "myhash3", "field1", "value1") + mustDo(t, c, + "HEXPIRE", "myhash3", "10", "FIELDS", "1", "nonexistent", + proto.Ints(-2), + ) + }) + + t.Run("expire on non-existent key", func(t *testing.T) { + mustDo(t, c, + "HEXPIRE", "nokey", "10", "FIELDS", "1", "field1", + proto.Ints(-2), + ) + }) + + t.Run("NX option - set only when no expiration", func(t *testing.T) { + must1(t, c, "HSET", "hash2", "f1", "v1") + + // First time should succeed (no expiration set) + mustDo(t, c, + "HEXPIRE", "hash2", "10", "NX", "FIELDS", "1", "f1", + proto.Ints(1), + ) + + // Second time should fail (expiration already set) + mustDo(t, c, + "HEXPIRE", "hash2", "20", "NX", "FIELDS", "1", "f1", + proto.Ints(0), + ) + }) + + t.Run("XX option - set only when expiration exists", func(t *testing.T) { + must1(t, c, "HSET", "hash3", "f1", "v1") + + // First time should fail (no expiration set) + mustDo(t, c, + "HEXPIRE", "hash3", "10", "XX", "FIELDS", "1", "f1", + proto.Ints(0), + ) + + // Set expiration first + mustDo(t, c, + "HEXPIRE", "hash3", "10", "FIELDS", "1", "f1", + proto.Ints(1), + ) + + // Now XX should succeed + mustDo(t, c, + "HEXPIRE", "hash3", "20", "XX", "FIELDS", "1", "f1", + proto.Ints(1), + ) + }) + + t.Run("GT option - set only when new expiration is greater", func(t *testing.T) { + must1(t, c, "HSET", "hash4", "f1", "v1") + + // Set initial expiration + mustDo(t, c, + "HEXPIRE", "hash4", "10", "FIELDS", "1", "f1", + proto.Ints(1), + ) + + // Try to set lower expiration with GT - should fail + mustDo(t, c, + "HEXPIRE", "hash4", "5", "GT", "FIELDS", "1", "f1", + proto.Ints(0), + ) + + // Set higher expiration with GT - should succeed + mustDo(t, c, + "HEXPIRE", "hash4", "20", "GT", "FIELDS", "1", "f1", + proto.Ints(1), + ) + }) + + t.Run("LT option - set only when new expiration is less", func(t *testing.T) { + must1(t, c, "HSET", "hash5", "f1", "v1") + + // Set initial expiration + mustDo(t, c, + "HEXPIRE", "hash5", "20", "FIELDS", "1", "f1", + proto.Ints(1), + ) + + // Try to set higher expiration with LT - should fail + mustDo(t, c, + "HEXPIRE", "hash5", "30", "LT", "FIELDS", "1", "f1", + proto.Ints(0), + ) + + // Set lower expiration with LT - should succeed + mustDo(t, c, + "HEXPIRE", "hash5", "10", "LT", "FIELDS", "1", "f1", + proto.Ints(1), + ) + }) + + t.Run("field expiration actually expires", func(t *testing.T) { + mustDo(t, c, "HSET", "hash6", "f1", "v1", "f2", "v2", proto.Int(2)) + + // Set very short expiration + mustDo(t, c, + "HEXPIRE", "hash6", "1", "FIELDS", "1", "f1", + proto.Ints(1), + ) + + // Field should exist now + mustDo(t, c, + "HGET", "hash6", "f1", + proto.String("v1"), + ) + + // Fast forward past expiration + s.FastForward(2 * time.Second) + + // Field should be gone + mustDo(t, c, + "HGET", "hash6", "f1", + proto.Nil, + ) + + // But other field should still exist + mustDo(t, c, + "HGET", "hash6", "f2", + proto.String("v2"), + ) + }) + + t.Run("all fields expired removes hash", func(t *testing.T) { + must1(t, c, "HSET", "hash7", "f1", "v1") + + // Set very short expiration + mustDo(t, c, + "HEXPIRE", "hash7", "1", "FIELDS", "1", "f1", + proto.Ints(1), + ) + + // Hash should exist + mustDo(t, c, + "EXISTS", "hash7", + proto.Int(1), + ) + + // Fast forward past expiration + s.FastForward(2 * time.Second) + + // Hash should be gone + mustDo(t, c, + "EXISTS", "hash7", + proto.Int(0), + ) + }) + + t.Run("error cases", func(t *testing.T) { + mustOK(t, c, "SET", "stringkey", "value") + + // Wrong number of arguments + mustDo(t, c, + "HEXPIRE", "myhash", + proto.Error(errWrongNumber("hexpire")), + ) + + // Wrong type + mustDo(t, c, + "HEXPIRE", "stringkey", "10", "FIELDS", "1", "field1", + proto.Error(msgWrongType), + ) + + // Invalid TTL + mustDo(t, c, + "HEXPIRE", "myhash", "notanumber", "FIELDS", "1", "field1", + proto.Error(msgInvalidInt), + ) + + // Invalid numFields + mustDo(t, c, + "HEXPIRE", "myhash", "10", "FIELDS", "notanumber", "field1", + proto.Error(msgNumFieldsInvalid), + ) + + // Zero numFields - needs at least one dummy field to pass atLeast(5) check + mustDo(t, c, + "HEXPIRE", "myhash", "10", "FIELDS", "0", "dummy", + proto.Error(msgNumFieldsInvalid), + ) + + // Not enough fields + mustDo(t, c, + "HEXPIRE", "myhash", "10", "FIELDS", "2", "field1", + proto.Error(msgNumFieldsParameter), + ) + + // GT and LT together + mustDo(t, c, + "HEXPIRE", "myhash", "10", "GT", "LT", "FIELDS", "1", "field1", + proto.Error(msgGTandLT), + ) + + // NX and XX together + mustDo(t, c, + "HEXPIRE", "myhash", "10", "NX", "XX", "FIELDS", "1", "field1", + proto.Error(msgNXandXXGTLT), + ) + }) + + t.Run("negative TTL for immediate expiration", func(t *testing.T) { + mustDo(t, c, "HSET", "hash8", "f1", "v1", "f2", "v2", proto.Int(2)) + + // Set negative expiration (immediate expiration) + mustDo(t, c, + "HEXPIRE", "hash8", "-1", "FIELDS", "1", "f1", + proto.Ints(1), + ) + + // Fast forward a tiny bit + s.FastForward(100 * time.Millisecond) + + // Field should be gone + mustDo(t, c, + "HGET", "hash8", "f1", + proto.Nil, + ) + }) + + t.Run("case insensitive options", func(t *testing.T) { + must1(t, c, "HSET", "hash9", "f1", "v1") + + mustDo(t, c, + "HEXPIRE", "hash9", "10", "nx", "fields", "1", "f1", + proto.Ints(1), + ) + }) + + t.Run("TTL is actually stored in hashTTLs map", func(t *testing.T) { + must1(t, c, "HSET", "hash10", "field1", "value1") + + // Set TTL + mustDo(t, c, + "HEXPIRE", "hash10", "300", "FIELDS", "1", "field1", + proto.Ints(1), + ) + + // Verify TTL is stored in the internal map + // Note: s.DB(0) internally handles locking + fieldTTLs, ok := s.DB(0).hashTTLs["hash10"] + if !ok { + t.Fatal("hashTTLs map not created for key") + } + ttl, ok := fieldTTLs["field1"] + if !ok { + t.Fatal("TTL not set for field1") + } + expectedTTL := 300 * time.Second + if ttl != expectedTTL { + t.Errorf("TTL mismatch: got %v, want %v", ttl, expectedTTL) + } + + // Set another field's TTL + must1(t, c, "HSET", "hash10", "field2", "value2") + mustDo(t, c, + "HEXPIRE", "hash10", "600", "FIELDS", "1", "field2", + proto.Ints(1), + ) + + // Verify both TTLs are stored + fieldTTLs = s.DB(0).hashTTLs["hash10"] + if len(fieldTTLs) != 2 { + t.Errorf("Expected 2 field TTLs, got %d", len(fieldTTLs)) + } + ttl1 := fieldTTLs["field1"] + ttl2 := fieldTTLs["field2"] + if ttl1 != 300*time.Second { + t.Errorf("field1 TTL mismatch: got %v, want %v", ttl1, 300*time.Second) + } + if ttl2 != 600*time.Second { + t.Errorf("field2 TTL mismatch: got %v, want %v", ttl2, 600*time.Second) + } + }) +} + +func TestCheckHashFieldTTL(t *testing.T) { + s := NewMiniRedis() + defer s.Close() + + t.Run("no TTLs set - no-op", func(t *testing.T) { + s.HSet("hash1", "field1", "value1") + s.HSet("hash1", "field2", "value2") + + // Call checkHashFieldTTL with no TTLs set + s.DB(0).checkHashFieldTTL("hash1", 5*time.Second) + + // Fields should still exist + equals(t, "value1", s.HGet("hash1", "field1")) + equals(t, "value2", s.HGet("hash1", "field2")) + }) + + t.Run("key not in hashTTLs map - no-op", func(t *testing.T) { + s.HSet("hash2", "field1", "value1") + + // Call checkHashFieldTTL for a key not in hashTTLs + s.DB(0).checkHashFieldTTL("hash2", 5*time.Second) + + // Field should still exist + equals(t, "value1", s.HGet("hash2", "field1")) + }) + + t.Run("TTL decrements correctly", func(t *testing.T) { + s.HSet("hash3", "field1", "value1") + + // Manually set TTL + db := s.DB(0) + db.hashTTLs["hash3"] = map[string]time.Duration{ + "field1": 10 * time.Second, + } + + // Decrement by 3 seconds + db.checkHashFieldTTL("hash3", 3*time.Second) + + // TTL should be 7 seconds now + equals(t, 7*time.Second, db.hashTTLs["hash3"]["field1"]) + equals(t, "value1", s.HGet("hash3", "field1")) + }) + + t.Run("field expires when TTL reaches zero", func(t *testing.T) { + s.HSet("hash4", "field1", "value1") + s.HSet("hash4", "field2", "value2") + + // Set TTL that will expire + db := s.DB(0) + db.hashTTLs["hash4"] = map[string]time.Duration{ + "field1": 2 * time.Second, + } + + // Decrement past zero + db.checkHashFieldTTL("hash4", 3*time.Second) + + // field1 should be deleted + equals(t, "", s.HGet("hash4", "field1")) + // field2 should still exist + equals(t, "value2", s.HGet("hash4", "field2")) + // TTL entry should be removed + _, exists := db.hashTTLs["hash4"]["field1"] + equals(t, false, exists) + }) + + t.Run("multiple fields with different TTLs", func(t *testing.T) { + s.HSet("hash5", "field1", "value1") + s.HSet("hash5", "field2", "value2") + s.HSet("hash5", "field3", "value3") + + db := s.DB(0) + db.hashTTLs["hash5"] = map[string]time.Duration{ + "field1": 2 * time.Second, + "field2": 5 * time.Second, + "field3": 10 * time.Second, + } + + // Decrement by 3 seconds + db.checkHashFieldTTL("hash5", 3*time.Second) + + // field1 should be deleted (2-3 = -1 <= 0) + equals(t, "", s.HGet("hash5", "field1")) + // field2 should still exist with 2 seconds left + equals(t, "value2", s.HGet("hash5", "field2")) + equals(t, 2*time.Second, db.hashTTLs["hash5"]["field2"]) + // field3 should still exist with 7 seconds left + equals(t, "value3", s.HGet("hash5", "field3")) + equals(t, 7*time.Second, db.hashTTLs["hash5"]["field3"]) + }) + + t.Run("hash deleted when all fields expire", func(t *testing.T) { + s.HSet("hash6", "field1", "value1") + s.HSet("hash6", "field2", "value2") + + db := s.DB(0) + db.hashTTLs["hash6"] = map[string]time.Duration{ + "field1": 2 * time.Second, + "field2": 3 * time.Second, + } + + // Decrement past all TTLs + db.checkHashFieldTTL("hash6", 5*time.Second) + + // Both fields should be deleted + equals(t, "", s.HGet("hash6", "field1")) + equals(t, "", s.HGet("hash6", "field2")) + + // Hash key should not exist + assert(t, !s.Exists("hash6"), "hash6 should be deleted") + }) + + t.Run("hash not deleted when some fields remain", func(t *testing.T) { + s.HSet("hash7", "field1", "value1") + s.HSet("hash7", "field2", "value2") + + db := s.DB(0) + db.hashTTLs["hash7"] = map[string]time.Duration{ + "field1": 2 * time.Second, + // field2 has no TTL + } + + // Decrement past field1's TTL + db.checkHashFieldTTL("hash7", 3*time.Second) + + // field1 should be deleted + equals(t, "", s.HGet("hash7", "field1")) + // field2 should still exist (no TTL) + equals(t, "value2", s.HGet("hash7", "field2")) + + // Hash key should still exist + assert(t, s.Exists("hash7"), "hash7 should still exist") + }) + + t.Run("negative TTL causes immediate expiration", func(t *testing.T) { + s.HSet("hash8", "field1", "value1") + + db := s.DB(0) + db.hashTTLs["hash8"] = map[string]time.Duration{ + "field1": -1 * time.Second, + } + + // Any decrement should trigger deletion + db.checkHashFieldTTL("hash8", 1*time.Millisecond) + + // field should be deleted + equals(t, "", s.HGet("hash8", "field1")) + assert(t, !s.Exists("hash8"), "hash8 should be deleted") + }) +} diff --git a/db.go b/db.go index 6af7ba3..97bdf7c 100644 --- a/db.go +++ b/db.go @@ -55,6 +55,7 @@ func (db *RedisDB) flush() { db.hllKeys = map[string]*hll{} db.sortedsetKeys = map[string]sortedSet{} db.ttl = map[string]time.Duration{} + db.hashTTLs = map[string]map[string]time.Duration{} db.streamKeys = map[string]*streamKey{} } @@ -74,6 +75,9 @@ func (db *RedisDB) move(key string, to *RedisDB) bool { to.stringKeys[key] = db.stringKeys[key] case keyTypeHash: to.hashKeys[key] = db.hashKeys[key] + if fieldTTLs, ok := db.hashTTLs[key]; ok { + to.hashTTLs[key] = fieldTTLs + } case keyTypeList: to.listKeys[key] = db.listKeys[key] case keyTypeSet: @@ -102,6 +106,9 @@ func (db *RedisDB) rename(from, to string) { db.stringKeys[to] = db.stringKeys[from] case keyTypeHash: db.hashKeys[to] = db.hashKeys[from] + if fieldTTLs, ok := db.hashTTLs[from]; ok { + db.hashTTLs[to] = fieldTTLs + } case keyTypeList: db.listKeys[to] = db.listKeys[from] case keyTypeSet: @@ -140,6 +147,7 @@ func (db *RedisDB) del(k string, delTTL bool) { delete(db.stringKeys, k) case keyTypeHash: delete(db.hashKeys, k) + delete(db.hashTTLs, k) case keyTypeList: delete(db.listKeys, k) case keyTypeSet: @@ -714,6 +722,32 @@ func (db *RedisDB) fastForward(duration time.Duration) { db.ttl[key] = value - duration db.checkTTL(key) } + + // Handle hash field TTLs + if db.t(key) == keyTypeHash { + db.checkHashFieldTTL(key, duration) + } + } +} + +func (db *RedisDB) checkHashFieldTTL(key string, duration time.Duration) { + fieldTTLs, ok := db.hashTTLs[key] + if !ok { + return + } + + for field, ttl := range fieldTTLs { + fieldTTLs[field] = ttl - duration + if fieldTTLs[field] <= 0 { + // Delete the expired field + delete(db.hashKeys[key], field) + delete(fieldTTLs, field) + + // If hash is now empty, delete the entire key + if len(db.hashKeys[key]) == 0 { + db.del(key, true) + } + } } } diff --git a/integration/hash_test.go b/integration/hash_test.go index 198b1ca..3e82e3d 100644 --- a/integration/hash_test.go +++ b/integration/hash_test.go @@ -54,6 +54,21 @@ func TestHash(t *testing.T) { c.Do("EXEC") }) }) + + t.Run("expire", func(t *testing.T) { + testRaw(t, func(c *client) { + c.Do("HSET", "aap", "noot", "mies") + c.Do("HEXPIRE", "aap", "3", "FIELDS", "2", "noot", "vuur") + + c.Error("wrong number", "HEXPIRE", "aap", "3", "FIELDS", "0") + c.Error("wrong number", "HEXPIRE", "aap", "3") + c.Error("wrong number", "HEXPIRE", "aap", "3", "FIELDS") + c.Error("wrong number", "HEXPIRE", "aap", "-3", "FIELDS", "0") + c.Error("wrong number", "HEXPIRE", "aap", "noot", "3") + c.Error("not an int", "HEXPIRE", "aap", "3.14", "FIELDS", "noot", "3.14") + c.Error("numfields", "HEXPIRE", "aap", "3", "FIELDS", "3", "noot", "vuur") + }) + }) } func TestHashSetnx(t *testing.T) { diff --git a/miniredis.go b/miniredis.go index cc87474..7e65f16 100644 --- a/miniredis.go +++ b/miniredis.go @@ -37,19 +37,20 @@ type setKey map[string]struct{} // RedisDB holds a single (numbered) Redis database. type RedisDB struct { - master *Miniredis // pointer to the lock in Miniredis - id int // db id - keys map[string]string // Master map of keys with their type - stringKeys map[string]string // GET/SET &c. keys - hashKeys map[string]hashKey // MGET/MSET &c. keys - listKeys map[string]listKey // LPUSH &c. keys - setKeys map[string]setKey // SADD &c. keys - hllKeys map[string]*hll // PFADD &c. keys - sortedsetKeys map[string]sortedSet // ZADD &c. keys - streamKeys map[string]*streamKey // XADD &c. keys - ttl map[string]time.Duration // effective TTL values - lru map[string]time.Time // last recently used ( read or written to ) - keyVersion map[string]uint // used to watch values + master *Miniredis // pointer to the lock in Miniredis + id int // db id + keys map[string]string // Master map of keys with their type + stringKeys map[string]string // GET/SET &c. keys + hashKeys map[string]hashKey // MGET/MSET &c. keys + listKeys map[string]listKey // LPUSH &c. keys + setKeys map[string]setKey // SADD &c. keys + hllKeys map[string]*hll // PFADD &c. keys + sortedsetKeys map[string]sortedSet // ZADD &c. keys + streamKeys map[string]*streamKey // XADD &c. keys + ttl map[string]time.Duration // effective TTL values + hashTTLs map[string]map[string]time.Duration // Hash TTL values + lru map[string]time.Time // last recently used ( read or written to ) + keyVersion map[string]uint // used to watch values } // Miniredis is a Redis server implementation. @@ -116,6 +117,7 @@ func newRedisDB(id int, m *Miniredis) RedisDB { sortedsetKeys: map[string]sortedSet{}, streamKeys: map[string]*streamKey{}, ttl: map[string]time.Duration{}, + hashTTLs: make(map[string]map[string]time.Duration), keyVersion: map[string]uint{}, } } diff --git a/redis.go b/redis.go index 2bf3bae..eae0e2f 100644 --- a/redis.go +++ b/redis.go @@ -69,6 +69,11 @@ const ( msgMaxLengthIsNegative = "ERR MAXLEN can't be negative" msgLimitIsNegative = "ERR LIMIT can't be negative" msgMemorySubcommand = "ERR unknown subcommand '%s'. Try MEMORY HELP." + msgNumFieldsParameter = "ERR The `numfields` parameter must match the number of arguments" + msgNumFieldsInvalid = "ERR Parameter `numFields` should be greater than 0" + msgMandatoryArgument = "ERR Mandatory argument %s is missing or not at the right position" + msgGTandLT = "ERR GT and LT options at the same time are not compatible" + msgNXandXXGTLT = "ERR NX and XX, GT or LT options at the same time are not compatible" ) func errWrongNumber(cmd string) string {