Skip to content

Commit

Permalink
Change SET option vararg to struct.
Browse files Browse the repository at this point in the history
  • Loading branch information
pascaldekloe committed Nov 28, 2019
1 parent a6ee032 commit b4f1f0b
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 41 deletions.
159 changes: 133 additions & 26 deletions command.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,68 @@
package redis

// Option Codes For SETWithArgs
const (
// EX sets the specified expire time, in seconds.
EX = "EX"
// PX sets the specified expire time, in milliseconds.
PX = "PX"
import (
"errors"
"fmt"
"time"
)

// Flags For SETOptions
const (
// NX only sets the key if it does not already exist.
NX = "NX"
NX = 1 << iota
// XX only sets the key if it does already exist.
XX = "XX"
XX

// EX sets an expire time, in seconds.
EX
// PX sets an expire time, in milliseconds.
PX
)

// SETOptions are extra arguments for the SET command.
type SETOptions struct {
// Composotion of NX, XX, EX or PX.
Flags uint

// The value is rounded to seconds with the EX flag,
// and milliseconds with PX. Non-zero values without
// expiry Flags are rejected to prevent mistakes.
Expire time.Duration
}

func (o *SETOptions) args() (existArg, expireArg string, expire int64, err error) {
if unknown := o.Flags &^ (NX | XX | EX | PX); unknown != 0 {
return "", "", 0, fmt.Errorf("redis: unknown flags %#x", unknown)
}

switch o.Flags & (NX | XX) {
case 0:
break
case NX:
existArg = "NX"
case XX:
existArg = "XX"
default:
return "", "", 0, errors.New("redis: combination of NX and XX not allowed")
}

switch o.Flags & (EX | PX) {
case 0:
if o.Expire != 0 {
return "", "", 0, errors.New("redis: expire time without EX nor PX not allowed")
}
case EX:
expireArg = "EX"
expire = int64(o.Expire / time.Second)
case PX:
expireArg = "PX"
expire = int64(o.Expire / time.Millisecond)
default:
return "", "", 0, errors.New("redis: combination of EX and PX not allowed")
}
return
}

// SELECT executes <https://redis.io/commands/select>.
func (c *Client) SELECT(db int64) error {
r := newRequest("*2\r\n$6\r\nSELECT\r\n$")
Expand Down Expand Up @@ -127,39 +177,96 @@ func (c *Client) SETString(key, value string) error {
return c.commandOK(r)
}

// SETWithArgs executes <https://redis.io/commands/set> with options.
// SETWithOptions executes <https://redis.io/commands/set> with options.
// The return is false if the SET operation was not performed due to an NX or XX
// condition. See EX, PX, NX and XX for details.
func (c *Client) SETWithArgs(key string, value []byte, options ...string) (bool, error) {
r := newRequestSize(3+len(options), "\r\n$3\r\nSET\r\n$")
r.addStringBytesStringList(key, value, options)
err := c.commandOK(r)
// condition.
func (c *Client) SETWithOptions(key string, value []byte, o SETOptions) (bool, error) {
existArg, expireArg, expire, err := o.args()
if err != nil {
return false, err
}

var r *request
switch {
case existArg != "" && expireArg == "":
r = newRequest("*4\r\n$3\r\nSET\r\n$")
r.addStringBytesString(key, value, existArg)
case existArg == "" && expireArg != "":
r = newRequest("*5\r\n$3\r\nSET\r\n$")
r.addStringBytesStringInt(key, value, expireArg, expire)
case existArg != "" && expireArg != "":
r = newRequest("*6\r\n$3\r\nSET\r\n$")
r.addStringBytesStringStringInt(key, value, existArg, expireArg, expire)
default:
err := c.SET(key, value)
return err == nil, err
}

err = c.commandOK(r)
if err == errNull {
return false, nil
}
return err == nil, err
}

// BytesSETWithArgs executes <https://redis.io/commands/set> with options.
// BytesSETWithOptions executes <https://redis.io/commands/set> with options.
// The return is false if the SET operation was not performed due to an NX or XX
// condition. See EX, PX, NX and XX for details.
func (c *Client) BytesSETWithArgs(key, value []byte, options ...string) (bool, error) {
r := newRequestSize(3+len(options), "\r\n$3\r\nSET\r\n$")
r.addBytesBytesStringList(key, value, options)
err := c.commandOK(r)
// condition.
func (c *Client) BytesSETWithOptions(key, value []byte, o SETOptions) (bool, error) {
existArg, expireArg, expire, err := o.args()
if err != nil {
return false, err
}

var r *request
switch {
case existArg != "" && expireArg == "":
r = newRequest("*4\r\n$3\r\nSET\r\n$")
r.addBytesBytesString(key, value, existArg)
case existArg == "" && expireArg != "":
r = newRequest("*5\r\n$3\r\nSET\r\n$")
r.addBytesBytesStringInt(key, value, expireArg, expire)
case existArg != "" && expireArg != "":
r = newRequest("*6\r\n$3\r\nSET\r\n$")
r.addBytesBytesStringStringInt(key, value, existArg, expireArg, expire)
default:
err := c.BytesSET(key, value)
return err == nil, err
}

err = c.commandOK(r)
if err == errNull {
return false, nil
}
return err == nil, err
}

// SETStringWithArgs executes <https://redis.io/commands/set> with options.
// SETStringWithOptions executes <https://redis.io/commands/set> with options.
// The return is false if the SET operation was not performed due to an NX or XX
// condition. See EX, PX, NX and XX for details.
func (c *Client) SETStringWithArgs(key, value string, options ...string) (bool, error) {
r := newRequestSize(3+len(options), "\r\n$3\r\nSET\r\n$")
r.addStringStringStringList(key, value, options)
err := c.commandOK(r)
// condition.
func (c *Client) SETStringWithOptions(key, value string, o SETOptions) (bool, error) {
existArg, expireArg, expire, err := o.args()
if err != nil {
return false, err
}

var r *request
switch {
case existArg != "" && expireArg == "":
r = newRequest("*4\r\n$3\r\nSET\r\n$")
r.addStringStringString(key, value, existArg)
case existArg == "" && expireArg != "":
r = newRequest("*5\r\n$3\r\nSET\r\n$")
r.addStringStringStringInt(key, value, expireArg, expire)
case existArg != "" && expireArg != "":
r = newRequest("*6\r\n$3\r\nSET\r\n$")
r.addStringStringStringStringInt(key, value, existArg, expireArg, expire)
default:
err := c.SETString(key, value)
return err == nil, err
}

err = c.commandOK(r)
if err == errNull {
return false, nil
}
Expand Down
6 changes: 3 additions & 3 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,21 +241,21 @@ func TestKeyOptions(t *testing.T) {
t.Parallel()
key := randomKey("test")

if ok, err := testClient.BytesSETWithArgs([]byte(key), nil, "XX"); err != nil {
if ok, err := testClient.BytesSETWithOptions([]byte(key), nil, SETOptions{Flags: XX}); err != nil {
t.Fatalf(`SET %q "" XX error: %s`, key, err)
} else if ok {
t.Fatalf(`SET %q "" XX got true`, key)
}

if ok, err := testClient.SETWithArgs(key, nil, "PX", "1"); err != nil {
if ok, err := testClient.SETWithOptions(key, nil, SETOptions{Flags: PX, Expire: time.Millisecond}); err != nil {
t.Fatalf(`SET %q "" PX 1 error: %s`, key, err)
} else if !ok {
t.Fatalf(`SET %q "" PX 1 got false`, key)
}

time.Sleep(20 * time.Millisecond)

if ok, err := testClient.SETStringWithArgs(key, "value", "NX"); err != nil {
if ok, err := testClient.SETStringWithOptions(key, "value", SETOptions{Flags: NX}); err != nil {
t.Errorf(`SET %q "value" "NX" error: %s`, key, err)
} else if !ok {
t.Errorf(`SET %q "value" "NX" got false`, key)
Expand Down
8 changes: 5 additions & 3 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ import (
"github.com/pascaldekloe/redis"
)

// SET With Options
func ExampleClient_SETWithArgs() {
func ExampleClient_SETStringWithOptions() {
// connection setup
var Redis = redis.NewClient("rds1.example.com", 5*time.Millisecond, time.Second)
// terminate after example
defer Redis.Close()

// execute command
ok, err := Redis.SETWithArgs("hello", nil, redis.EX, "60", redis.NX)
ok, err := Redis.SETStringWithOptions("k", "v", redis.SETOptions{
Flags: redis.NX | redis.EX,
Expire: time.Minute,
})
if err != nil {
log.Print("error: ", err)
return
Expand Down
108 changes: 99 additions & 9 deletions resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,39 @@ func (r *request) addBytesBytes(a1, a2 []byte) {
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addBytesBytesString(a1, a2 []byte, a3 string) {
r.bytes(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.bytes(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addBytesBytesStringInt(a1, a2 []byte, a3 string, a4 int64) {
r.bytes(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.bytes(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n', '$')
r.decimal(a4)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addBytesBytesStringStringInt(a1, a2 []byte, a3, a4 string, a5 int64) {
r.bytes(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.bytes(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a4)
r.buf = append(r.buf, '\r', '\n', '$')
r.decimal(a5)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addBytesBytesList(a1 []byte, a2 [][]byte) {
r.bytes(a1)
for _, b := range a2 {
Expand Down Expand Up @@ -299,6 +332,39 @@ func (r *request) addStringBytes(a1 string, a2 []byte) {
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringBytesString(a1 string, a2 []byte, a3 string) {
r.string(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.bytes(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringBytesStringInt(a1 string, a2 []byte, a3 string, a4 int64) {
r.string(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.bytes(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n', '$')
r.decimal(a4)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringBytesStringStringInt(a1 string, a2 []byte, a3, a4 string, a5 int64) {
r.string(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.bytes(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a4)
r.buf = append(r.buf, '\r', '\n', '$')
r.decimal(a5)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringBytesMapLists(a1 []string, a2 [][]byte) error {
if len(a1) != len(a2) {
return errMapSlices
Expand Down Expand Up @@ -327,6 +393,39 @@ func (r *request) addStringString(a1, a2 string) {
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringStringString(a1, a2, a3 string) {
r.string(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringStringStringInt(a1, a2, a3 string, a4 int64) {
r.string(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n', '$')
r.decimal(a4)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringStringStringStringInt(a1, a2, a3, a4 string, a5 int64) {
r.string(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a4)
r.buf = append(r.buf, '\r', '\n', '$')
r.decimal(a5)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringStringMapLists(a1, a2 []string) error {
if len(a1) != len(a2) {
return errMapSlices
Expand Down Expand Up @@ -451,15 +550,6 @@ func (r *request) addStringStringBytes(a1, a2 string, a3 []byte) {
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringStringString(a1, a2, a3 string) {
r.string(a1)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a2)
r.buf = append(r.buf, '\r', '\n', '$')
r.string(a3)
r.buf = append(r.buf, '\r', '\n')
}

func (r *request) addStringStringStringList(a1, a2 string, a3 []string) {
r.string(a1)
r.buf = append(r.buf, '\r', '\n', '$')
Expand Down

0 comments on commit b4f1f0b

Please sign in to comment.