Skip to content

Commit

Permalink
Use GetLoadingProgress for sync load coll/part (#509)
Browse files Browse the repository at this point in the history
Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Jul 4, 2023
1 parent f519502 commit 1829b63
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 309 deletions.
41 changes: 29 additions & 12 deletions client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,22 +431,21 @@ func (c *GrpcClient) LoadCollection(ctx context.Context, collName string, async
}

if !async {
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return errors.New("context deadline exceeded")
default:
return ctx.Err()
case <-ticker.C:
progress, err := c.getLoadingProgress(ctx, collName)
if err != nil {
return err
}
if progress == 100 {
return nil
}
}

coll, err := c.ShowCollection(ctx, collName)
if err != nil {
return err
}
if coll.Loaded {
break
}

time.Sleep(200 * time.Millisecond) // TODO change to configuration
}
}
return nil
Expand Down Expand Up @@ -596,3 +595,21 @@ func (c *GrpcClient) AlterCollection(ctx context.Context, collName string, attrs
}
return handleRespStatus(resp)
}

func (c *GrpcClient) getLoadingProgress(ctx context.Context, collectionName string, partitionNames ...string) (int64, error) {
req := &server.GetLoadingProgressRequest{
Base: &common.MsgBase{},
DbName: "",
CollectionName: collectionName,
PartitionNames: partitionNames,
}

resp, err := c.Service.GetLoadingProgress(ctx, req)
if err != nil {
return -1, err
}
if err := handleRespStatus(resp.GetStatus()); err != nil {
return -1, err
}
return resp.GetProgress(), nil
}
227 changes: 132 additions & 95 deletions client/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"math/rand"
"testing"
"time"

"github.com/golang/protobuf/proto"
common "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
Expand Down Expand Up @@ -59,90 +58,6 @@ func TestGrpcClientDropCollection(t *testing.T) {
})
}

func TestGrpcClientLoadCollection(t *testing.T) {
ctx := context.Background()
c := testClient(ctx, t)
mockServer.SetInjection(MHasCollection, hasCollectionDefault)
// injection check collection name equals
mockServer.SetInjection(MLoadCollection, func(_ context.Context, raw proto.Message) (proto.Message, error) {
req, ok := raw.(*server.LoadCollectionRequest)
if !ok {
return BadRequestStatus()
}
assert.Equal(t, testCollectionName, req.GetCollectionName())
return SuccessStatus()
})
t.Run("Load collection normal async", func(t *testing.T) {
assert.Nil(t, c.LoadCollection(ctx, testCollectionName, true))
})
t.Run("Load collection sync", func(t *testing.T) {

loadTime := rand.Intn(500) + 500 // in milli seconds, 100~1000 milliseconds
passed := false // ### flag variable
start := time.Now()

mockServer.SetInjection(MShowCollections, func(_ context.Context, raw proto.Message) (proto.Message, error) {
req, ok := raw.(*server.ShowCollectionsRequest)
r := &server.ShowCollectionsResponse{}
if !ok || req == nil {
s, err := BadRequestStatus()
r.Status = s
return r, err
}
s, err := SuccessStatus()
r.Status = s
r.CollectionIds = []int64{1}
var perc int64
if time.Since(start) > time.Duration(loadTime)*time.Millisecond {
t.Log("passed")
perc = 100
passed = true
}
r.InMemoryPercentages = []int64{perc}
return r, err
})
assert.Nil(t, c.LoadCollection(ctx, testCollectionName, false))
assert.True(t, passed)

start = time.Now()
passed = false
quickCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
defer cancel()
assert.NotNil(t, c.LoadCollection(quickCtx, testCollectionName, false))

// remove injection
mockServer.DelInjection(MShowCollections)
})
t.Run("Load default replica", func(t *testing.T) {
mockServer.SetInjection(MLoadCollection, func(ctx context.Context, raw proto.Message) (proto.Message, error) {
req, ok := raw.(*server.LoadCollectionRequest)
if !ok {
return BadRequestStatus()
}
assert.Equal(t, testDefaultReplicaNumber, req.GetReplicaNumber())
assert.Equal(t, testCollectionName, req.GetCollectionName())
return SuccessStatus()
})
defer mockServer.DelInjection(MLoadCollection)
assert.Nil(t, c.LoadCollection(ctx, testCollectionName, true))
})
t.Run("Load multiple replica", func(t *testing.T) {
mockServer.DelInjection(MLoadCollection)

mockServer.SetInjection(MLoadCollection, func(ctx context.Context, raw proto.Message) (proto.Message, error) {
req, ok := raw.(*server.LoadCollectionRequest)
if !ok {
return BadRequestStatus()
}
assert.Equal(t, testMultiReplicaNumber, req.GetReplicaNumber())
assert.Equal(t, testCollectionName, req.GetCollectionName())
return SuccessStatus()
})
defer mockServer.DelInjection(MLoadCollection)
assert.Nil(t, c.LoadCollection(ctx, testCollectionName, true, WithReplicaNumber(testMultiReplicaNumber)))
})
}

