diff --git a/README.md b/README.md index f987363c..d046ee0e 100644 --- a/README.md +++ b/README.md @@ -15,17 +15,25 @@ After this command *gomemcache* is ready to use. Its source will be in: ## Example - import ( - "github.com/bradfitz/gomemcache/memcache" - ) - - func main() { - mc := memcache.New("10.0.0.1:11211", "10.0.0.2:11211", "10.0.0.3:11212") - mc.Set(&memcache.Item{Key: "foo", Value: []byte("my value")}) - - it, err := mc.Get("foo") - ... - } +```go +import ( + "github.com/bradfitz/gomemcache/memcache" +) + +func main() { + mc := memcache.New("10.0.0.1:11211", "10.0.0.2:11211", "10.0.0.3:11212") + mc.Set(&memcache.Item{Key: "foo", Value: []byte("my value")}) + + it, err := mc.Get("foo") + + // With context + ctx, cancel := context.WithTimeout(2 * time.Second) + doLongTimeJob(ctx) + ... + it, err = mc.GetWithContext(ctx, "bar") + ... +} +``` ## Full docs, see: diff --git a/memcache/memcache.go b/memcache/memcache.go index 545a3e79..6064c28d 100644 --- a/memcache/memcache.go +++ b/memcache/memcache.go @@ -20,11 +20,11 @@ package memcache import ( "bufio" "bytes" + "context" "errors" "fmt" "io" "net" - "strconv" "strings" "sync" @@ -272,7 +272,14 @@ func (c *Client) dial(addr net.Addr) (net.Conn, error) { return nil, err } -func (c *Client) getConn(addr net.Addr) (*conn, error) { +func (c *Client) getConn(ctx context.Context, addr net.Addr) (*conn, error) { + // Check if the context is expired. + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + cn, ok := c.getFreeConn(addr) if ok { cn.extendDeadline() @@ -292,12 +299,12 @@ func (c *Client) getConn(addr net.Addr) (*conn, error) { return cn, nil } -func (c *Client) onItem(item *Item, fn func(*Client, *bufio.ReadWriter, *Item) error) error { +func (c *Client) onItem(ctx context.Context, item *Item, fn func(*Client, *bufio.ReadWriter, *Item) error) error { addr, err := c.selector.PickServer(item.Key) if err != nil { return err } - cn, err := c.getConn(addr) + cn, err := c.getConn(ctx, addr) if err != nil { return err } @@ -309,14 +316,24 @@ func (c *Client) onItem(item *Item, fn func(*Client, *bufio.ReadWriter, *Item) e } func (c *Client) FlushAll() error { - return c.selector.Each(c.flushAllFromAddr) + return c.FlushAllWithContext(context.Background()) +} + +func (c *Client) FlushAllWithContext(ctx context.Context) error { + return c.selector.Each(c.flushAllFromAddrWithContext(ctx)) } // Get gets the item for the given key. ErrCacheMiss is returned for a // memcache cache miss. The key must be at most 250 bytes in length. func (c *Client) Get(key string) (item *Item, err error) { + return c.GetWithContext(context.Background(), key) +} + +// GetWithContext gets the item for the given key. ErrCacheMiss is returned +// for a memcache cache miss. The key must be at most 250 bytes in length. +func (c *Client) GetWithContext(ctx context.Context, key string) (item *Item, err error) { err = c.withKeyAddr(key, func(addr net.Addr) error { - return c.getFromAddr(addr, []string{key}, func(it *Item) { item = it }) + return c.getFromAddr(ctx, addr, []string{key}, func(it *Item) { item = it }) }) if err == nil && item == nil { err = ErrCacheMiss @@ -330,8 +347,17 @@ func (c *Client) Get(key string) (item *Item, err error) { // no expiration time. ErrCacheMiss is returned if the key is not in the cache. // The key must be at most 250 bytes in length. func (c *Client) Touch(key string, seconds int32) (err error) { + return c.TouchWithContext(context.Background(), key, seconds) +} + +// TouchWithContext updates the expiry for the given key. The seconds parameter +// is either a Unix timestamp or, if seconds is less than 1 month, the number +// of seconds into the future at which time the item will expire. Zero means the +// item has no expiration time. ErrCacheMiss is returned if the key is not in +// the cache. The key must be at most 250 bytes in length. +func (c *Client) TouchWithContext(ctx context.Context, key string, seconds int32) (err error) { return c.withKeyAddr(key, func(addr net.Addr) error { - return c.touchFromAddr(addr, []string{key}, seconds) + return c.touchFromAddr(ctx, addr, []string{key}, seconds) }) } @@ -346,8 +372,8 @@ func (c *Client) withKeyAddr(key string, fn func(net.Addr) error) (err error) { return fn(addr) } -func (c *Client) withAddrRw(addr net.Addr, fn func(*bufio.ReadWriter) error) (err error) { - cn, err := c.getConn(addr) +func (c *Client) withAddrRw(ctx context.Context, addr net.Addr, fn func(*bufio.ReadWriter) error) (err error) { + cn, err := c.getConn(ctx, addr) if err != nil { return err } @@ -355,14 +381,14 @@ func (c *Client) withAddrRw(addr net.Addr, fn func(*bufio.ReadWriter) error) (er return fn(cn.rw) } -func (c *Client) withKeyRw(key string, fn func(*bufio.ReadWriter) error) error { +func (c *Client) withKeyRw(ctx context.Context, key string, fn func(*bufio.ReadWriter) error) error { return c.withKeyAddr(key, func(addr net.Addr) error { - return c.withAddrRw(addr, fn) + return c.withAddrRw(ctx, addr, fn) }) } -func (c *Client) getFromAddr(addr net.Addr, keys []string, cb func(*Item)) error { - return c.withAddrRw(addr, func(rw *bufio.ReadWriter) error { +func (c *Client) getFromAddr(ctx context.Context, addr net.Addr, keys []string, cb func(*Item)) error { + return c.withAddrRw(ctx, addr, func(rw *bufio.ReadWriter) error { if _, err := fmt.Fprintf(rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil { return err } @@ -376,55 +402,59 @@ func (c *Client) getFromAddr(addr net.Addr, keys []string, cb func(*Item)) error }) } -// flushAllFromAddr send the flush_all command to the given addr -func (c *Client) flushAllFromAddr(addr net.Addr) error { - return c.withAddrRw(addr, func(rw *bufio.ReadWriter) error { - if _, err := fmt.Fprintf(rw, "flush_all\r\n"); err != nil { - return err - } - if err := rw.Flush(); err != nil { - return err - } - line, err := rw.ReadSlice('\n') - if err != nil { - return err - } - switch { - case bytes.Equal(line, resultOk): - break - default: - return fmt.Errorf("memcache: unexpected response line from flush_all: %q", string(line)) - } - return nil - }) +// flushAllFromAddrWithContext send the flush_all command to the given addr +func (c *Client) flushAllFromAddrWithContext(ctx context.Context) func(addr net.Addr) error { + return func(addr net.Addr) error { + return c.withAddrRw(ctx, addr, func(rw *bufio.ReadWriter) error { + if _, err := fmt.Fprintf(rw, "flush_all\r\n"); err != nil { + return err + } + if err := rw.Flush(); err != nil { + return err + } + line, err := rw.ReadSlice('\n') + if err != nil { + return err + } + switch { + case bytes.Equal(line, resultOk): + break + default: + return fmt.Errorf("memcache: unexpected response line from flush_all: %q", string(line)) + } + return nil + }) + } } -// ping sends the version command to the given addr -func (c *Client) ping(addr net.Addr) error { - return c.withAddrRw(addr, func(rw *bufio.ReadWriter) error { - if _, err := fmt.Fprintf(rw, "version\r\n"); err != nil { - return err - } - if err := rw.Flush(); err != nil { - return err - } - line, err := rw.ReadSlice('\n') - if err != nil { - return err - } +// pingWithContext sends the version command to the given addr +func (c *Client) pingWithContext(ctx context.Context) func(addr net.Addr) error { + return func(addr net.Addr) error { + return c.withAddrRw(ctx, addr, func(rw *bufio.ReadWriter) error { + if _, err := fmt.Fprintf(rw, "version\r\n"); err != nil { + return err + } + if err := rw.Flush(); err != nil { + return err + } + line, err := rw.ReadSlice('\n') + if err != nil { + return err + } - switch { - case bytes.HasPrefix(line, versionPrefix): - break - default: - return fmt.Errorf("memcache: unexpected response line from ping: %q", string(line)) - } - return nil - }) + switch { + case bytes.HasPrefix(line, versionPrefix): + break + default: + return fmt.Errorf("memcache: unexpected response line from ping: %q", string(line)) + } + return nil + }) + } } -func (c *Client) touchFromAddr(addr net.Addr, keys []string, expiration int32) error { - return c.withAddrRw(addr, func(rw *bufio.ReadWriter) error { +func (c *Client) touchFromAddr(ctx context.Context, addr net.Addr, keys []string, expiration int32) error { + return c.withAddrRw(ctx, addr, func(rw *bufio.ReadWriter) error { for _, key := range keys { if _, err := fmt.Fprintf(rw, "touch %s %d\r\n", key, expiration); err != nil { return err @@ -454,6 +484,14 @@ func (c *Client) touchFromAddr(addr net.Addr, keys []string, expiration int32) e // cache misses. Each key must be at most 250 bytes in length. // If no error is returned, the returned map will also be non-nil. func (c *Client) GetMulti(keys []string) (map[string]*Item, error) { + return c.GetMultiWithContext(context.Background(), keys) +} + +// GetMultiWithContext is a batch version of Get. The returned map from +// keys to items may have fewer elements than the input slice, due to +// memcache cache misses. Each key must be at most 250 bytes in length. +// If no error is returned, the returned map will also be non-nil. +func (c *Client) GetMultiWithContext(ctx context.Context, keys []string) (map[string]*Item, error) { var lk sync.Mutex m := make(map[string]*Item) addItemToMap := func(it *Item) { @@ -477,7 +515,7 @@ func (c *Client) GetMulti(keys []string) (map[string]*Item, error) { ch := make(chan error, buffered) for addr, keys := range keyMap { go func(addr net.Addr, keys []string) { - ch <- c.getFromAddr(addr, keys, addItemToMap) + ch <- c.getFromAddr(ctx, addr, keys, addItemToMap) }(addr, keys) } @@ -539,7 +577,12 @@ func scanGetResponseLine(line []byte, it *Item) (size int, err error) { // Set writes the given item, unconditionally. func (c *Client) Set(item *Item) error { - return c.onItem(item, (*Client).set) + return c.SetWithContext(context.Background(), item) +} + +// SetWithContext writes the given item, unconditionally. +func (c *Client) SetWithContext(ctx context.Context, item *Item) error { + return c.onItem(ctx, item, (*Client).set) } func (c *Client) set(rw *bufio.ReadWriter, item *Item) error { @@ -549,7 +592,13 @@ func (c *Client) set(rw *bufio.ReadWriter, item *Item) error { // Add writes the given item, if no value already exists for its // key. ErrNotStored is returned if that condition is not met. func (c *Client) Add(item *Item) error { - return c.onItem(item, (*Client).add) + return c.AddWithContext(context.Background(), item) +} + +// AddWithContext writes the given item, if no value already exists for +// its key. ErrNotStored is returned if that condition is not met. +func (c *Client) AddWithContext(ctx context.Context, item *Item) error { + return c.onItem(ctx, item, (*Client).add) } func (c *Client) add(rw *bufio.ReadWriter, item *Item) error { @@ -559,7 +608,13 @@ func (c *Client) add(rw *bufio.ReadWriter, item *Item) error { // Replace writes the given item, but only if the server *does* // already hold data for this key func (c *Client) Replace(item *Item) error { - return c.onItem(item, (*Client).replace) + return c.ReplaceWithContext(context.Background(), item) +} + +// ReplaceWithContext writes the given item, but only if the server +// *does* already hold data for this key +func (c *Client) ReplaceWithContext(ctx context.Context, item *Item) error { + return c.onItem(ctx, item, (*Client).replace) } func (c *Client) replace(rw *bufio.ReadWriter, item *Item) error { @@ -574,7 +629,18 @@ func (c *Client) replace(rw *bufio.ReadWriter, item *Item) error { // calls. ErrNotStored is returned if the value was evicted in between // the calls. func (c *Client) CompareAndSwap(item *Item) error { - return c.onItem(item, (*Client).cas) + return c.CompareAndSwapWithContext(context.Background(), item) +} + +// CompareAndSwapWithContext writes the given item that was previously +// returned by Get, if the value was neither modified or evicted between +// the Get and the CompareAndSwap calls. The item's Key should not change +// between calls but all other item fields may differ. ErrCASConflict +// is returned if the value was modified in between the +// calls. ErrNotStored is returned if the value was evicted in between +// the calls. +func (c *Client) CompareAndSwapWithContext(ctx context.Context, item *Item) error { + return c.onItem(ctx, item, (*Client).cas) } func (c *Client) cas(rw *bufio.ReadWriter, item *Item) error { @@ -657,14 +723,25 @@ func writeExpectf(rw *bufio.ReadWriter, expect []byte, format string, args ...in // Delete deletes the item with the provided key. The error ErrCacheMiss is // returned if the item didn't already exist in the cache. func (c *Client) Delete(key string) error { - return c.withKeyRw(key, func(rw *bufio.ReadWriter) error { + return c.DeleteWithContext(context.Background(), key) +} + +// DeleteWithContext deletes the item with the provided key. The error +// ErrCacheMiss is returned if the item didn't already exist in the cache. +func (c *Client) DeleteWithContext(ctx context.Context, key string) error { + return c.withKeyRw(ctx, key, func(rw *bufio.ReadWriter) error { return writeExpectf(rw, resultDeleted, "delete %s\r\n", key) }) } // DeleteAll deletes all items in the cache. func (c *Client) DeleteAll() error { - return c.withKeyRw("", func(rw *bufio.ReadWriter) error { + return c.DeleteAllWithContext(context.Background()) +} + +// DeleteAllWithContext deletes all items in the cache. +func (c *Client) DeleteAllWithContext(ctx context.Context) error { + return c.withKeyRw(ctx, "", func(rw *bufio.ReadWriter) error { return writeExpectf(rw, resultDeleted, "flush_all\r\n") }) } @@ -672,7 +749,13 @@ func (c *Client) DeleteAll() error { // Ping checks all instances if they are alive. Returns error if any // of them is down. func (c *Client) Ping() error { - return c.selector.Each(c.ping) + return c.PingWithContext(context.Background()) +} + +// Ping checks all instances if they are alive. Returns error if any +// of them is down. +func (c *Client) PingWithContext(ctx context.Context) error { + return c.selector.Each(c.pingWithContext(ctx)) } // Increment atomically increments key by delta. The return value is @@ -681,7 +764,16 @@ func (c *Client) Ping() error { // memcached must be an decimal number, or an error will be returned. // On 64-bit overflow, the new value wraps around. func (c *Client) Increment(key string, delta uint64) (newValue uint64, err error) { - return c.incrDecr("incr", key, delta) + return c.IncrementWithContext(context.Background(), key, delta) +} + +// IncrementWithContext atomically increments key by delta. The return +// value is the new value after being incremented or an error. If the +// value didn't exist in memcached the error is ErrCacheMiss. The value +// in memcached must be an decimal number, or an error will be returned. +// On 64-bit overflow, the new value wraps around. +func (c *Client) IncrementWithContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error) { + return c.incrDecr(ctx, "incr", key, delta) } // Decrement atomically decrements key by delta. The return value is @@ -691,12 +783,22 @@ func (c *Client) Increment(key string, delta uint64) (newValue uint64, err error // On underflow, the new value is capped at zero and does not wrap // around. func (c *Client) Decrement(key string, delta uint64) (newValue uint64, err error) { - return c.incrDecr("decr", key, delta) + return c.DecrementWithContext(context.Background(), key, delta) +} + +// DecrementWithContext atomically decrements key by delta. The return +// value is the new value after being decremented or an error. If the +// value didn't exist in memcached the error is ErrCacheMiss. The value +// in memcached must be an decimal number, or an error will be returned. +// On underflow, the new value is capped at zero and does not wrap +// around. +func (c *Client) DecrementWithContext(ctx context.Context, key string, delta uint64) (newValue uint64, err error) { + return c.incrDecr(ctx, "decr", key, delta) } -func (c *Client) incrDecr(verb, key string, delta uint64) (uint64, error) { +func (c *Client) incrDecr(ctx context.Context, verb, key string, delta uint64) (uint64, error) { var val uint64 - err := c.withKeyRw(key, func(rw *bufio.ReadWriter) error { + err := c.withKeyRw(ctx, key, func(rw *bufio.ReadWriter) error { line, err := writeReadLine(rw, "%s %s %d\r\n", verb, key, delta) if err != nil { return err diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index 70d47026..d36aea37 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -19,6 +19,7 @@ package memcache import ( "bufio" + "context" "fmt" "io" "io/ioutil" @@ -121,13 +122,23 @@ func testWithClient(t *testing.T, c *Client) { t.Errorf("get(Hello_世界) Value = %q, want hello world", string(it.Value)) } + // Get with expired context + key := "foo" + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + time.Sleep(1010 * time.Millisecond) + it, err = c.GetWithContext(ctx, key) + if err != context.DeadlineExceeded { + t.Errorf("getWithContext(foo) should return context.DeadlineExceeded instead of %v", err) + } + // Set malformed keys malFormed := &Item{Key: "foo bar", Value: []byte("foobarval")} err = c.Set(malFormed) if err != ErrMalformedKey { t.Errorf("set(foo bar) should return ErrMalformedKey instead of %v", err) } - malFormed = &Item{Key: "foo" + string(0x7f), Value: []byte("foobarval")} + malFormed = &Item{Key: "foo" + string(rune(0x7f)), Value: []byte("foobarval")} err = c.Set(malFormed) if err != ErrMalformedKey { t.Errorf("set(foo<0x7f>) should return ErrMalformedKey instead of %v", err) @@ -279,7 +290,7 @@ func BenchmarkOnItem(b *testing.B) { addr := fakeServer.Addr() c := New(addr.String()) - if _, err := c.getConn(addr); err != nil { + if _, err := c.getConn(context.TODO(), addr); err != nil { b.Fatal("failed to initialize connection to fake server") } @@ -287,6 +298,6 @@ func BenchmarkOnItem(b *testing.B) { dummyFn := func(_ *Client, _ *bufio.ReadWriter, _ *Item) error { return nil } b.ResetTimer() for i := 0; i < b.N; i++ { - c.onItem(&item, dummyFn) + c.onItem(context.TODO(), &item, dummyFn) } }