Skip to content

Commit

Permalink
Handle vector field data in QueryByPks (#256)
Browse files Browse the repository at this point in the history
Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Jan 26, 2022
1 parent e6520ad commit 8ae7001
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 36 deletions.
9 changes: 9 additions & 0 deletions client/client_grpc_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,15 @@ func (c *grpcClient) QueryByPks(ctx context.Context, collectionName string, part
fieldsData := resp.GetFieldsData()
columns := make([]entity.Column, 0, len(fieldsData))
for _, fieldData := range resp.GetFieldsData() {
if fieldData.GetType() == schema.DataType_FloatVector ||
fieldData.GetType() == schema.DataType_BinaryVector {
column, err := entity.FieldDataVector(fieldData)
if err != nil {
return nil, err
}
columns = append(columns, column)
continue
}
column, err := entity.FieldDataColumn(fieldData, 0, -1)
if err != nil {
return nil, err
Expand Down
49 changes: 47 additions & 2 deletions client/client_grpc_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,16 +410,31 @@ func TestGrpcQueryByPks(t *testing.T) {
},
},
},
{
Type: schema.DataType_FloatVector,
FieldName: testVectorField,
Field: &schema.FieldData_Vectors{
Vectors: &schema.VectorField{
Dim: 1,
Data: &schema.VectorField_FloatVector{
FloatVector: &schema.FloatArray{
Data: []float32{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0},
},
},
},
},
},
}

return resp, err
})
defer mock.delInjection(mQuery)

columns, err := c.QueryByPks(ctx, testCollectionName, []string{partName}, entity.NewColumnInt64(testPrimaryField, []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), []string{"int64"})
columns, err := c.QueryByPks(ctx, testCollectionName, []string{partName}, entity.NewColumnInt64(testPrimaryField, []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), []string{"int64", testVectorField})
assert.NoError(t, err)
assert.Equal(t, 1, len(columns))
assert.Equal(t, 2, len(columns))
assert.Equal(t, entity.FieldTypeInt64, columns[0].Type())
assert.Equal(t, entity.FieldTypeFloatVector, columns[1].Type())
assert.Equal(t, 10, columns[0].Len())

colInt64, ok := columns[0].(*entity.ColumnInt64)
Expand Down Expand Up @@ -519,6 +534,36 @@ func TestGrpcQueryByPks(t *testing.T) {
})
_, err = c.QueryByPks(ctx, testCollectionName, []string{}, entity.NewColumnInt64(testPrimaryField, []int64{1}), []string{"*"})
assert.Error(t, err)

mock.setInjection(mQuery, func(_ context.Context, raw proto.Message) (proto.Message, error) {
_, ok := raw.(*server.QueryRequest)
if !ok {
t.FailNow()
}

resp := &server.QueryResults{}
s, err := successStatus()
resp.Status = s
resp.FieldsData = []*schema.FieldData{
{
Type: schema.DataType_FloatVector,
FieldName: "int64",
Field: &schema.FieldData_Scalars{
Scalars: &schema.ScalarField{
Data: &schema.ScalarField_LongData{
LongData: &schema.LongArray{
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
},
},
},
},
}

return resp, err
})
_, err = c.QueryByPks(ctx, testCollectionName, []string{}, entity.NewColumnInt64(testPrimaryField, []int64{1}), []string{"*"})
assert.Error(t, err)
})
}

Expand Down
36 changes: 36 additions & 0 deletions entity/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,43 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) {
return NewColumnString(fd.GetFieldName(), data.StringData.GetData()[begin:]), nil
}
return NewColumnString(fd.GetFieldName(), data.StringData.GetData()[begin:end]), nil
default:
return nil, errors.New("unsupported data type")
}
}

// FieldDataColumn converts schema.FieldData to vector Column
func FieldDataVector(fd *schema.FieldData) (Column, error) {
switch fd.GetType() {
case schema.DataType_FloatVector:
vectors := fd.GetVectors()
data := vectors.GetFloatVector().GetData()
if data == nil {
return nil, errFieldDataTypeNotMatch
}
dim := int(vectors.GetDim())
vector := make([][]float32, 0, len(data)/dim) // shall not have remanunt
for i := 0; i < len(data)/dim; i++ {
v := make([]float32, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnFloatVector(fd.GetFieldName(), dim, vector), nil
case schema.DataType_BinaryVector:
vectors := fd.GetVectors()
data := vectors.GetBinaryVector()
if data == nil {
return nil, errFieldDataTypeNotMatch
}
dim := int(vectors.GetDim())
blen := dim / 8
vector := make([][]byte, 0, len(data)/blen)
for i := 0; i < len(data)/blen; i++ {
v := make([]byte, blen)
copy(v, data[i*blen:(i+1)*blen])
vector = append(vector, v)
}
return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil
default:
return nil, errors.New("unsupported data type")
}
Expand Down
2 changes: 1 addition & 1 deletion entity/columns_vector_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 50 additions & 9 deletions entity/columns_vector_gen_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 25 additions & 4 deletions entity/gen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,25 @@ import (
"testing"
"time"
"github.com/milvus-io/milvus-sdk-go/v2/internal/proto/schema"
"github.com/stretchr/testify/assert"
)
{{range .Types}}{{with .}}
func TestColumn{{.TypeName}}(t *testing.T) {
rand.Seed(time.Now().UnixNano())
columnName := fmt.Sprintf("column_{{.TypeName}}_%d", rand.Int())
columnLen := 8 + rand.Intn(10)
dim := ([]int{8, 32, 64, 128})[rand.Intn(4)]
columnLen := 12 + rand.Intn(10)
dim := ([]int{64, 128, 256, 512})[rand.Intn(4)]
v := make([]{{.TypeDef}}, columnLen)
column := NewColumn{{.TypeName}}(columnName,dim, v)
v := make([]{{.TypeDef}},0, columnLen)
dlen := dim
{{if eq .TypeName "BinaryVector" }}dlen /= 8{{end}}
for i := 0; i < columnLen; i++ {
entry := make({{.TypeDef}}, dlen)
v = append(v, entry)
}
column := NewColumn{{.TypeName}}(columnName, dim, v)
t.Run("test meta", func(t *testing.T) {
ft := FieldType{{.TypeName}}
Expand Down Expand Up @@ -386,6 +394,19 @@ func TestColumn{{.TypeName}}(t *testing.T) {
fd := column.FieldData()
assert.NotNil(t, fd)
assert.Equal(t, fd.GetFieldName(), columnName)
c, err := FieldDataVector(fd)
assert.NotNil(t, c)
assert.NoError(t, err)
})
t.Run("test column field data error", func(t *testing.T) {
fd := &schema.FieldData{
Type: schema.DataType_{{.TypeName}},
FieldName: columnName,
}
_, err := FieldDataVector(fd)
assert.Error(t, err)
})
}
Expand Down
Loading

0 comments on commit 8ae7001

Please sign in to comment.