From fcdff932a7035221c9c3e0dbb9a1a90a7cbc06ec Mon Sep 17 00:00:00 2001 From: ryjiang Date: Fri, 26 Apr 2024 10:28:11 +0800 Subject: [PATCH] Support query iterator (#309) * query iterator Signed-off-by: ruiyi.jiang * finish query iterator Signed-off-by: ruiyi.jiang * test more Signed-off-by: ruiyi.jiang * move getQueryIterator to utils/function and add tests for it Signed-off-by: ruiyi.jiang * rename pageSize -> batchSize Signed-off-by: ruiyi.jiang * init search iterator Signed-off-by: ruiyi.jiang * search iterator part1 Signed-off-by: ryjiang * stash Signed-off-by: ryjiang * refine Signed-off-by: ruiyi.jiang * refine Signed-off-by: ryjiang * refine Signed-off-by: ryjiang * refine id expression Signed-off-by: ruiyi.jiang * refine Signed-off-by: ruiyi.jiang * refine Signed-off-by: ryjiang * refine Signed-off-by: ryjiang * refine Signed-off-by: ryjiang * init consine Signed-off-by: ryjiang * close debug Signed-off-by: ryjiang --------- Signed-off-by: ruiyi.jiang Signed-off-by: ryjiang --- milvus/const/defaults.ts | 5 + milvus/const/milvus.ts | 17 ++ milvus/grpc/Collection.ts | 24 ++ milvus/grpc/Data.ts | 270 +++++++++++++++++++ milvus/types/Data.ts | 18 +- milvus/utils/Function.ts | 85 +++++- milvus/utils/Grpc.ts | 9 +- test/grpc/Iterator.spec.ts | 508 ++++++++++++++++++++++++++++++++++++ test/utils/Function.spec.ts | 231 +++++++++++++++- 9 files changed, 1161 insertions(+), 6 deletions(-) create mode 100644 test/grpc/Iterator.spec.ts diff --git a/milvus/const/defaults.ts b/milvus/const/defaults.ts index e3b1b949..df3acea2 100644 --- a/milvus/const/defaults.ts +++ b/milvus/const/defaults.ts @@ -17,3 +17,8 @@ export const DEFAULT_HTTP_ENDPOINT_VERSION = 'v2'; // api version, default v1 export const DEFAULT_POOL_MAX = 10; // default max pool client number export const DEFAULT_POOL_MIN = 2; // default min pool client number + +export const DEFAULT_MIN_INT64 = `-9223372036854775807`; // min int64 +export const DEFAULT_MAX_SEARCH_SIZE = 16384; // max query/search size +export const DEFAULT_MAX_L2_DISTANCE = 99999999; // max l2 distance +export const DEFAULT_MIN_COSINE_DISTANCE = -2.0; // min cosine distance diff --git a/milvus/const/milvus.ts b/milvus/const/milvus.ts index 0f39939a..4a26fbb1 100644 --- a/milvus/const/milvus.ts +++ b/milvus/const/milvus.ts @@ -300,6 +300,23 @@ export const DataTypeMap: { [key in keyof typeof DataType]: number } = { SparseFloatVector: 104, }; +// data type string enum +export enum DataTypeStringEnum { + None = 'None', + Bool = 'Bool', + Int8 = 'Int8', + Int16 = 'Int16', + Int32 = 'Int32', + Int64 = 'Int64', + Float = 'Float', + Double = 'Double', + VarChar = 'VarChar', + Array = 'Array', + JSON = 'JSON', + BinaryVector = 'BinaryVector', + FloatVector = 'FloatVector', +} + // RBAC: operate user role type export enum OperateUserRoleType { AddUserToRole = 0, diff --git a/milvus/grpc/Collection.ts b/milvus/grpc/Collection.ts index 66cab0ba..95fdc547 100644 --- a/milvus/grpc/Collection.ts +++ b/milvus/grpc/Collection.ts @@ -53,6 +53,7 @@ import { parseToKeyValue, CreateCollectionWithFieldsReq, CreateCollectionWithSchemaReq, + FieldSchema, } from '../'; /** @@ -1057,4 +1058,27 @@ export class Collection extends Database { return pkFieldType; } + + /** + * Get the primary field + */ + async getPkField(data: DescribeCollectionReq): Promise { + // get collection info + const collectionInfo = await this.describeCollection(data); + + // pk field + let pkField: FieldSchema = collectionInfo.schema.fields[0]; + // extract key information + for (let i = 0; i < collectionInfo.schema.fields.length; i++) { + const f = collectionInfo.schema.fields[i]; + + // get pk field info + if (f.is_primary_key) { + pkField = f; + break; + } + } + + return pkField; + } } diff --git a/milvus/grpc/Data.ts b/milvus/grpc/Data.ts index cdaa0e87..8bf7a0ed 100644 --- a/milvus/grpc/Data.ts +++ b/milvus/grpc/Data.ts @@ -36,6 +36,8 @@ import { SearchReq, SearchRes, SearchSimpleReq, + SearchIteratorReq, + DEFAULT_TOPK, HybridSearchReq, promisify, findKeyValue, @@ -53,6 +55,12 @@ import { CountReq, CountResult, DEFAULT_COUNT_QUERY_STRING, + getQueryIteratorExpr, + QueryIteratorReq, + getRangeFromSearchResult, + SearchResultData, + getPKFieldExpr, + DEFAULT_MAX_SEARCH_SIZE, SparseFloatVector, sparseRowsToBytes, getSparseDim, @@ -478,6 +486,268 @@ export class Data extends Collection { }; } + // async searchIterator(data: SearchIteratorReq): Promise { + // // store client + // const client = this; + // // get collection info + // const pkField = await this.getPkField(data); + // // get available count + // const count = await client.count({ + // collection_name: data.collection_name, + // expr: data.expr || data.filter || '', + // }); + // // make sure limit is not exceed the total count + // const total = data.limit > count.data ? count.data : data.limit; + // // make sure batch size is exceed the total count + // let batchSize = data.batchSize > total ? total : data.batchSize; + // // make sure batch size is not exceed max search size + // batchSize = + // batchSize > DEFAULT_MAX_SEARCH_SIZE ? DEFAULT_MAX_SEARCH_SIZE : batchSize; + + // // init expr + // const initExpr = data.expr || data.filter || ''; + // // init search params object + // data.params = data.params || {}; + // data.limit = batchSize; + + // // user range filter set + // const initRadius = Number(data.params.radius) || 0; + // const initRangeFilter = Number(data.params.range_filter) || 0; + // // range params object + // const rangeFilterParams = { + // radius: initRadius, + // rangeFilter: initRangeFilter, + // expr: initExpr, + // }; + + // // force quite if true, at first, if total is 0, return done + // let done = total === 0; + // // batch result store + // let lastBatchRes: SearchResultData[] = []; + + // // build cache + // const cache = await client.search({ + // ...data, + // limit: total > DEFAULT_MAX_SEARCH_SIZE ? DEFAULT_MAX_SEARCH_SIZE : total, + // }); + + // return { + // currentTotal: 0, + // [Symbol.asyncIterator]() { + // return { + // currentTotal: this.currentTotal, + // async next() { + // // check if reach the limit + // if ( + // (this.currentTotal >= total && this.currentTotal !== 0) || + // done + // ) { + // return { done: true, value: lastBatchRes }; + // } + + // // batch result container + // const batchRes: SearchResultData[] = []; + // const bs = + // this.currentTotal + batchSize > total + // ? total - this.currentTotal + // : batchSize; + + // // keep getting search data if not reach the batch size + // while (batchRes.length < bs) { + // // search results container + // let searchResults: SearchResults = { + // status: { error_code: 'SUCCESS', reason: '' }, + // results: [], + // }; + + // // Iterate through the cached data, adding it to the search results container until the batch size is reached. + // if (cache.results.length > 0) { + // while ( + // cache.results.length > 0 && + // searchResults.results.length < bs + // ) { + // searchResults.results.push(cache.results.shift()!); + // } + // } else if (searchResults.results.length < bs) { + // // build search params, overwrite range filter + // if (rangeFilterParams.radius && rangeFilterParams.rangeFilter) { + // data.params = { + // ...data.params, + // radius: rangeFilterParams.radius, + // range_filter: + // rangeFilterParams.rangeFilter, + // }; + // } + // // set search expr + // data.expr = rangeFilterParams.expr; + + // console.log('search param', data.params, data.expr); + + // // iterate search, if no result, double the radius, until we doubled for 5 times + // let newSearchRes = await client.search(data); + // let retry = 0; + // while (newSearchRes.results.length === 0 && retry < 5) { + // newSearchRes = await client.search(data); + // if (searchResults.results.length === 0) { + // const newRadius = rangeFilterParams.radius * 2; + + // data.params = { + // ...data.params, + // radius: newRadius, + // }; + // } + + // retry++; + // } + + // // combine search results + // searchResults.results = [ + // ...searchResults.results, + // ...newSearchRes.results, + // ]; + // } + + // console.log('return', searchResults.results); + + // // filter result, batchRes should be unique + // const filterResult = searchResults.results.filter( + // r => + // !lastBatchRes.some(l => l.id === r.id) && + // !batchRes.some(c => c.id === r.id) + // ); + + // // fill filter result to batch result, it should not exceed the batch size + // for (let i = 0; i < filterResult.length; i++) { + // if (batchRes.length < bs) { + // batchRes.push(filterResult[i]); + // } + // } + + // // get data range about last batch result + // const resultRange = getRangeFromSearchResult(filterResult); + + // console.log('result range', resultRange); + + // // if no more result, force quite + // if (resultRange.lastDistance === 0) { + // done = true; + // return { done: false, value: batchRes }; + // } + + // // update next range and expr + // rangeFilterParams.rangeFilter = resultRange.lastDistance; + // rangeFilterParams.radius = + // rangeFilterParams.radius + resultRange.radius; + // rangeFilterParams.expr = getPKFieldExpr({ + // pkField, + // value: resultRange.id as string, + // expr: initExpr, + // }); + + // console.log('last', rangeFilterParams); + // } + + // // store last result + // lastBatchRes = batchRes; + + // // update current total + // this.currentTotal += batchRes.length; + + // // return batch result + // return { done: false, value: batchRes }; + // }, + // }; + // }, + // }; + // } + + /** + * Executes a query and returns an async iterator that allows iterating over the results in batches. + * + * @param {QueryIteratorReq} data - The query iterator request data. + * @returns {Promise} - An async iterator that yields batches of query results. + * @throws {Error} - If an error occurs during the query execution. + * + * @example + * const queryData = { + * collection_name: 'my_collection', + * expr: 'age > 30', + * limit: 100, + * pageSize: 10 + * }; + * + * const iterator = await queryIterator(queryData); + * + * for await (const batch of iterator) { + * console.log(batch); // Process each batch of query results + * } + */ + async queryIterator(data: QueryIteratorReq): Promise { + // get collection info + const pkField = await this.getPkField(data); + // store client; + const client = this; + // expr + const userExpr = data.expr || data.filter || ''; + // get count + const count = await client.count({ + collection_name: data.collection_name, + expr: userExpr, + }); + // total should be the minimum of total and count + const total = data.limit > count.data ? count.data : data.limit; + const batchSize = + data.batchSize > DEFAULT_MAX_SEARCH_SIZE + ? DEFAULT_MAX_SEARCH_SIZE + : data.batchSize; + + // local variables + let expr = userExpr; + let lastBatchRes: Record = []; + let lastPKId: string | number = ''; + let currentBatchSize = batchSize; // Store the current batch size + + // return iterator + return { + currentTotal: 0, + [Symbol.asyncIterator]() { + return { + currentTotal: this.currentTotal, + async next() { + // if reach the limit, return done + if (this.currentTotal >= total) { + return { done: true, value: lastBatchRes }; + } + // set limit for current batch + data.limit = currentBatchSize; // Use the current batch size + + // get current page expr + data.expr = getQueryIteratorExpr({ + expr: expr, + pkField, + lastPKId, + }); + + // search data + const res = await client.query(data); + + // get last item of the data + const lastItem = res.data[res.data.length - 1]; + // update last pk id + lastPKId = lastItem && lastItem[pkField.name]; + + // store last batch result + lastBatchRes = res.data; + // update current total + this.currentTotal += lastBatchRes.length; + // Update the current batch size based on remaining data + currentBatchSize = Math.min(batchSize, total - this.currentTotal); + return { done: false, value: lastBatchRes }; + }, + }; + }, + }; + } // alias hybridSearch = this.search; diff --git a/milvus/types/Data.ts b/milvus/types/Data.ts index a59c5c0e..a94a8a8c 100644 --- a/milvus/types/Data.ts +++ b/milvus/types/Data.ts @@ -216,7 +216,7 @@ export interface MutationResult extends resStatusResponse { } export interface QueryResults extends resStatusResponse { - data: { [x: string]: any }[]; + data: Record[]; } export interface CountResult extends resStatusResponse { @@ -284,6 +284,16 @@ export interface SearchReq extends collectionNameReq { transformers?: OutputTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors } +export interface SearchIteratorReq + extends Omit< + SearchSimpleReq, + 'data' | 'vectors' | 'offset' | 'limit' | 'topk' + > { + data: number[]; // data to search + batchSize: number; + limit: number; +} + // simplified search api parameter type export interface SearchSimpleReq extends collectionNameReq { partition_names?: string[]; // partition names @@ -393,6 +403,12 @@ export interface QueryReq extends collectionNameReq { transformers?: OutputTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors } +export interface QueryIteratorReq + extends Omit { + limit: number; + batchSize: number; +} + export interface GetReq extends collectionNameReq { ids: string[] | number[]; // primary key values output_fields?: string[]; // fields to return diff --git a/milvus/utils/Function.ts b/milvus/utils/Function.ts index 61951d30..d85a8ddf 100644 --- a/milvus/utils/Function.ts +++ b/milvus/utils/Function.ts @@ -1,4 +1,13 @@ -import { KeyValuePair, DataType, ERROR_REASONS, SparseFloatVector } from '../'; +import { + KeyValuePair, + DataType, + ERROR_REASONS, + FieldSchema, + DataTypeStringEnum, + DEFAULT_MIN_INT64, + SearchResultData, + SparseFloatVector, +} from '../'; import { Pool } from 'generic-pool'; /** @@ -139,6 +148,80 @@ export const getDataKey = (type: DataType, camelCase: boolean = false) => { return camelCase ? convertToCamelCase(dataKey) : dataKey; }; +/** + * Returns the query iterator expression based on the provided parameters. + * + * @param params - The parameters for generating the query iterator expression. + * @param params.expr - The expression to be combined with the iterator expression. + * @param params.pkField - The primary key field schema. + * @param params.page - The current page number. + * @param params.pageCache - The cache of previous pages. + * @returns The query iterator expression. + */ +export const getQueryIteratorExpr = (params: { + expr: string; + pkField: FieldSchema; + lastPKId: string | number; +}) => { + // get params + const { expr, lastPKId, pkField } = params; + + // If cache does not exist, return expression based on primaryKey type + let compareValue = ''; + if (!lastPKId) { + // get default value + compareValue = + pkField?.data_type === DataTypeStringEnum.VarChar + ? '' + : `${DEFAULT_MIN_INT64}`; + } else { + compareValue = lastPKId as string; + } + + // return expr combined with iteratorExpr + return getPKFieldExpr({ + pkField, + value: compareValue, + expr, + condition: '>', + }); +}; + +// return distance range between the first and last item for the given search results +export const getRangeFromSearchResult = (results: SearchResultData[]) => { + // get first item + const firstItem = results[0]; + const lastItem = results[results.length - 1]; + + if (firstItem && lastItem) { + const radius = lastItem.score * 2 - firstItem.score; + return { + radius: radius, + lastDistance: lastItem.score, + id: lastItem.id, + }; + } else { + return { + radius: 0, + lastDistance: 0, + }; + } +}; + +// return pk filed != expression based on pk field type, if pk field is string, return pk field != '' +export const getPKFieldExpr = (data: { + pkField: FieldSchema; + value: string | number; + condition?: string; + expr?: string; +}) => { + const { pkField, value, condition = '!=', expr = '' } = data; + const pkValue = + pkField?.data_type === DataTypeStringEnum.VarChar + ? `'${value}'` + : `${value}`; + return `${pkField?.name} ${condition} ${pkValue}${expr ? ` && ${expr}` : ''}`; +}; // get biggest size of sparse vector array export const getSparseDim = (data: SparseFloatVector[]) => { let dim = 0; diff --git a/milvus/utils/Grpc.ts b/milvus/utils/Grpc.ts index 721fd90d..28459b89 100644 --- a/milvus/utils/Grpc.ts +++ b/milvus/utils/Grpc.ts @@ -193,13 +193,16 @@ export const getRetryInterceptor = ({ newCall.start(savedMetadata, retryListener); newCall.sendMessage(savedSendMessage); } else { + const string = JSON.stringify(savedReceiveMessage); + const msg = + string.length > 2048 ? string.slice(0, 2048) + '...' : string; + logger.debug( `\x1b[32m[Response(${ Date.now() - startTime.getTime() - }ms)]\x1b[0m\x1b[2m${clientId}\x1b[0m>${dbname}>\x1b[1m${methodName}\x1b[0m: ${JSON.stringify( - savedReceiveMessage - )}` + }ms)]\x1b[0m\x1b[2m${clientId}\x1b[0m>${dbname}>\x1b[1m${methodName}\x1b[0m: ${msg}` ); + savedMessageNext(savedReceiveMessage); savedStatusNext(status); } diff --git a/test/grpc/Iterator.spec.ts b/test/grpc/Iterator.spec.ts new file mode 100644 index 00000000..b0fd5db4 --- /dev/null +++ b/test/grpc/Iterator.spec.ts @@ -0,0 +1,508 @@ +import { MilvusClient, DataType } from '../../milvus'; +import { + IP, + genCollectionParams, + GENERATE_NAME, + generateInsertData, + dynamicFields, +} from '../tools'; + +const milvusClient = new MilvusClient({ address: IP, logLevel: 'info' }); +const COLLECTION = GENERATE_NAME(); +const COLLECTION_COSINE = GENERATE_NAME(); +const dbParam = { + db_name: 'Iterator_test_db', +}; +const numPartitions = 3; + +// create +const createCollectionParams = genCollectionParams({ + collectionName: COLLECTION, + dim: [4], + vectorType: [DataType.FloatVector], + autoID: false, + partitionKeyEnabled: true, + numPartitions, + enableDynamic: true, +}); + +const createCosineCollectionParams = genCollectionParams({ + collectionName: COLLECTION_COSINE, + dim: [4], + vectorType: [DataType.FloatVector], + autoID: false, + partitionKeyEnabled: true, + numPartitions, + enableDynamic: false, +}); +// data to insert +const data = generateInsertData( + [...createCollectionParams.fields, ...dynamicFields], + 20000 +); + +const cosineData = generateInsertData( + [...createCosineCollectionParams.fields], + 20000 +); + +// id map for faster test +const dataMap = new Map(data.map(item => [item.id.toString(), item])); + +describe(`Iterator API`, () => { + beforeAll(async () => { + // create db and use db + await milvusClient.createDatabase(dbParam); + await milvusClient.use(dbParam); + // create collection + await milvusClient.createCollection(createCollectionParams); + await milvusClient.createCollection(createCosineCollectionParams); + // insert data + await milvusClient.insert({ + collection_name: COLLECTION, + fields_data: data, + }); + await milvusClient.insert({ + collection_name: COLLECTION_COSINE, + fields_data: cosineData, + }); + + // create index + await milvusClient.createIndex({ + collection_name: COLLECTION, + index_name: 't', + field_name: 'vector', + index_type: 'IVF_FLAT', + metric_type: 'L2', + params: { nlist: 1024 }, + }); + + await milvusClient.createIndex({ + collection_name: COLLECTION_COSINE, + index_name: 't', + field_name: 'vector', + index_type: 'IVF_FLAT', + metric_type: 'COSINE', + params: { nlist: 1024 }, + }); + + // load collection + await milvusClient.loadCollectionSync({ + collection_name: COLLECTION, + }); + await milvusClient.loadCollectionSync({ + collection_name: COLLECTION_COSINE, + }); + }); + + afterAll(async () => { + await milvusClient.dropCollection({ + collection_name: COLLECTION, + }); + await milvusClient.dropCollection({ + collection_name: COLLECTION_COSINE, + }); + await milvusClient.dropDatabase(dbParam); + }); + + it(`query iterator with batch size = 1 should success`, async () => { + // page size + const batchSize = 1; + const total = 10; + const iterator = await milvusClient.queryIterator({ + collection_name: COLLECTION, + batchSize: batchSize, + expr: 'id > 0', + output_fields: ['id'], + limit: total, + }); + + const results: any = []; + let page = 0; + for await (const value of iterator) { + results.push(...value); + page += 1; + } + + // page size should equal to page + expect(page).toEqual(Math.ceil(total / batchSize)); + // results length should equal to data length + expect(results.length).toEqual(total); + + // results id should be unique + const idSet = new Set(); + results.forEach((result: any) => { + idSet.add(result.id); + }); + expect(idSet.size).toEqual(total); + + // every id in query result should be founded in the original data + results.forEach((result: any) => { + const item = dataMap.get(result.id.toString()); + expect(typeof item !== 'undefined').toEqual(true); + }); + }); + + it(`query iterator with batch size > 16384 should success`, async () => { + // page size + const batchSize = 16384; + const total = data.length; + const iterator = await milvusClient.queryIterator({ + collection_name: COLLECTION, + batchSize: batchSize, + expr: 'id > 0', + output_fields: ['id'], + limit: total, + }); + + const results: any = []; + let page = 0; + for await (const value of iterator) { + results.push(...value); + page += 1; + } + + // page size should equal to page + expect(page).toEqual(Math.ceil(total / batchSize)); + // results length should equal to data length + expect(results.length).toEqual(total); + + // results id should be unique + const idSet = new Set(); + results.forEach((result: any) => { + idSet.add(result.id); + }); + expect(idSet.size).toEqual(total); + + // every id in query result should be founded in the original data + results.forEach((result: any) => { + const item = dataMap.get(result.id.toString()); + expect(typeof item !== 'undefined').toEqual(true); + }); + }); + + it(`query iterator with batch size > total should success`, async () => { + // page size + const batchSize = data.length + 1; + const total = data.length; + const iterator = await milvusClient.queryIterator({ + collection_name: COLLECTION, + batchSize: batchSize, + expr: 'id > 0', + output_fields: ['id'], + limit: total, + }); + + const results: any = []; + let page = 0; + for await (const value of iterator) { + results.push(...value); + page += 1; + } + + // page size should equal to page + expect(page).toEqual( + batchSize > 16384 + ? Math.ceil(total / 16384) + : Math.ceil(total / batchSize) + ); + // results length should equal to data length + expect(results.length).toEqual(total); + + // results id should be unique + const idSet = new Set(); + results.forEach((result: any) => { + idSet.add(result.id); + }); + expect(idSet.size).toEqual(total); + + // every id in query result should be founded in the original data + results.forEach((result: any) => { + const item = dataMap.get(result.id.toString()); + expect(typeof item !== 'undefined').toEqual(true); + }); + }); + + it(`query iterator with limit < total should success`, async () => { + // page size + const batchSize = 2; + const total = 10; + const iterator = await milvusClient.queryIterator({ + collection_name: COLLECTION, + batchSize: batchSize, + expr: 'id > 0', + output_fields: ['id'], + limit: total, + }); + + const results: any = []; + let page = 0; + for await (const value of iterator) { + results.push(...value); + page += 1; + } + + // page size should equal to page + expect(page).toEqual(Math.ceil(total / batchSize)); + // results length should equal to data length + expect(results.length).toEqual(total); + + // results id should be unique + const idSet = new Set(); + results.forEach((result: any) => { + idSet.add(result.id); + }); + expect(idSet.size).toEqual(total); + + // every id in query result should be founded in the original data + results.forEach((result: any) => { + const item = dataMap.get(result.id.toString()); + expect(typeof item !== 'undefined').toEqual(true); + }); + }); + + it(`query iterator with limit > total should success`, async () => { + // page size + const batchSize = 5000; + const total = 30000; + const iterator = await milvusClient.queryIterator({ + collection_name: COLLECTION, + batchSize: batchSize, + expr: 'id > 0', + output_fields: ['id'], + limit: total, + }); + + const results: any = []; + let page = 0; + for await (const value of iterator) { + results.push(...value); + page += 1; + } + + // page size should equal to page + expect(page).toEqual(Math.ceil(data.length / batchSize)); + // results length should equal to data length + expect(results.length).toEqual(data.length); + + // results id should be unique + const idSet = new Set(); + results.forEach((result: any) => { + idSet.add(result.id); + }); + expect(idSet.size).toEqual(data.length); + + // every id in query result should be founded in the original data + results.forEach((result: any) => { + const item = dataMap.get(result.id.toString()); + expect(typeof item !== 'undefined').toEqual(true); + }); + }); + + // it('search iterator with batch size = total should success', async () => { + // const batchSize = 100; + // const total = 100; + // const iterator = await milvusClient.searchIterator({ + // collection_name: COLLECTION, + // batchSize: batchSize, + // data: data[0].vector, + // expr: 'id > 0', + // output_fields: ['id'], + // limit: total, + // }); + + // const results: any = []; + // // let batch = 0; + // for await (const value of iterator) { + // // console.log(`batch${batch++}`, value.length); + // // console.log(value.map((item: any) => item.score)); + // results.push(...value); + // } + + // // results id should be unique + // const idSet = new Set(); + // results.forEach((result: any) => { + // idSet.add(result.id); + // }); + // expect(idSet.size).toEqual(total); + // }); + + // it('search iterator with batch size > total should success', async () => { + // const batchSize = 200; + // const total = 100; + // const iterator = await milvusClient.searchIterator({ + // collection_name: COLLECTION, + // batchSize: batchSize, + // data: data[0].vector, + // expr: 'id > 0', + // output_fields: ['id'], + // limit: total, + // }); + + // const results: any = []; + // // let batch = 0; + // for await (const value of iterator) { + // // console.log(`batch${batch++}`, value.length); + // // console.log(value.map((item: any) => item.score)); + // results.push(...value); + // } + + // // results id should be unique + // const idSet = new Set(); + // results.forEach((result: any) => { + // idSet.add(result.id); + // }); + // expect(idSet.size).toEqual(total); + // }); + + // it('search iterator with batch size < total should success', async () => { + // const batchSize = 33; + // const total = 100; + // const iterator = await milvusClient.searchIterator({ + // collection_name: COLLECTION, + // batchSize: batchSize, + // data: data[0].vector, + // expr: 'id > 0', + // output_fields: ['id'], + // limit: total, + // }); + + // const results: any = []; + // let batchTimes = 0; + // for await (const value of iterator) { + // // console.log(`batch${batch++}`, value.length); + // // console.log(value.map((item: any) => item.score)); + // batchTimes++; + // results.push(...value); + // } + // expect(batchTimes).toEqual(Math.ceil(total / batchSize)); + + // // results id should be unique + // const idSet = new Set(); + // results.forEach((result: any) => { + // idSet.add(result.id); + // }); + // expect(idSet.size).toEqual(total); + // }); + + // it('search iterator with batch size = 2 should success, and ignore total', async () => { + // const batchSize = 2; + // const total = 10; + // const iterator = await milvusClient.searchIterator({ + // collection_name: COLLECTION, + // batchSize: batchSize, + // data: [0.1, 0.2, 0.3, 0.4], + // expr: 'id > 0', + // output_fields: ['id'], + // limit: total, + // }); + + // const results: any = []; + // // let batch = 0; + // for await (const value of iterator) { + // // console.log(`batch${batch++}`, value.length); + // // console.log(value.map((item: any) => item.score)); + // results.push(...value); + // } + + // expect(results.length).toEqual(total); + + // // results id should be unique + // const idSet = new Set(); + // results.forEach((result: any) => { + // idSet.add(result.id); + // }); + // expect(idSet.size).toEqual(total); + // }); + + // it('search iterator with batch size = 1 should success, and ignore total', async () => { + // const search = await milvusClient.search({ + // collection_name: COLLECTION, + // data: [0.1, 0.2, 0.3, 0.4], + // expr: 'id > 0', + // output_fields: ['id'], + // limit: 10, + // }); + + // const batchSize = 1; + // const total = 10; + // const iterator = await milvusClient.searchIterator({ + // collection_name: COLLECTION, + // batchSize: batchSize, + // data: [0.1, 0.2, 0.3, 0.4], + // expr: 'id > 0', + // output_fields: ['id'], + // limit: total, + // }); + + // const results: any = []; + // // let batch = 0; + // for await (const value of iterator) { + // // console.log(`batch${batch++}`, value.length); + // // console.log(value.map((item: any) => item.score)); + // results.push(...value); + // } + + // expect(results.length).toEqual(total); + + // // results id should be unique + // const idSet = new Set(); + // results.forEach((result: any) => { + // idSet.add(result.id); + // }); + // expect(idSet.size).toEqual(total); + + // // get result scores + // const scores = results.map((result: any) => result.score); + + // // compare with search result, should be equal + // expect(scores).toEqual(search.results.map(s => s.score)); + // }); + + // it('search iterator with limit > all data count should success, and ignore total', async () => { + // const batchSize = 5000; + // const total = 30000; + // const iterator = await milvusClient.searchIterator({ + // collection_name: COLLECTION, + // batchSize: batchSize, + // data: data[0].vector, + // expr: 'id > 0', + // output_fields: ['id'], + // limit: total, + // }); + + // const results: any = []; + // // let batch = 0; + // for await (const value of iterator) { + // // console.log(`batch${batch++}`, value.length); + // // console.log(value.map((item: any) => item.score)); + // results.push(...value); + // } + + // expect(results.length).toEqual(data.length); + + // // results id should be unique + // const idSet = new Set(); + // results.forEach((result: any) => { + // idSet.add(result.id); + // }); + // expect(idSet.size).toEqual(data.length); + // }); + + // it('search with cosine similarity should success', async () => { + // const batchSize = 10000; + // const total = 2000; + // const searchRes = await milvusClient.search({ + // collection_name: COLLECTION_COSINE, + // data: cosineData[0].vector, + // expr: 'id > 0', + // output_fields: ['id'], + // limit: total, + // params: { + // radius: 0.6, + // range_filter: 0.5, + // }, + // }); + // console.log(searchRes.results.map(s => s.score)); + // }); +}); diff --git a/test/utils/Function.spec.ts b/test/utils/Function.spec.ts index 793a6082..d742657a 100644 --- a/test/utils/Function.spec.ts +++ b/test/utils/Function.spec.ts @@ -1,12 +1,18 @@ import { promisify, + getQueryIteratorExpr, + DataTypeStringEnum, + DEFAULT_MIN_INT64, + getPKFieldExpr, + getRangeFromSearchResult, + SearchResultData, getSparseDim, SparseFloatVector, getDataKey, DataType, } from '../../milvus'; -describe('promisify', () => { +describe('Function API testing', () => { let pool: any; let client: any; @@ -53,6 +59,229 @@ describe('promisify', () => { expect(pool.release).toHaveBeenCalled(); }); + it('should return varchar expression when cache does not exist', () => { + const params = { + expr: '', + pkField: { + name: 'id', + data_type: DataTypeStringEnum.VarChar, + }, + lastPkId: '', + } as any; + + const result = getQueryIteratorExpr(params); + + expect(result).toBe("id > ''"); + }); + + it('should return varchar expression when cache exists', () => { + const params = { + expr: '', + pkField: { + name: 'id', + data_type: DataTypeStringEnum.VarChar, + }, + lastPKId: 'abc', + } as any; + + const result = getQueryIteratorExpr(params); + + expect(result).toBe("id > 'abc'"); + }); + + it('should return varchar expression combined with iteratorExpr when expr is provided', () => { + const params = { + expr: 'field > 10', + pkField: { + name: 'id', + data_type: DataTypeStringEnum.VarChar, + }, + page: 1, + lastPkId: '', + } as any; + + const result = getQueryIteratorExpr(params); + + expect(result).toBe("id > '' && field > 10"); + }); + + it('should return int64 expression when cache does not exist', () => { + const params = { + expr: '', + pkField: { + name: 'id', + data_type: DataTypeStringEnum.Int64, + }, + lastPkId: '', + } as any; + + const result = getQueryIteratorExpr(params); + + expect(result).toBe(`id > ${DEFAULT_MIN_INT64}`); + }); + + it('should return int64 expression when cache exists', () => { + const params = { + expr: '', + pkField: { + name: 'id', + data_type: DataTypeStringEnum.Int64, + }, + lastPKId: 10, + } as any; + + const result = getQueryIteratorExpr(params); + + expect(result).toBe('id > 10'); + }); + + it('should return int64 expression combined with iteratorExpr when expr is provided and cache exists', () => { + const params = { + expr: 'field > 10', + pkField: { + name: 'id', + data_type: DataTypeStringEnum.Int64, + }, + lastPKId: 10, + } as any; + + const result = getQueryIteratorExpr(params); + + expect(result).toBe('id > 10 && field > 10'); + }); + + it('should return 0 radius when results are empty', () => { + const results = [] as any; + + const result = getRangeFromSearchResult(results); + + expect(result).toEqual({ + radius: 0, + lastDistance: 0, + }); + }); + + it('should return radius and lastDistance when results are not empty', () => { + const results: SearchResultData[] = [ + { + id: '1', + score: 0.1, + }, + { + id: '2', + score: 0.2, + }, + { + id: '3', + score: 0.3, + }, + ]; + + const result = getRangeFromSearchResult(results); + + expect(result).toEqual({ + radius: 0.3 * 2 - 0.1, + lastDistance: 0.3, + id: '3', + }); + }); + + it('should return 0 radius when results contain only one item', () => { + const results: SearchResultData[] = [ + { + id: '1', + score: 0.1, + }, + ]; + + const result = getRangeFromSearchResult(results); + + expect(result).toEqual({ + radius: 0.1 * 2 - 0.1, + lastDistance: 0.1, + id: '1', + }); + }); + + it('should return 0 radius when results contain only two items', () => { + const results: SearchResultData[] = [ + { + id: '1', + score: 0.1, + }, + { + id: '2', + score: 0.2, + }, + ]; + + const result = getRangeFromSearchResult(results); + + expect(result).toEqual({ + radius: 0.2 * 2 - 0.1, + lastDistance: 0.2, + id: '2', + }); + }); + + it('should return varchar expression when pk field is varchar', () => { + const pkField: any = { + name: 'id', + data_type: DataTypeStringEnum.VarChar, + }; + + const result = getPKFieldExpr({ + pkField, + value: 'abc', + }); + + expect(result).toBe("id != 'abc'"); + }); + + it('should return int64 expression when pk field is int64', () => { + const pkField: any = { + name: 'id', + data_type: DataTypeStringEnum.Int64, + }; + + const result = getPKFieldExpr({ + pkField, + value: 10, + }); + + expect(result).toBe('id != 10'); + }); + + it('should return int64 expression with condition when condition is provided', () => { + const pkField: any = { + name: 'id', + data_type: DataTypeStringEnum.Int64, + }; + + const result = getPKFieldExpr({ + pkField, + value: 10, + condition: '>', + }); + + expect(result).toBe('id > 10'); + }); + + it('should return int64 expression with condition and expr when expr is provided', () => { + const pkField: any = { + name: 'id', + data_type: DataTypeStringEnum.Int64, + }; + + const result = getPKFieldExpr({ + pkField, + value: 10, + condition: '>', + expr: 'field > 10', + }); + + expect(result).toBe('id > 10 && field > 10'); + }); it('should return the correct dimension of the sparse vector', () => { const data = [ { '0': 1, '1': 2, '2': 3 },