diff --git a/examples/orm/iterator.py b/examples/orm/iterator.py index ed7d2bf54..9a563ff0e 100644 --- a/examples/orm/iterator.py +++ b/examples/orm/iterator.py @@ -1,5 +1,6 @@ import numpy as np import random +import logging from pymilvus import ( connections, utility, @@ -17,9 +18,13 @@ PICTURE = "picture" CONSISTENCY_LEVEL = "Eventually" LIMIT = 5 -NUM_ENTITIES = 1000 +NUM_ENTITIES = 10000 DIM = 8 -CLEAR_EXIST = False +CLEAR_EXIST = True + +# Create a logger for the main script +log = logging.getLogger("pymilvus") +log.setLevel(logging.INFO) def re_create_collection(skip_data_period: bool): @@ -29,8 +34,7 @@ def re_create_collection(skip_data_period: bool): print(f"dropped existed collection{COLLECTION_NAME}") fields = [ - FieldSchema(name=USER_ID, dtype=DataType.VARCHAR, is_primary=True, - auto_id=False, max_length=MAX_LENGTH), + FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False), FieldSchema(name=AGE, dtype=DataType.INT64), FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE), FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM) @@ -58,10 +62,9 @@ def random_pk(filter_set: set, lower_bound: int, upper_bound: int) -> str: def insert_data(collection): rng = np.random.default_rng(seed=19530) batch_count = 5 - filter_set: set = {} for i in range(batch_count): entities = [ - [random_pk(filter_set, 0, batch_count * NUM_ENTITIES) for _ in range(NUM_ENTITIES)], + [i for i in range(NUM_ENTITIES*i, NUM_ENTITIES*(i + 1))], [int(ni % 100) for ni in range(NUM_ENTITIES)], [float(ni) for ni in range(NUM_ENTITIES)], rng.random((NUM_ENTITIES, DIM)), @@ -150,6 +153,22 @@ def query_iterate_collection_with_offset(collection): print(f"page{page_idx}-------------------------") +def query_iterate_collection_with_large_offset(collection): + query_iterator = collection.query_iterator(output_fields=[USER_ID, AGE], + offset=48000, batch_size=50, consistency_level=CONSISTENCY_LEVEL) + page_idx = 0 + while True: + res = query_iterator.next() + if len(res) == 0: + print("query iteration finished, close") + query_iterator.close() + break + for i in range(len(res)): + print(res[i]) + page_idx += 1 + print(f"page{page_idx}-------------------------") + + def query_iterate_collection_with_limit(collection): expr = f"10 <= {AGE} <= 44" query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE], @@ -167,6 +186,8 @@ def query_iterate_collection_with_limit(collection): print(f"page{page_idx}-------------------------") + + def search_iterator_collection(collection): SEARCH_NQ = 1 DIM = 8 @@ -216,11 +237,12 @@ def search_iterator_collection_with_limit(collection): def main(): - skip_data_period = False + skip_data_period = True connections.connect("default", host=HOST, port=PORT) collection = re_create_collection(skip_data_period) if not skip_data_period: collection = prepare_data(collection) + query_iterate_collection_with_large_offset(collection) query_iterate_collection_no_offset(collection) query_iterate_collection_with_offset(collection) query_iterate_collection_with_limit(collection) diff --git a/pymilvus/orm/iterator.py b/pymilvus/orm/iterator.py index 702d744a1..fef5dead2 100644 --- a/pymilvus/orm/iterator.py +++ b/pymilvus/orm/iterator.py @@ -1,4 +1,5 @@ import logging +import time from copy import deepcopy from typing import Any, Dict, List, Optional, TypeVar, Union @@ -39,8 +40,7 @@ from .schema import CollectionSchema from .types import DataType -LOGGER = logging.getLogger(__name__) -LOGGER.setLevel(logging.ERROR) +log = logging.getLogger(__name__) QueryIterator = TypeVar("QueryIterator") SearchIterator = TypeVar("SearchIterator") @@ -77,19 +77,16 @@ def __init__( self._schema = schema self._timeout = timeout self._kwargs = kwargs - self.__set_up_iteration_states() self.__check_set_batch_size(batch_size) self._limit = limit self.__check_set_reduce_stop_for_best() self._returned_count = 0 self.__setup__pk_prop() self.__set_up_expr(expr) + self._next_id = None self.__seek() self._cache_id_in_use = NO_CACHE_ID - def __set_up_iteration_states(self): - self._kwargs[ITERATOR_FIELD] = "True" - def __check_set_reduce_stop_for_best(self): if self._kwargs.get(REDUCE_STOP_FOR_BEST, True): self._kwargs[REDUCE_STOP_FOR_BEST] = "True" @@ -115,25 +112,47 @@ def __set_up_expr(self, expr: str): def __seek(self): self._cache_id_in_use = NO_CACHE_ID - if self._kwargs.get(OFFSET, 0) == 0: - self._next_id = None - return - - first_cursor_kwargs = self._kwargs.copy() - first_cursor_kwargs[OFFSET] = 0 - # offset may be too large, needed to seek in multiple times - first_cursor_kwargs[MILVUS_LIMIT] = self._kwargs[OFFSET] - - res = self._conn.query( - collection_name=self._collection_name, - expr=self._expr, - output_field=self._output_fields, - partition_name=self._partition_names, - timeout=self._timeout, - **first_cursor_kwargs, - ) - self.__update_cursor(res) - self._kwargs[OFFSET] = 0 + offset = self._kwargs.get(OFFSET, 0) + if offset > 0: + seek_params = self._kwargs.copy() + seek_params[OFFSET] = 0 + # offset may be too large, needed to seek in multiple times + seek_params[ITERATOR_FIELD] = "False" + seek_params[REDUCE_STOP_FOR_BEST] = "False" + start_time = time.time() + + def seek_offset_by_batch(batch: int, expr: str) -> int: + seek_params[MILVUS_LIMIT] = batch + res = self._conn.query( + collection_name=self._collection_name, + expr=expr, + output_field=[], + partition_name=self._partition_names, + timeout=self._timeout, + **seek_params, + ) + self.__update_cursor(res) + return len(res) + + while offset > 0: + batch_size = min(MAX_BATCH_SIZE, offset) + next_expr = self.__setup_next_expr() + seeked_count = seek_offset_by_batch(batch_size, next_expr) + log.debug( + f"seeked offset, seek_expr:{next_expr} batch_size:{batch_size} seeked_count:{seeked_count}" + ) + if seeked_count == 0: + log.info( + "seek offset has drained all matched results for query iterator, break" + ) + break + offset -= seeked_count + self._kwargs[OFFSET] = 0 + seek_offset_duration = time.time() - start_time + log.info( + f"Finish seek offset for query iterator, offset:{offset}, current_pk_cursor:{self._next_id}, " + f"duration:{seek_offset_duration}" + ) def __maybe_cache(self, result: List): if len(result) < 2 * self._kwargs[BATCH_SIZE]: @@ -156,6 +175,7 @@ def next(self): else: iterator_cache.release_cache(self._cache_id_in_use) current_expr = self.__setup_next_expr() + log.debug(f"query_iterator_next_expr:{current_expr}") res = self._conn.query( collection_name=self._collection_name, expr=current_expr, @@ -194,7 +214,7 @@ def __setup__pk_prop(self): if self._pk_field_name is None or self._pk_field_name == "": raise MilvusException(message="schema must contain pk field, broke") - def __setup_next_expr(self) -> None: + def __setup_next_expr(self) -> str: current_expr = self._expr if self._next_id is None: return current_expr @@ -339,7 +359,7 @@ def __init_search_iterator(self): "Cannot init search iterator because init page contains no matched rows, " "please check the radius and range_filter set up by searchParams" ) - LOGGER.error(message) + log.error(message) self._cache_id = NO_CACHE_ID self._init_success = False return @@ -364,14 +384,14 @@ def __update_width(self, page: SearchPage): def __set_up_range_parameters(self, page: SearchPage): self.__update_width(page) self._tail_band = page[-1].distance - LOGGER.debug( + log.debug( f"set up init parameter for searchIterator width:{self._width} tail_band:{self._tail_band}" ) def __check_reached_limit(self) -> bool: if self._limit == UNLIMITED or self._returned_count < self._limit: return False - LOGGER.debug( + log.debug( f"reached search limit:{self._limit}, returned_count:{self._returned_count}, directly return" ) return True @@ -528,7 +548,7 @@ def __try_search_fill(self) -> SearchPage: if len(final_page) >= self._iterator_params[BATCH_SIZE]: break if try_time > MAX_TRY_TIME: - LOGGER.warning(f"Search probe exceed max try times:{MAX_TRY_TIME} directly break") + log.warning(f"Search probe exceed max try times:{MAX_TRY_TIME} directly break") break # if there's a ring containing no vectors matched, then we need to extend # the ring continually to avoid empty ring problem @@ -538,6 +558,7 @@ def __try_search_fill(self) -> SearchPage: def __execute_next_search( self, next_params: dict, next_expr: str, to_extend_batch: bool ) -> SearchPage: + log.debug(f"search_iterator_next_expr:{next_expr}, next_params:{next_params}") res = self._conn.search( self._iterator_params["collection_name"], self._iterator_params["data"], @@ -592,7 +613,7 @@ def __next_params(self, coefficient: int): else: next_params[PARAMS][RADIUS] = next_radius next_params[PARAMS][RANGE_FILTER] = self._tail_band - LOGGER.debug( + log.debug( f"next round search iteration radius:{next_params[PARAMS][RADIUS]}," f"range_filter:{next_params[PARAMS][RANGE_FILTER]}," f"coefficient:{coefficient}"