Skip to content

Commit

Permalink
Support the replicate message api (#622)
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <[email protected]>
  • Loading branch information
SimFG authored Nov 10, 2023
1 parent b61b0b6 commit 087d17c
Show file tree
Hide file tree
Showing 19 changed files with 439 additions and 198 deletions.
28 changes: 18 additions & 10 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (

"google.golang.org/grpc"

"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"

"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

Expand All @@ -38,9 +40,9 @@ type Client interface {
// ListDatabases list all database in milvus cluster.
ListDatabases(ctx context.Context) ([]entity.Database, error)
// CreateDatabase create database with the given name.
CreateDatabase(ctx context.Context, dbName string) error
CreateDatabase(ctx context.Context, dbName string, opts ...CreateDatabaseOption) error
// DropDatabase drop database with the given db name.
DropDatabase(ctx context.Context, dbName string) error
DropDatabase(ctx context.Context, dbName string, opts ...DropDatabaseOption) error

// -- collection --

Expand All @@ -53,13 +55,13 @@ type Client interface {
// DescribeCollection describe collection meta
DescribeCollection(ctx context.Context, collName string) (*entity.Collection, error)
// DropCollection drop the specified collection
DropCollection(ctx context.Context, collName string) error
DropCollection(ctx context.Context, collName string, opts ...DropCollectionOption) error
// GetCollectionStatistics get collection statistics
GetCollectionStatistics(ctx context.Context, collName string) (map[string]string, error)
// LoadCollection load collection into memory
LoadCollection(ctx context.Context, collName string, async bool, opts ...LoadCollectionOption) error
// ReleaseCollection release loaded collection
ReleaseCollection(ctx context.Context, collName string) error
ReleaseCollection(ctx context.Context, collName string, opts ...ReleaseCollectionOption) error
// HasCollection check whether collection exists
HasCollection(ctx context.Context, collName string) (bool, error)
// RenameCollection performs renaming for provided collection.
Expand Down Expand Up @@ -91,17 +93,17 @@ type Client interface {
// -- partition --

// CreatePartition create partition for collection
CreatePartition(ctx context.Context, collName string, partitionName string) error
CreatePartition(ctx context.Context, collName string, partitionName string, opts ...CreatePartitionOption) error
// DropPartition drop partition from collection
DropPartition(ctx context.Context, collName string, partitionName string) error
DropPartition(ctx context.Context, collName string, partitionName string, opts ...DropPartitionOption) error
// ShowPartitions list all partitions from collection
ShowPartitions(ctx context.Context, collName string) ([]*entity.Partition, error)
// HasPartition check whether partition exists in collection
HasPartition(ctx context.Context, collName string, partitionName string) (bool, error)
// LoadPartitions load partitions into memory
LoadPartitions(ctx context.Context, collName string, partitionNames []string, async bool) error
LoadPartitions(ctx context.Context, collName string, partitionNames []string, async bool, opts ...LoadPartitionsOption) error
// ReleasePartitions release partitions
ReleasePartitions(ctx context.Context, collName string, partitionNames []string) error
ReleasePartitions(ctx context.Context, collName string, partitionNames []string, opts ...ReleasePartitionsOption) error

// -- segment --
GetPersistentSegmentInfo(ctx context.Context, collName string) ([]*entity.Segment, error)
Expand All @@ -124,10 +126,10 @@ type Client interface {
// Insert column-based data into collection, returns id column values
Insert(ctx context.Context, collName string, partitionName string, columns ...entity.Column) (entity.Column, error)
// Flush collection, specified
Flush(ctx context.Context, collName string, async bool) error
Flush(ctx context.Context, collName string, async bool, opts ...FlushOption) error
// FlushV2 flush collection, specified, return newly sealed segmentIds, all flushed segmentIds of the collection, seal time and error
// currently it is only used in milvus-backup(https://github.com/zilliztech/milvus-backup)
FlushV2(ctx context.Context, collName string, async bool) ([]int64, []int64, int64, error)
FlushV2(ctx context.Context, collName string, async bool, opts ...FlushOption) ([]int64, []int64, int64, error)
// DeleteByPks deletes entries related to provided primary keys
DeleteByPks(ctx context.Context, collName string, partitionName string, ids entity.Column) error
// Delete deletes entries match expression
Expand Down Expand Up @@ -211,6 +213,12 @@ type Client interface {
GetVersion(ctx context.Context) (string, error)
// CheckHealth returns milvus state
CheckHealth(ctx context.Context) (*entity.MilvusState, error)

ReplicateMessage(ctx context.Context,
channelName string, beginTs, endTs uint64,
msgsBytes [][]byte, startPositions, endPositions []*msgpb.MsgPosition,
opts ...ReplicateMessageOption,
) (*entity.MessageInfo, error)
}

// NewClient create a client connected to remote milvus cluster.
Expand Down
12 changes: 10 additions & 2 deletions client/client_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ const (
MListDatabase ServiceMethod = 1000
MCreateDatabase ServiceMethod = 1001
MDropDatabase ServiceMethod = 1002

MReplicateMessage ServiceMethod = 1100
)

// injection function definition
Expand Down Expand Up @@ -924,8 +926,14 @@ func (m *MockServer) AllocTimestamp(_ context.Context, _ *milvuspb.AllocTimestam
panic("not implemented")
}

func (m *MockServer) ReplicateMessage(_ context.Context, _ *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) {
panic("not implemented")
func (m *MockServer) ReplicateMessage(ctx context.Context, req *milvuspb.ReplicateMessageRequest) (*milvuspb.ReplicateMessageResponse, error) {
f := m.GetInjection(MReplicateMessage)
if f != nil {
r, err := f(ctx, req)
return r.(*milvuspb.ReplicateMessageResponse), err
}
s, err := SuccessStatus()
return &milvuspb.ReplicateMessageResponse{Status: s}, err
}

func (m *MockServer) Connect(_ context.Context, _ *milvuspb.ConnectRequest) (*milvuspb.ConnectResponse, error) {
Expand Down
12 changes: 10 additions & 2 deletions client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/cockroachdb/errors"

"github.com/golang/protobuf/proto"

"github.com/milvus-io/milvus-sdk-go/v2/entity"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
Expand Down Expand Up @@ -151,6 +152,7 @@ func (c *GrpcClient) requestCreateCollection(ctx context.Context, sch *entity.Sc
}

req := &milvuspb.CreateCollectionRequest{
Base: opt.MsgBase,
DbName: "", // reserved fields, not used for now
CollectionName: sch.CollectionName,
Schema: bs,
Expand Down Expand Up @@ -279,7 +281,7 @@ func (c *GrpcClient) DescribeCollection(ctx context.Context, collName string) (*
}

// DropCollection drop collection by name
func (c *GrpcClient) DropCollection(ctx context.Context, collName string) error {
func (c *GrpcClient) DropCollection(ctx context.Context, collName string, opts ...DropCollectionOption) error {
if c.Service == nil {
return ErrClientNotReady
}
Expand All @@ -290,6 +292,9 @@ func (c *GrpcClient) DropCollection(ctx context.Context, collName string) error
req := &milvuspb.DropCollectionRequest{
CollectionName: collName,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.DropCollection(ctx, req)
if err != nil {
return err
Expand Down Expand Up @@ -447,7 +452,7 @@ func (c *GrpcClient) LoadCollection(ctx context.Context, collName string, async
}

// ReleaseCollection release loaded collection
func (c *GrpcClient) ReleaseCollection(ctx context.Context, collName string) error {
func (c *GrpcClient) ReleaseCollection(ctx context.Context, collName string, opts ...ReleaseCollectionOption) error {
if c.Service == nil {
return ErrClientNotReady
}
Expand All @@ -459,6 +464,9 @@ func (c *GrpcClient) ReleaseCollection(ctx context.Context, collName string) err
DbName: "", // reserved
CollectionName: collName,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.ReleaseCollection(ctx, req)
if err != nil {
return err
Expand Down
8 changes: 4 additions & 4 deletions client/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (s *CollectionSuite) TestCreateCollection() {
Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)
s.mock.EXPECT().HasCollection(mock.Anything, &milvuspb.HasCollectionRequest{CollectionName: testCollectionName}).Return(&milvuspb.BoolResponse{Status: &commonpb.Status{}, Value: false}, nil)

err := c.CreateCollection(ctx, ds, shardsNum)
err := c.CreateCollection(ctx, ds, shardsNum, WithCreateCollectionMsgBase(&commonpb.MsgBase{}))
s.NoError(err)
})

Expand Down Expand Up @@ -514,7 +514,7 @@ func (s *CollectionSuite) TestLoadCollection() {

s.mock.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, nil)

err := c.LoadCollection(ctx, testCollectionName, true)
err := c.LoadCollection(ctx, testCollectionName, true, WithLoadCollectionMsgBase(&commonpb.MsgBase{}))
s.NoError(err)
})

Expand Down Expand Up @@ -663,7 +663,7 @@ func TestGrpcClientDropCollection(t *testing.T) {
})

t.Run("Test Normal drop", func(t *testing.T) {
assert.Nil(t, c.DropCollection(ctx, testCollectionName))
assert.Nil(t, c.DropCollection(ctx, testCollectionName, WithDropCollectionMsgBase(&commonpb.MsgBase{})))
})

t.Run("Test drop non-existing collection", func(t *testing.T) {
Expand All @@ -685,7 +685,7 @@ func TestReleaseCollection(t *testing.T) {
return SuccessStatus()
})

c.ReleaseCollection(ctx, testCollectionName)
c.ReleaseCollection(ctx, testCollectionName, WithReleaseCollectionMsgBase(&commonpb.MsgBase{}))
}

func TestGrpcClientHasCollection(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion client/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestGrpcClientFlush(t *testing.T) {
c := testClient(ctx, t)

t.Run("test async flush", func(t *testing.T) {
assert.Nil(t, c.Flush(ctx, testCollectionName, true))
assert.Nil(t, c.Flush(ctx, testCollectionName, true, WithFlushMsgBase(&commonpb.MsgBase{})))
})

t.Run("test sync flush", func(t *testing.T) {
Expand Down
10 changes: 8 additions & 2 deletions client/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (c *GrpcClient) UsingDatabase(ctx context.Context, dbName string) error {

// CreateDatabase creates a new database for remote Milvus cluster.
// TODO:New options can be added as expanding parameters.
func (c *GrpcClient) CreateDatabase(ctx context.Context, dbName string) error {
func (c *GrpcClient) CreateDatabase(ctx context.Context, dbName string, opts ...CreateDatabaseOption) error {
if c.Service == nil {
return ErrClientNotReady
}
Expand All @@ -50,6 +50,9 @@ func (c *GrpcClient) CreateDatabase(ctx context.Context, dbName string) error {
req := &milvuspb.CreateDatabaseRequest{
DbName: dbName,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.CreateDatabase(ctx, req)
if err != nil {
return err
Expand Down Expand Up @@ -84,7 +87,7 @@ func (c *GrpcClient) ListDatabases(ctx context.Context) ([]entity.Database, erro
}

// DropDatabase drop all database in milvus cluster.
func (c *GrpcClient) DropDatabase(ctx context.Context, dbName string) error {
func (c *GrpcClient) DropDatabase(ctx context.Context, dbName string, opts ...DropDatabaseOption) error {
if c.Service == nil {
return ErrClientNotReady
}
Expand All @@ -95,6 +98,9 @@ func (c *GrpcClient) DropDatabase(ctx context.Context, dbName string) error {
req := &milvuspb.DropDatabaseRequest{
DbName: dbName,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.DropDatabase(ctx, req)
if err != nil {
return err
Expand Down
5 changes: 3 additions & 2 deletions client/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/go-faker/faker/v4"
"github.com/go-faker/faker/v4/pkg/options"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -45,7 +46,7 @@ func TestGrpcClientCreateDatabase(t *testing.T) {
mockServer.SetInjection(MCreateDatabase, func(ctx context.Context, m proto.Message) (proto.Message, error) {
return SuccessStatus()
})
err := c.CreateDatabase(ctx, "a")
err := c.CreateDatabase(ctx, "a", WithCreateDatabaseMsgBase(&commonpb.MsgBase{}))
assert.Nil(t, err)
}

Expand All @@ -55,6 +56,6 @@ func TestGrpcClientDropDatabase(t *testing.T) {
mockServer.SetInjection(MDropDatabase, func(ctx context.Context, m proto.Message) (proto.Message, error) {
return SuccessStatus()
})
err := c.DropDatabase(ctx, "a")
err := c.DropDatabase(ctx, "a", WithDropDatabaseMsgBase(&commonpb.MsgBase{}))
assert.Nil(t, err)
}
9 changes: 9 additions & 0 deletions client/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type indexDef struct {
name string
fieldName string
collectionName string
MsgBase *commonpb.MsgBase
}

// IndexOption is the predefined function to alter index def.
Expand All @@ -71,6 +72,12 @@ func WithIndexName(name string) IndexOption {
}
}

func WithIndexMsgBase(msgBase *commonpb.MsgBase) IndexOption {
return func(def *indexDef) {
def.MsgBase = msgBase
}
}

func getIndexDef(opts ...IndexOption) indexDef {
idxDef := indexDef{}
for _, opt := range opts {
Expand All @@ -93,6 +100,7 @@ func (c *GrpcClient) CreateIndex(ctx context.Context, collName string, fieldName
idxDef := getIndexDef(opts...)

req := &milvuspb.CreateIndexRequest{
Base: idxDef.MsgBase,
DbName: "", // reserved
CollectionName: collName,
FieldName: fieldName,
Expand Down Expand Up @@ -167,6 +175,7 @@ func (c *GrpcClient) DropIndex(ctx context.Context, collName string, fieldName s

idxDef := getIndexDef(opts...)
req := &milvuspb.DropIndexRequest{
Base: idxDef.MsgBase,
DbName: "", //reserved,
CollectionName: collName,
FieldName: fieldName,
Expand Down
4 changes: 2 additions & 2 deletions client/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestGrpcClientCreateIndex(t *testing.T) {
})

t.Run("test async create index", func(t *testing.T) {
assert.Nil(t, c.CreateIndex(ctx, testCollectionName, fieldName, idx, true))
assert.Nil(t, c.CreateIndex(ctx, testCollectionName, fieldName, idx, true, WithIndexMsgBase(&commonpb.MsgBase{})))
})

t.Run("test sync create index", func(t *testing.T) {
Expand Down Expand Up @@ -85,7 +85,7 @@ func TestGrpcClientDropIndex(t *testing.T) {
c := testClient(ctx, t)
mockServer.SetInjection(MHasCollection, hasCollectionDefault)
mockServer.SetInjection(MDescribeCollection, describeCollectionInjection(t, 0, testCollectionName, defaultSchema()))
assert.Nil(t, c.DropIndex(ctx, testCollectionName, "vector"))
assert.Nil(t, c.DropIndex(ctx, testCollectionName, "vector", WithIndexMsgBase(&commonpb.MsgBase{})))
}

func TestGrpcClientDescribeIndex(t *testing.T) {
Expand Down
9 changes: 6 additions & 3 deletions client/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,14 @@ func (c *GrpcClient) mergeDynamicColumns(dynamicName string, rowSize int, column

// Flush force collection to flush memory records into storage
// in sync mode, flush will wait all segments to be flushed
func (c *GrpcClient) Flush(ctx context.Context, collName string, async bool) error {
_, _, _, err := c.FlushV2(ctx, collName, async)
func (c *GrpcClient) Flush(ctx context.Context, collName string, async bool, opts ...FlushOption) error {
_, _, _, err := c.FlushV2(ctx, collName, async, opts...)
return err
}

// Flush force collection to flush memory records into storage
// in sync mode, flush will wait all segments to be flushed
func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool) ([]int64, []int64, int64, error) {
func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool, opts ...FlushOption) ([]int64, []int64, int64, error) {
if c.Service == nil {
return nil, nil, 0, ErrClientNotReady
}
Expand All @@ -208,6 +208,9 @@ func (c *GrpcClient) FlushV2(ctx context.Context, collName string, async bool) (
DbName: "", // reserved,
CollectionNames: []string{collName},
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.Flush(ctx, req)
if err != nil {
return nil, nil, 0, err
Expand Down
41 changes: 41 additions & 0 deletions client/mq_message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package client

import (
"context"

"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

func (c *GrpcClient) ReplicateMessage(ctx context.Context,
channelName string, beginTs, endTs uint64,
msgsBytes [][]byte, startPositions, endPositions []*msgpb.MsgPosition,
opts ...ReplicateMessageOption) (*entity.MessageInfo, error) {

if c.Service == nil {
return nil, ErrClientNotReady
}
req := &milvuspb.ReplicateMessageRequest{
ChannelName: channelName,
BeginTs: beginTs,
EndTs: endTs,
Msgs: msgsBytes,
StartPositions: startPositions,
EndPositions: endPositions,
}
for _, opt := range opts {
opt(req)
}
resp, err := c.Service.ReplicateMessage(ctx, req)
if err != nil {
return nil, err
}
err = handleRespStatus(resp.GetStatus())
if err != nil {
return nil, err
}
return &entity.MessageInfo{
Position: resp.GetPosition(),
}, nil
}
Loading

0 comments on commit 087d17c

Please sign in to comment.