func TestReleaseCollection(t *testing.T) {
ctx := context.Background()

Expand Down Expand Up @@ -717,11 +632,9 @@ func (s *CollectionSuite) TestNewCollection() {
},
}, nil)
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&common.Status{ErrorCode: common.ErrorCode_Success}, nil)
s.mock.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&server.ShowCollectionsResponse{
Status: &common.Status{ErrorCode: common.ErrorCode_Success},
CollectionNames: []string{testCollectionName},
CollectionIds: []int64{0},
InMemoryPercentages: []int64{100},
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&server.GetLoadingProgressResponse{
Status: &common.Status{ErrorCode: common.ErrorCode_Success},
Progress: 100,
}, nil)
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&server.DescribeCollectionResponse{
Status: &common.Status{ErrorCode: common.ErrorCode_Success},
Expand Down Expand Up @@ -775,11 +688,9 @@ func (s *CollectionSuite) TestNewCollection() {
},
}, nil)
s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&common.Status{ErrorCode: common.ErrorCode_Success}, nil)
s.mock.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&server.ShowCollectionsResponse{
Status: &common.Status{ErrorCode: common.ErrorCode_Success},
CollectionNames: []string{testCollectionName},
CollectionIds: []int64{0},
InMemoryPercentages: []int64{100},
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Return(&server.GetLoadingProgressResponse{
Status: &common.Status{ErrorCode: common.ErrorCode_Success},
Progress: 100,
}, nil)
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&server.DescribeCollectionResponse{
Status: &common.Status{ErrorCode: common.ErrorCode_Success},
Expand Down Expand Up @@ -869,6 +780,132 @@ func (s *CollectionSuite) TestAlterCollection() {
})
}

func (s *CollectionSuite) TestLoadCollection() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

c := s.client

s.Run("normal_run_async", func() {
defer s.resetMock()
s.mock.EXPECT().HasCollection(mock.Anything, &server.HasCollectionRequest{CollectionName: testCollectionName}).
Return(&server.BoolResponse{Status: &common.Status{}, Value: true}, nil)

s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&common.Status{ErrorCode: common.ErrorCode_Success}, nil)

err := c.LoadCollection(ctx, testCollectionName, true)
s.NoError(err)
})

s.Run("normal_run_sync", func() {
defer s.resetMock()
s.mock.EXPECT().HasCollection(mock.Anything, &server.HasCollectionRequest{CollectionName: testCollectionName}).
Return(&server.BoolResponse{Status: &common.Status{}, Value: true}, nil)

s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&common.Status{ErrorCode: common.ErrorCode_Success}, nil)
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).
Return(&server.GetLoadingProgressResponse{
Status: &common.Status{ErrorCode: common.ErrorCode_Success},
Progress: 100,
}, nil)

err := c.LoadCollection(ctx, testCollectionName, true)
s.NoError(err)
})

s.Run("load_default_replica", func() {
defer s.resetMock()
s.mock.EXPECT().HasCollection(mock.Anything, &server.HasCollectionRequest{CollectionName: testCollectionName}).
Return(&server.BoolResponse{Status: &common.Status{}, Value: true}, nil)

s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Run(func(_ context.Context, req *server.LoadCollectionRequest) {
s.Equal(testDefaultReplicaNumber, req.GetReplicaNumber())
s.Equal(testCollectionName, req.GetCollectionName())
}).Return(&common.Status{ErrorCode: common.ErrorCode_Success}, nil)

err := c.LoadCollection(ctx, testCollectionName, true)
s.NoError(err)
})

s.Run("load_multiple_replica", func() {
defer s.resetMock()
s.mock.EXPECT().HasCollection(mock.Anything, &server.HasCollectionRequest{CollectionName: testCollectionName}).
Return(&server.BoolResponse{Status: &common.Status{}, Value: true}, nil)

s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Run(func(_ context.Context, req *server.LoadCollectionRequest) {
s.Equal(testMultiReplicaNumber, req.GetReplicaNumber())
s.Equal(testCollectionName, req.GetCollectionName())
}).Return(&common.Status{ErrorCode: common.ErrorCode_Success}, nil)

err := c.LoadCollection(ctx, testCollectionName, true, WithReplicaNumber(testMultiReplicaNumber))
s.NoError(err)
})

