From b4f1f0b328fa6134316ec183a3449e208daefb66 Mon Sep 17 00:00:00 2001 From: "Pascal S. de Kloe" Date: Thu, 28 Nov 2019 03:59:20 +0100 Subject: [PATCH] Change SET option vararg to struct. --- command.go | 159 ++++++++++++++++++++++++++++++++++++++++-------- command_test.go | 6 +- example_test.go | 8 ++- resp.go | 108 +++++++++++++++++++++++++++++--- 4 files changed, 240 insertions(+), 41 deletions(-) diff --git a/command.go b/command.go index 1b30d29..9866380 100644 --- a/command.go +++ b/command.go @@ -1,18 +1,68 @@ package redis -// Option Codes For SETWithArgs -const ( - // EX sets the specified expire time, in seconds. - EX = "EX" - // PX sets the specified expire time, in milliseconds. - PX = "PX" +import ( + "errors" + "fmt" + "time" +) +// Flags For SETOptions +const ( // NX only sets the key if it does not already exist. - NX = "NX" + NX = 1 << iota // XX only sets the key if it does already exist. - XX = "XX" + XX + + // EX sets an expire time, in seconds. + EX + // PX sets an expire time, in milliseconds. + PX ) +// SETOptions are extra arguments for the SET command. +type SETOptions struct { + // Composotion of NX, XX, EX or PX. + Flags uint + + // The value is rounded to seconds with the EX flag, + // and milliseconds with PX. Non-zero values without + // expiry Flags are rejected to prevent mistakes. + Expire time.Duration +} + +func (o *SETOptions) args() (existArg, expireArg string, expire int64, err error) { + if unknown := o.Flags &^ (NX | XX | EX | PX); unknown != 0 { + return "", "", 0, fmt.Errorf("redis: unknown flags %#x", unknown) + } + + switch o.Flags & (NX | XX) { + case 0: + break + case NX: + existArg = "NX" + case XX: + existArg = "XX" + default: + return "", "", 0, errors.New("redis: combination of NX and XX not allowed") + } + + switch o.Flags & (EX | PX) { + case 0: + if o.Expire != 0 { + return "", "", 0, errors.New("redis: expire time without EX nor PX not allowed") + } + case EX: + expireArg = "EX" + expire = int64(o.Expire / time.Second) + case PX: + expireArg = "PX" + expire = int64(o.Expire / time.Millisecond) + default: + return "", "", 0, errors.New("redis: combination of EX and PX not allowed") + } + return +} + // SELECT executes . func (c *Client) SELECT(db int64) error { r := newRequest("*2\r\n$6\r\nSELECT\r\n$") @@ -127,39 +177,96 @@ func (c *Client) SETString(key, value string) error { return c.commandOK(r) } -// SETWithArgs executes with options. +// SETWithOptions executes with options. // The return is false if the SET operation was not performed due to an NX or XX -// condition. See EX, PX, NX and XX for details. -func (c *Client) SETWithArgs(key string, value []byte, options ...string) (bool, error) { - r := newRequestSize(3+len(options), "\r\n$3\r\nSET\r\n$") - r.addStringBytesStringList(key, value, options) - err := c.commandOK(r) +// condition. +func (c *Client) SETWithOptions(key string, value []byte, o SETOptions) (bool, error) { + existArg, expireArg, expire, err := o.args() + if err != nil { + return false, err + } + + var r *request + switch { + case existArg != "" && expireArg == "": + r = newRequest("*4\r\n$3\r\nSET\r\n$") + r.addStringBytesString(key, value, existArg) + case existArg == "" && expireArg != "": + r = newRequest("*5\r\n$3\r\nSET\r\n$") + r.addStringBytesStringInt(key, value, expireArg, expire) + case existArg != "" && expireArg != "": + r = newRequest("*6\r\n$3\r\nSET\r\n$") + r.addStringBytesStringStringInt(key, value, existArg, expireArg, expire) + default: + err := c.SET(key, value) + return err == nil, err + } + + err = c.commandOK(r) if err == errNull { return false, nil } return err == nil, err } -// BytesSETWithArgs executes with options. +// BytesSETWithOptions executes with options. // The return is false if the SET operation was not performed due to an NX or XX -// condition. See EX, PX, NX and XX for details. -func (c *Client) BytesSETWithArgs(key, value []byte, options ...string) (bool, error) { - r := newRequestSize(3+len(options), "\r\n$3\r\nSET\r\n$") - r.addBytesBytesStringList(key, value, options) - err := c.commandOK(r) +// condition. +func (c *Client) BytesSETWithOptions(key, value []byte, o SETOptions) (bool, error) { + existArg, expireArg, expire, err := o.args() + if err != nil { + return false, err + } + + var r *request + switch { + case existArg != "" && expireArg == "": + r = newRequest("*4\r\n$3\r\nSET\r\n$") + r.addBytesBytesString(key, value, existArg) + case existArg == "" && expireArg != "": + r = newRequest("*5\r\n$3\r\nSET\r\n$") + r.addBytesBytesStringInt(key, value, expireArg, expire) + case existArg != "" && expireArg != "": + r = newRequest("*6\r\n$3\r\nSET\r\n$") + r.addBytesBytesStringStringInt(key, value, existArg, expireArg, expire) + default: + err := c.BytesSET(key, value) + return err == nil, err + } + + err = c.commandOK(r) if err == errNull { return false, nil } return err == nil, err } -// SETStringWithArgs executes with options. +// SETStringWithOptions executes with options. // The return is false if the SET operation was not performed due to an NX or XX -// condition. See EX, PX, NX and XX for details. -func (c *Client) SETStringWithArgs(key, value string, options ...string) (bool, error) { - r := newRequestSize(3+len(options), "\r\n$3\r\nSET\r\n$") - r.addStringStringStringList(key, value, options) - err := c.commandOK(r) +// condition. +func (c *Client) SETStringWithOptions(key, value string, o SETOptions) (bool, error) { + existArg, expireArg, expire, err := o.args() + if err != nil { + return false, err + } + + var r *request + switch { + case existArg != "" && expireArg == "": + r = newRequest("*4\r\n$3\r\nSET\r\n$") + r.addStringStringString(key, value, existArg) + case existArg == "" && expireArg != "": + r = newRequest("*5\r\n$3\r\nSET\r\n$") + r.addStringStringStringInt(key, value, expireArg, expire) + case existArg != "" && expireArg != "": + r = newRequest("*6\r\n$3\r\nSET\r\n$") + r.addStringStringStringStringInt(key, value, existArg, expireArg, expire) + default: + err := c.SETString(key, value) + return err == nil, err + } + + err = c.commandOK(r) if err == errNull { return false, nil } diff --git a/command_test.go b/command_test.go index d5f81f9..d5790b6 100644 --- a/command_test.go +++ b/command_test.go @@ -241,13 +241,13 @@ func TestKeyOptions(t *testing.T) { t.Parallel() key := randomKey("test") - if ok, err := testClient.BytesSETWithArgs([]byte(key), nil, "XX"); err != nil { + if ok, err := testClient.BytesSETWithOptions([]byte(key), nil, SETOptions{Flags: XX}); err != nil { t.Fatalf(`SET %q "" XX error: %s`, key, err) } else if ok { t.Fatalf(`SET %q "" XX got true`, key) } - if ok, err := testClient.SETWithArgs(key, nil, "PX", "1"); err != nil { + if ok, err := testClient.SETWithOptions(key, nil, SETOptions{Flags: PX, Expire: time.Millisecond}); err != nil { t.Fatalf(`SET %q "" PX 1 error: %s`, key, err) } else if !ok { t.Fatalf(`SET %q "" PX 1 got false`, key) @@ -255,7 +255,7 @@ func TestKeyOptions(t *testing.T) { time.Sleep(20 * time.Millisecond) - if ok, err := testClient.SETStringWithArgs(key, "value", "NX"); err != nil { + if ok, err := testClient.SETStringWithOptions(key, "value", SETOptions{Flags: NX}); err != nil { t.Errorf(`SET %q "value" "NX" error: %s`, key, err) } else if !ok { t.Errorf(`SET %q "value" "NX" got false`, key) diff --git a/example_test.go b/example_test.go index 45c0e69..9a7aed3 100644 --- a/example_test.go +++ b/example_test.go @@ -7,15 +7,17 @@ import ( "github.com/pascaldekloe/redis" ) -// SET With Options -func ExampleClient_SETWithArgs() { +func ExampleClient_SETStringWithOptions() { // connection setup var Redis = redis.NewClient("rds1.example.com", 5*time.Millisecond, time.Second) // terminate after example defer Redis.Close() // execute command - ok, err := Redis.SETWithArgs("hello", nil, redis.EX, "60", redis.NX) + ok, err := Redis.SETStringWithOptions("k", "v", redis.SETOptions{ + Flags: redis.NX | redis.EX, + Expire: time.Minute, + }) if err != nil { log.Print("error: ", err) return diff --git a/resp.go b/resp.go index c1c6d42..ec59238 100644 --- a/resp.go +++ b/resp.go @@ -262,6 +262,39 @@ func (r *request) addBytesBytes(a1, a2 []byte) { r.buf = append(r.buf, '\r', '\n') } +func (r *request) addBytesBytesString(a1, a2 []byte, a3 string) { + r.bytes(a1) + r.buf = append(r.buf, '\r', '\n', '$') + r.bytes(a2) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a3) + r.buf = append(r.buf, '\r', '\n') +} + +func (r *request) addBytesBytesStringInt(a1, a2 []byte, a3 string, a4 int64) { + r.bytes(a1) + r.buf = append(r.buf, '\r', '\n', '$') + r.bytes(a2) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a3) + r.buf = append(r.buf, '\r', '\n', '$') + r.decimal(a4) + r.buf = append(r.buf, '\r', '\n') +} + +func (r *request) addBytesBytesStringStringInt(a1, a2 []byte, a3, a4 string, a5 int64) { + r.bytes(a1) + r.buf = append(r.buf, '\r', '\n', '$') + r.bytes(a2) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a3) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a4) + r.buf = append(r.buf, '\r', '\n', '$') + r.decimal(a5) + r.buf = append(r.buf, '\r', '\n') +} + func (r *request) addBytesBytesList(a1 []byte, a2 [][]byte) { r.bytes(a1) for _, b := range a2 { @@ -299,6 +332,39 @@ func (r *request) addStringBytes(a1 string, a2 []byte) { r.buf = append(r.buf, '\r', '\n') } +func (r *request) addStringBytesString(a1 string, a2 []byte, a3 string) { + r.string(a1) + r.buf = append(r.buf, '\r', '\n', '$') + r.bytes(a2) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a3) + r.buf = append(r.buf, '\r', '\n') +} + +func (r *request) addStringBytesStringInt(a1 string, a2 []byte, a3 string, a4 int64) { + r.string(a1) + r.buf = append(r.buf, '\r', '\n', '$') + r.bytes(a2) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a3) + r.buf = append(r.buf, '\r', '\n', '$') + r.decimal(a4) + r.buf = append(r.buf, '\r', '\n') +} + +func (r *request) addStringBytesStringStringInt(a1 string, a2 []byte, a3, a4 string, a5 int64) { + r.string(a1) + r.buf = append(r.buf, '\r', '\n', '$') + r.bytes(a2) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a3) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a4) + r.buf = append(r.buf, '\r', '\n', '$') + r.decimal(a5) + r.buf = append(r.buf, '\r', '\n') +} + func (r *request) addStringBytesMapLists(a1 []string, a2 [][]byte) error { if len(a1) != len(a2) { return errMapSlices @@ -327,6 +393,39 @@ func (r *request) addStringString(a1, a2 string) { r.buf = append(r.buf, '\r', '\n') } +func (r *request) addStringStringString(a1, a2, a3 string) { + r.string(a1) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a2) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a3) + r.buf = append(r.buf, '\r', '\n') +} + +func (r *request) addStringStringStringInt(a1, a2, a3 string, a4 int64) { + r.string(a1) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a2) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a3) + r.buf = append(r.buf, '\r', '\n', '$') + r.decimal(a4) + r.buf = append(r.buf, '\r', '\n') +} + +func (r *request) addStringStringStringStringInt(a1, a2, a3, a4 string, a5 int64) { + r.string(a1) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a2) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a3) + r.buf = append(r.buf, '\r', '\n', '$') + r.string(a4) + r.buf = append(r.buf, '\r', '\n', '$') + r.decimal(a5) + r.buf = append(r.buf, '\r', '\n') +} + func (r *request) addStringStringMapLists(a1, a2 []string) error { if len(a1) != len(a2) { return errMapSlices @@ -451,15 +550,6 @@ func (r *request) addStringStringBytes(a1, a2 string, a3 []byte) { r.buf = append(r.buf, '\r', '\n') } -func (r *request) addStringStringString(a1, a2, a3 string) { - r.string(a1) - r.buf = append(r.buf, '\r', '\n', '$') - r.string(a2) - r.buf = append(r.buf, '\r', '\n', '$') - r.string(a3) - r.buf = append(r.buf, '\r', '\n') -} - func (r *request) addStringStringStringList(a1, a2 string, a3 []string) { r.string(a1) r.buf = append(r.buf, '\r', '\n', '$')