33package miniredis
44
55import (
6+ "fmt"
67 "math/big"
78 "strconv"
89 "strings"
10+ "time"
911
1012 "github.com/alicebob/miniredis/v2/server"
1113)
@@ -28,6 +30,7 @@ func commandsHash(m *Miniredis) {
2830 m .srv .Register ("HVALS" , m .cmdHvals , server .ReadOnlyOption ())
2931 m .srv .Register ("HSCAN" , m .cmdHscan , server .ReadOnlyOption ())
3032 m .srv .Register ("HRANDFIELD" , m .cmdHrandfield , server .ReadOnlyOption ())
33+ m .srv .Register ("HEXPIRE" , m .cmdHexpire )
3134}
3235
3336// HSET
@@ -641,6 +644,151 @@ func (m *Miniredis) cmdHrandfield(c *server.Peer, cmd string, args []string) {
641644 })
642645}
643646
647+ // HEXPIRE
648+ func (m * Miniredis ) cmdHexpire (c * server.Peer , cmd string , args []string ) {
649+ if ! m .isValidCMD (c , cmd , args , atLeast (5 )) {
650+ return
651+ }
652+
653+ opts , err := parseHExpireArgs (args )
654+ if err != "" {
655+ setDirty (c )
656+ c .WriteError (err )
657+ return
658+ }
659+
660+ withTx (m , c , func (peer * server.Peer , ctx * connCtx ) {
661+ db := m .db (ctx .selectedDB )
662+
663+ if _ , ok := db .keys [opts .key ]; ! ok {
664+ c .WriteLen (len (opts .fields ))
665+ for range opts .fields {
666+ c .WriteInt (- 2 )
667+ }
668+ return
669+ }
670+
671+ if db .t (opts .key ) != keyTypeHash {
672+ c .WriteError (msgWrongType )
673+ return
674+ }
675+
676+ fieldTTLs := db .hashTTLs [opts .key ]
677+ if fieldTTLs == nil {
678+ fieldTTLs = map [string ]time.Duration {}
679+ db .hashTTLs [opts .key ] = fieldTTLs
680+ }
681+
682+ c .WriteLen (len (opts .fields ))
683+ for _ , field := range opts .fields {
684+ if _ , ok := db.hashKeys [opts.key ][field ]; ! ok {
685+ c .WriteInt (- 2 )
686+ continue
687+ }
688+
689+ currentTtl , ok := fieldTTLs [field ]
690+ newTTL := time .Duration (opts .ttl ) * time .Second
691+
692+ // NX -- For each specified field,
693+ // set expiration only when the field has no expiration.
694+ if opts .nx && ok {
695+ c .WriteInt (0 )
696+ continue
697+ }
698+
699+ // XX -- For each specified field,
700+ // set expiration only when the field has an existing expiration.
701+ if opts .xx && ! ok {
702+ c .WriteInt (0 )
703+ continue
704+ }
705+
706+ // GT -- For each specified field,
707+ // set expiration only when the new expiration is greater than current one.
708+ if opts .gt && (! ok || newTTL <= currentTtl ) {
709+ c .WriteInt (0 )
710+ continue
711+ }
712+
713+ // LT -- For each specified field,
714+ // set expiration only when the new expiration is less than current one.
715+ if opts .lt && ok && newTTL >= currentTtl {
716+ c .WriteInt (0 )
717+ continue
718+ }
719+
720+ fieldTTLs [field ] = newTTL
721+ c .WriteInt (1 )
722+ }
723+ })
724+ }
725+
726+ type hexpireOpts struct {
727+ key string
728+ ttl int
729+ nx bool
730+ xx bool
731+ gt bool
732+ lt bool
733+ fields []string
734+ }
735+
736+ func parseHExpireArgs (args []string ) (hexpireOpts , string ) {
737+ var opts hexpireOpts
738+ opts .key = args [0 ]
739+
740+ if err := optIntSimple (args [1 ], & opts .ttl ); err != nil {
741+ return hexpireOpts {}, err .Error ()
742+ }
743+
744+ args = args [2 :]
745+
746+ for len (args ) > 0 {
747+ switch strings .ToLower (args [0 ]) {
748+ case "nx" :
749+ opts .nx = true
750+ args = args [1 :]
751+ case "xx" :
752+ opts .xx = true
753+ args = args [1 :]
754+ case "gt" :
755+ opts .gt = true
756+ args = args [1 :]
757+ case "lt" :
758+ opts .lt = true
759+ args = args [1 :]
760+ case "fields" :
761+ var numFields int
762+ if err := optIntSimple (args [1 ], & numFields ); err != nil {
763+ return hexpireOpts {}, msgNumFieldsInvalid
764+ }
765+ if numFields <= 0 {
766+ return hexpireOpts {}, msgNumFieldsInvalid
767+ }
768+
769+ // FIELDS numFields field1 field2 ...
770+ if len (args ) < 2 + numFields {
771+ return hexpireOpts {}, msgNumFieldsParameter
772+ }
773+
774+ opts .fields = append ([]string {}, args [2 :2 + numFields ]... )
775+ args = args [2 + numFields :]
776+ default :
777+ return hexpireOpts {}, fmt .Sprintf (msgMandatoryArgument , "FIELDS" )
778+ }
779+ }
780+
781+ if opts .gt && opts .lt {
782+ return hexpireOpts {}, msgGTandLT
783+ }
784+
785+ if opts .nx && (opts .xx || opts .gt || opts .lt ) {
786+ return hexpireOpts {}, msgNXandXXGTLT
787+ }
788+
789+ return opts , ""
790+ }
791+
644792func abs (n int ) int {
645793 if n < 0 {
646794 return - n
0 commit comments