Skip to content

Commit

Permalink
Fix handling of count with non-zero revision
Browse files Browse the repository at this point in the history
Signed-off-by: Brad Davidson <[email protected]>
  • Loading branch information
brandond committed Feb 2, 2024
1 parent ab4df41 commit 3773672
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 36 deletions.
26 changes: 22 additions & 4 deletions pkg/drivers/generic/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ type Generic struct {
RevisionSQL string
ListRevisionStartSQL string
GetRevisionAfterSQL string
CountSQL string
CountCurrentSQL string
CountRevisionSQL string
AfterSQL string
DeleteSQL string
CompactSQL string
Expand Down Expand Up @@ -219,12 +220,18 @@ func Open(ctx context.Context, driverName, dataSourceName string, connPoolConfig
ListRevisionStartSQL: q(fmt.Sprintf(listSQL, "AND mkv.id <= ?"), paramCharacter, numbered),
GetRevisionAfterSQL: q(fmt.Sprintf(listSQL, idOfKey), paramCharacter, numbered),

CountSQL: q(fmt.Sprintf(`
CountCurrentSQL: q(fmt.Sprintf(`
SELECT (%s), COUNT(c.theid)
FROM (
%s
) c`, revSQL, fmt.Sprintf(listSQL, "")), paramCharacter, numbered),

CountRevisionSQL: q(fmt.Sprintf(`
SELECT (%s), COUNT(c.theid)
FROM (
%s
) c`, revSQL, fmt.Sprintf(listSQL, "AND mkv.id <= ?")), paramCharacter, numbered),

AfterSQL: q(fmt.Sprintf(`
SELECT (%s), (%s), %s
FROM kine AS kv
Expand Down Expand Up @@ -360,13 +367,24 @@ func (d *Generic) List(ctx context.Context, prefix, startKey string, limit, revi
return d.query(ctx, sql, prefix, revision, startKey, revision, includeDeleted)
}

func (d *Generic) Count(ctx context.Context, prefix string) (int64, int64, error) {
func (d *Generic) CountCurrent(ctx context.Context, prefix string) (int64, int64, error) {
var (
rev sql.NullInt64
id int64
)

row := d.queryRow(ctx, d.CountCurrentSQL, prefix, false)
err := row.Scan(&rev, &id)
return rev.Int64, id, err
}

func (d *Generic) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) {
var (
rev sql.NullInt64
id int64
)

row := d.queryRow(ctx, d.CountSQL, prefix, false)
row := d.queryRow(ctx, d.CountRevisionSQL, prefix, revision, false)
err := row.Scan(&rev, &id)
return rev.Int64, id, err
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/drivers/nats/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func (b *Backend) CurrentRevision(ctx context.Context) (int64, error) {
}

// Count returns an exact count of the number of matching keys and the current revision of the database.
func (b *Backend) Count(ctx context.Context, prefix string) (int64, int64, error) {
count, err := b.kv.Count(ctx, prefix)
func (b *Backend) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) {
count, err := b.kv.Count(ctx, prefix, revision)
if err != nil {
return 0, 0, err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/drivers/nats/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ func TestBackend_Create(t *testing.T) {

time.Sleep(2 * time.Millisecond)

srev, count, err := b.Count(ctx, "/")
srev, count, err := b.Count(ctx, "/", 0)
noErr(t, err)
expEqual(t, 4, srev)
expEqual(t, 4, count)

time.Sleep(time.Second)

srev, count, err = b.Count(ctx, "/")
srev, count, err = b.Count(ctx, "/", 0)
noErr(t, err)
expEqual(t, 4, srev)
expEqual(t, 3, count)
Expand All @@ -149,7 +149,7 @@ func TestBackend_Create(t *testing.T) {

time.Sleep(2 * time.Millisecond)

srev, count, err = b.Count(ctx, "/")
srev, count, err = b.Count(ctx, "/", 0)
noErr(t, err)
expEqual(t, 6, srev)
expEqual(t, 4, count)
Expand Down
31 changes: 24 additions & 7 deletions pkg/drivers/nats/kv.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ type keySeq struct {
seq uint64
}

func (e *KeyValue) Count(ctx context.Context, prefix string) (int64, error) {
func (e *KeyValue) Count(ctx context.Context, prefix string, revision int64) (int64, error) {
it := e.bt.Iter()

if prefix != "" {
Expand All @@ -396,11 +396,27 @@ func (e *KeyValue) Count(ctx context.Context, prefix string) (int64, error) {
break
}
v := it.Value()
so := v[len(v)-1]

if so.op == jetstream.KeyValuePut {
if so.ex.IsZero() || so.ex.After(now) {
count++
// Get the latest update for the key.
if revision <= 0 {
so := v[len(v)-1]
if so.op == jetstream.KeyValuePut {
if so.ex.IsZero() || so.ex.After(now) {
count++
}
}
} else {
// Find the latest update below the given revision.
for i := len(v) - 1; i >= 0; i-- {
so := v[i]
if so.seq <= uint64(revision) {
if so.op == jetstream.KeyValuePut {
if so.ex.IsZero() || so.ex.After(now) {
count++
}
}
break
}
}
}

Expand Down Expand Up @@ -429,6 +445,7 @@ func (e *KeyValue) List(ctx context.Context, prefix, startKey string, limit, rev
}

var matches []*keySeq
now := time.Now()

e.btm.RLock()

Expand All @@ -448,7 +465,7 @@ func (e *KeyValue) List(ctx context.Context, prefix, startKey string, limit, rev
if revision <= 0 {
so := v[len(v)-1]
if so.op == jetstream.KeyValuePut {
if so.ex.IsZero() || so.ex.After(time.Now()) {
if so.ex.IsZero() || so.ex.After(now) {
matches = append(matches, &keySeq{key: k, seq: so.seq})
}
}
Expand All @@ -458,7 +475,7 @@ func (e *KeyValue) List(ctx context.Context, prefix, startKey string, limit, rev
so := v[i]
if so.seq <= uint64(revision) {
if so.op == jetstream.KeyValuePut {
if so.ex.IsZero() || so.ex.After(time.Now()) {
if so.ex.IsZero() || so.ex.After(now) {
matches = append(matches, &keySeq{key: k, seq: so.seq})
}
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/drivers/nats/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ func (b *BackendLogger) List(ctx context.Context, prefix, startKey string, limit
}

// Count returns an exact count of the number of matching keys and the current revision of the database
func (b *BackendLogger) Count(ctx context.Context, prefix string) (revRet int64, count int64, err error) {
func (b *BackendLogger) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) {
start := time.Now()
defer func() {
dur := time.Since(start)
fStr := "COUNT %s => rev=%d, count=%d, err=%v, duration=%s"
b.logMethod(dur, fStr, prefix, revRet, count, err, dur)
fStr := "COUNT %s, rev=%d => rev=%d, count=%d, err=%v, duration=%s"
b.logMethod(dur, fStr, prefix, revision, revRet, count, err, dur)
}()

return b.backend.Count(ctx, prefix)
return b.backend.Count(ctx, prefix, revision)
}

func (b *BackendLogger) Update(ctx context.Context, key string, value []byte, revision, lease int64) (revRet int64, kvRet *server.KeyValue, updateRet bool, errRet error) {
Expand Down
8 changes: 4 additions & 4 deletions pkg/logstructured/logstructured.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type Log interface {
List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeletes bool) (int64, []*server.Event, error)
After(ctx context.Context, prefix string, revision, limit int64) (int64, []*server.Event, error)
Watch(ctx context.Context, prefix string) <-chan []*server.Event
Count(ctx context.Context, prefix string) (int64, int64, error)
Count(ctx context.Context, prefix string, revision int64) (int64, int64, error)
Append(ctx context.Context, event *server.Event) (int64, error)
DbSize(ctx context.Context) (int64, error)
}
Expand Down Expand Up @@ -198,11 +198,11 @@ func (l *LogStructured) List(ctx context.Context, prefix, startKey string, limit
return rev, kvs, nil
}

func (l *LogStructured) Count(ctx context.Context, prefix string) (revRet int64, count int64, err error) {
func (l *LogStructured) Count(ctx context.Context, prefix string, revision int64) (revRet int64, count int64, err error) {
defer func() {
logrus.Tracef("COUNT %s => rev=%d, count=%d, err=%v", prefix, revRet, count, err)
logrus.Tracef("COUNT %s, rev=%d => rev=%d, count=%d, err=%v", prefix, revision, revRet, count, err)
}()
rev, count, err := l.log.Count(ctx, prefix)
rev, count, err := l.log.Count(ctx, prefix, revision)
if err != nil {
return 0, 0, err
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/logstructured/sqllog/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,11 +524,15 @@ func canSkipRevision(rev, skip int64, skipTime time.Time) bool {
return rev == skip && time.Since(skipTime) > time.Second
}

func (s *SQLLog) Count(ctx context.Context, prefix string) (int64, int64, error) {
func (s *SQLLog) Count(ctx context.Context, prefix string, revision int64) (int64, int64, error) {
if strings.HasSuffix(prefix, "/") {
prefix += "%"
}
return s.d.Count(ctx, prefix)

if revision == 0 {
return s.d.CountCurrent(ctx, prefix)
}
return s.d.Count(ctx, prefix, revision)
}

func (s *SQLLog) Append(ctx context.Context, event *server.Event) (int64, error) {
Expand Down
21 changes: 13 additions & 8 deletions pkg/server/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest)
prefix = prefix + "/"
}
start := string(bytes.TrimRight(r.Key, "\x00"))
revision := r.Revision

if r.CountOnly {
rev, count, err := l.backend.Count(ctx, prefix)
rev, count, err := l.backend.Count(ctx, prefix, revision)
if err != nil {
return nil, err
}
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, r.Revision, rev, count)
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, count)
return &RangeResponse{
Header: txnHeader(rev),
Count: count,
Expand All @@ -38,29 +39,33 @@ func (l *LimitedServer) list(ctx context.Context, r *etcdserverpb.RangeRequest)
limit++
}

rev, kvs, err := l.backend.List(ctx, prefix, start, limit, r.Revision)
rev, kvs, err := l.backend.List(ctx, prefix, start, limit, revision)
if err != nil {
return nil, err
}

logrus.Tracef("LIST key=%s, end=%s, revision=%d, currentRev=%d count=%d, limit=%d", r.Key, r.RangeEnd, r.Revision, rev, len(kvs), r.Limit)
logrus.Tracef("LIST key=%s, end=%s, revision=%d, currentRev=%d count=%d, limit=%d", r.Key, r.RangeEnd, revision, rev, len(kvs), r.Limit)
resp := &RangeResponse{
Header: txnHeader(rev),
Count: int64(len(kvs)),
Kvs: kvs,
}

// count the actual number of results if there are more items in the db.
if limit > 0 && resp.Count > r.Limit {
resp.More = true
resp.Kvs = kvs[0 : limit-1]

// count the actual number of results if there are more items in the db.
_, count, err := l.backend.Count(ctx, prefix)
if revision == 0 {
revision = rev
}

rev, resp.Count, err = l.backend.Count(ctx, prefix, revision)
if err != nil {
return nil, err
}
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, r.Revision, rev, count)
resp.Count = count
logrus.Tracef("LIST COUNT key=%s, end=%s, revision=%d, currentRev=%d count=%d", r.Key, r.RangeEnd, revision, rev, resp.Count)
resp.Header = txnHeader(rev)
}

return resp, nil
Expand Down
5 changes: 3 additions & 2 deletions pkg/server/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Backend interface {
Create(ctx context.Context, key string, value []byte, lease int64) (int64, error)
Delete(ctx context.Context, key string, revision int64) (int64, *KeyValue, bool, error)
List(ctx context.Context, prefix, startKey string, limit, revision int64) (int64, []*KeyValue, error)
Count(ctx context.Context, prefix string) (int64, int64, error)
Count(ctx context.Context, prefix string, revision int64) (int64, int64, error)
Update(ctx context.Context, key string, value []byte, revision, lease int64) (int64, *KeyValue, bool, error)
Watch(ctx context.Context, key string, revision int64) WatchResult
DbSize(ctx context.Context) (int64, error)
Expand All @@ -33,7 +33,8 @@ type Backend interface {
type Dialect interface {
ListCurrent(ctx context.Context, prefix string, limit int64, includeDeleted bool) (*sql.Rows, error)
List(ctx context.Context, prefix, startKey string, limit, revision int64, includeDeleted bool) (*sql.Rows, error)
Count(ctx context.Context, prefix string) (int64, int64, error)
CountCurrent(ctx context.Context, prefix string) (int64, int64, error)
Count(ctx context.Context, prefix string, revision int64) (int64, int64, error)
CurrentRevision(ctx context.Context) (int64, error)
After(ctx context.Context, prefix string, rev, limit int64) (*sql.Rows, error)
Insert(ctx context.Context, key string, create, delete bool, createRevision, previousRevision int64, ttl int64, value, prevValue []byte) (int64, error)
Expand Down

0 comments on commit 3773672

Please sign in to comment.