s.Run("has_collection_failure", func() {
s.Run("return_false", func() {
defer s.resetMock()
s.mock.EXPECT().HasCollection(mock.Anything, &server.HasCollectionRequest{CollectionName: testCollectionName}).
Return(&server.BoolResponse{Status: &common.Status{}, Value: false}, nil)

err := c.LoadCollection(ctx, testCollectionName, true)
s.Error(err)
})

s.Run("return_error", func() {
defer s.resetMock()
s.mock.EXPECT().HasCollection(mock.Anything, &server.HasCollectionRequest{CollectionName: testCollectionName}).
Return(nil, errors.New("mock error"))

err := c.LoadCollection(ctx, testCollectionName, true)
s.Error(err)
})
})

s.Run("load_collection_failure", func() {
s.Run("failure_status", func() {
defer s.resetMock()
s.mock.EXPECT().HasCollection(mock.Anything, &server.HasCollectionRequest{CollectionName: testCollectionName}).
Return(&server.BoolResponse{Status: &common.Status{}, Value: true}, nil)

s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).
Return(&common.Status{ErrorCode: common.ErrorCode_UnexpectedError}, nil)

err := c.LoadCollection(ctx, testCollectionName, true)
s.Error(err)
})

s.Run("return_error", func() {
s.mock.EXPECT().HasCollection(mock.Anything, &server.HasCollectionRequest{CollectionName: testCollectionName}).
Return(&server.BoolResponse{Status: &common.Status{}, Value: true}, nil)

s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).
Return(nil, errors.New("mock error"))

err := c.LoadCollection(ctx, testCollectionName, true)
s.Error(err)
})
})

s.Run("get_loading_progress_failure", func() {
defer s.resetMock()
s.mock.EXPECT().HasCollection(mock.Anything, &server.HasCollectionRequest{CollectionName: testCollectionName}).
Return(&server.BoolResponse{Status: &common.Status{}, Value: true}, nil)

s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&common.Status{ErrorCode: common.ErrorCode_Success}, nil)
s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).
Return(nil, errors.New("mock error"))

err := c.LoadCollection(ctx, testCollectionName, false)
s.Error(err)
})

s.Run("service_not_ready", func() {
c := &GrpcClient{}
err := c.LoadCollection(ctx, testCollectionName, false)
s.ErrorIs(err, ErrClientNotReady)
})
}

func TestCollectionSuite(t *testing.T) {
suite.Run(t, new(CollectionSuite))
}
49 changes: 10 additions & 39 deletions client/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,6 @@ func (c *GrpcClient) LoadPartitions(ctx context.Context, collName string, partit
return err
}
}
partitions, err := c.ShowPartitions(ctx, collName)
if err != nil {
return err
}
m := make(map[string]int64)
for _, partition := range partitions {
m[partition.Name] = partition.ID
}
// load partitions ids
ids := make(map[int64]struct{})
for _, partitionName := range partitionNames {
id, has := m[partitionName]
if !has {
return fmt.Errorf("collection %s does not has partitions %s", collName, partitionName)
}
ids[id] = struct{}{}
}

req := &server.LoadPartitionsRequest{
DbName: "", // reserved
Expand All @@ -180,34 +163,22 @@ func (c *GrpcClient) LoadPartitions(ctx context.Context, collName string, partit
}

if !async {
ticker := time.NewTicker(200 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return errors.New("context deadline exceeded")
default:
}
partitions, err := c.ShowPartitions(ctx, collName)
if err != nil {
return err
}
foundLoading := false
loaded := 0
for _, partition := range partitions {
if _, has := ids[partition.ID]; !has {
continue
return ctx.Err()
case <-ticker.C:
progress, err := c.getLoadingProgress(ctx, collName, partitionNames...)
if err != nil {
return err
}
if !partition.Loaded {
//Not loaded
foundLoading = true
break
if progress == 100 {
return nil
}
loaded++
}
if foundLoading || loaded < len(partitionNames) {
time.Sleep(time.Millisecond * 100)
continue
}
break
}
}

Expand Down
Loading

0 comments on commit 1829b63

Please sign in to comment.