Skip to content

Commit

Permalink
fix: Additional tests for the in-flight key tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
creativecreature committed May 25, 2024
1 parent 2532463 commit e3dca8f
Show file tree
Hide file tree
Showing 9 changed files with 265 additions and 172 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ test:
@echo 'Removing test cache...'
go clean -testcache
@echo 'Running tests...'
go test -race -vet=off -timeout 30s ./...
go test -race -vet=off -timeout 15s ./...

## bench: run all benchmarks
.PHONY: bench
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# `sturdyc`: a caching library for building sturdy systemsreadme
# `sturdyc`: a caching library for building sturdy systems

[![Go Reference](https://pkg.go.dev/badge/github.com/creativecreature/sturdyc.svg)](https://pkg.go.dev/github.com/creativecreature/sturdyc)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/creativecreature/sturdyc/blob/master/LICENSE)
Expand Down
9 changes: 9 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,12 @@ func (c *Client[T]) Delete(key string) {
shard := c.getShard(key)
shard.delete(key)
}

// NumKeysInflight returns the number of keys that are currently being fetched.
func (c *Client[T]) NumKeysInflight() int {
c.inFlightMutex.Lock()
defer c.inFlightMutex.Unlock()
c.inFlightBatchMutex.Lock()
defer c.inFlightBatchMutex.Unlock()
return len(c.inFlightMap) + len(c.inFlightBatchMap)
}
117 changes: 59 additions & 58 deletions inflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sturdyc
import (
"context"
"errors"
"fmt"
"sync"
)

Expand All @@ -20,6 +21,54 @@ func (c *Client[T]) newFlight(key string) *inFlightCall[T] {
return call
}

func makeCall[T, V any](ctx context.Context, c *Client[T], key string, fn FetchFn[V], call *inFlightCall[T]) {
defer func() {
if err := recover(); err != nil {
call.err = fmt.Errorf("panic recovered: %v", err)
}
call.Done()
c.inFlightMutex.Lock()
delete(c.inFlightMap, key)
c.inFlightMutex.Unlock()
}()

response, err := fn(ctx)
if err != nil && c.storeMissingRecords && errors.Is(err, ErrStoreMissingRecord) {
c.SetMissing(key, *new(T), true)
call.err = ErrMissingRecord
return
}

if err != nil {
call.err = err
return
}

res, ok := any(response).(T)
if !ok {
call.err = ErrInvalidType
return
}

call.err = nil
call.val = res
c.SetMissing(key, res, false)
}

func callAndCache[V, T any](ctx context.Context, c *Client[T], key string, fn FetchFn[V]) (V, error) {
c.inFlightMutex.Lock()
if call, ok := c.inFlightMap[key]; ok {
c.inFlightMutex.Unlock()
call.Wait()
return unwrap[V, T](call.val, call.err)
}

call := c.newFlight(key)
c.inFlightMutex.Unlock()
makeCall(ctx, c, key, fn, call)
return unwrap[V, T](call.val, call.err)
}

// newBatchFlight should be called with a lock.
func (c *Client[T]) newBatchFlight(ids []string, keyFn KeyFn) *inFlightCall[map[string]T] {
call := new(inFlightCall[map[string]T])
Expand All @@ -31,13 +80,6 @@ func (c *Client[T]) newBatchFlight(ids []string, keyFn KeyFn) *inFlightCall[map[
return call
}

func (c *Client[T]) endFlight(call *inFlightCall[T], key string) {
call.Done()
c.inFlightMutex.Lock()
delete(c.inFlightMap, key)
c.inFlightMutex.Unlock()
}

func (c *Client[T]) endBatchFlight(ids []string, keyFn KeyFn, call *inFlightCall[map[string]T]) {
call.Done()
c.inFlightBatchMutex.Lock()
Expand All @@ -47,12 +89,6 @@ func (c *Client[T]) endBatchFlight(ids []string, keyFn KeyFn, call *inFlightCall
c.inFlightBatchMutex.Unlock()
}

func (c *Client[T]) endErrorFlight(call *inFlightCall[T], key string, err error) error {
call.err = err
c.endFlight(call, key)
return err
}

type makeBatchCallOpts[T, V any] struct {
ids []string
fn BatchFetchFn[V]
Expand All @@ -64,16 +100,14 @@ func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCa
response, err := opts.fn(ctx, opts.ids)
if err != nil {
opts.call.err = err
c.endBatchFlight(opts.ids, opts.keyFn, opts.call)
return
}

// Check if we should store any of these IDs as a missing record.
if c.storeMissingRecords && len(response) < len(opts.ids) {
for _, id := range opts.ids {
if _, ok := response[id]; !ok {
var zero T
c.SetMissing(opts.keyFn(id), zero, true)
c.SetMissing(opts.keyFn(id), *new(T), true)
}
}
}
Expand All @@ -87,40 +121,6 @@ func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCa
c.SetMissing(opts.keyFn(id), v, false)
opts.call.val[id] = v
}
c.endBatchFlight(opts.ids, opts.keyFn, opts.call)
}

func callAndCache[V, T any](ctx context.Context, c *Client[T], key string, fn FetchFn[V]) (V, error) {
c.inFlightMutex.Lock()
if call, ok := c.inFlightMap[key]; ok {
c.inFlightMutex.Unlock()
call.Wait()
return unwrap[V, T](call.val, call.err)
}

call := c.newFlight(key)
c.inFlightMutex.Unlock()

response, err := fn(ctx)
if err != nil && c.storeMissingRecords && errors.Is(err, ErrStoreMissingRecord) {
c.SetMissing(key, *new(T), true)
return response, c.endErrorFlight(call, key, ErrMissingRecord)
}

if err != nil {
return response, c.endErrorFlight(call, key, err)
}

res, ok := any(response).(T)
if !ok {
return response, c.endErrorFlight(call, key, ErrInvalidType)
}

c.SetMissing(key, res, false)
call.val = res
call.err = nil
c.endFlight(call, key)
return response, nil
}

type callBatchOpts[T, V any] struct {
Expand All @@ -132,7 +132,6 @@ type callBatchOpts[T, V any] struct {
func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBatchOpts[T, V]) (map[string]V, error) {
c.inFlightBatchMutex.Lock()

// We need to keep track of the specific IDs we're after for a particular call.
callIDs := make(map[*inFlightCall[map[string]T]][]string)
uniqueIDs := make([]string, 0, len(opts.ids))
for _, id := range opts.ids {
Expand All @@ -145,20 +144,22 @@ func callAndCacheBatch[V, T any](ctx context.Context, c *Client[T], opts callBat

if len(uniqueIDs) > 0 {
call := c.newBatchFlight(uniqueIDs, opts.keyFn)
for _, id := range uniqueIDs {
c.inFlightBatchMap[opts.keyFn(id)] = call
callIDs[call] = append(callIDs[call], id)
}

safeGo(func() {
callIDs[call] = append(callIDs[call], uniqueIDs...)
go func() {
defer func() {
if err := recover(); err != nil {
call.err = fmt.Errorf("panic recovered: %v", err)
}
c.endBatchFlight(uniqueIDs, opts.keyFn, call)
}()
batchCallOpts := makeBatchCallOpts[T, V]{
ids: uniqueIDs,
fn: opts.fn,
keyFn: opts.keyFn,
call: call,
}
makeBatchCall(ctx, c, batchCallOpts)
})
}()
}
c.inFlightBatchMutex.Unlock()

Expand Down
Loading

0 comments on commit e3dca8f

Please sign in to comment.