Skip to content

Commit

Permalink
extend unlimted offset for query iterator(milvus-io#2418)
Browse files Browse the repository at this point in the history
Signed-off-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han committed Dec 9, 2024
1 parent b1bc025 commit 2c06080
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 38 deletions.
36 changes: 29 additions & 7 deletions examples/orm/iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import random
import logging
from pymilvus import (
connections,
utility,
Expand All @@ -17,9 +18,13 @@
PICTURE = "picture"
CONSISTENCY_LEVEL = "Eventually"
LIMIT = 5
NUM_ENTITIES = 1000
NUM_ENTITIES = 10000
DIM = 8
CLEAR_EXIST = False
CLEAR_EXIST = True

# Create a logger for the main script
log = logging.getLogger("pymilvus")
log.setLevel(logging.INFO)


def re_create_collection(skip_data_period: bool):
Expand All @@ -29,8 +34,7 @@ def re_create_collection(skip_data_period: bool):
print(f"dropped existed collection{COLLECTION_NAME}")

fields = [
FieldSchema(name=USER_ID, dtype=DataType.VARCHAR, is_primary=True,
auto_id=False, max_length=MAX_LENGTH),
FieldSchema(name=USER_ID, dtype=DataType.INT64, is_primary=True, auto_id=False),
FieldSchema(name=AGE, dtype=DataType.INT64),
FieldSchema(name=DEPOSIT, dtype=DataType.DOUBLE),
FieldSchema(name=PICTURE, dtype=DataType.FLOAT_VECTOR, dim=DIM)
Expand Down Expand Up @@ -58,10 +62,9 @@ def random_pk(filter_set: set, lower_bound: int, upper_bound: int) -> str:
def insert_data(collection):
rng = np.random.default_rng(seed=19530)
batch_count = 5
filter_set: set = {}
for i in range(batch_count):
entities = [
[random_pk(filter_set, 0, batch_count * NUM_ENTITIES) for _ in range(NUM_ENTITIES)],
[i for i in range(NUM_ENTITIES*i, NUM_ENTITIES*(i + 1))],
[int(ni % 100) for ni in range(NUM_ENTITIES)],
[float(ni) for ni in range(NUM_ENTITIES)],
rng.random((NUM_ENTITIES, DIM)),
Expand Down Expand Up @@ -150,6 +153,22 @@ def query_iterate_collection_with_offset(collection):
print(f"page{page_idx}-------------------------")


def query_iterate_collection_with_large_offset(collection):
query_iterator = collection.query_iterator(output_fields=[USER_ID, AGE],
offset=48000, batch_size=50, consistency_level=CONSISTENCY_LEVEL)
page_idx = 0
while True:
res = query_iterator.next()
if len(res) == 0:
print("query iteration finished, close")
query_iterator.close()
break
for i in range(len(res)):
print(res[i])
page_idx += 1
print(f"page{page_idx}-------------------------")


def query_iterate_collection_with_limit(collection):
expr = f"10 <= {AGE} <= 44"
query_iterator = collection.query_iterator(expr=expr, output_fields=[USER_ID, AGE],
Expand All @@ -167,6 +186,8 @@ def query_iterate_collection_with_limit(collection):
print(f"page{page_idx}-------------------------")




def search_iterator_collection(collection):
SEARCH_NQ = 1
DIM = 8
Expand Down Expand Up @@ -216,11 +237,12 @@ def search_iterator_collection_with_limit(collection):


def main():
skip_data_period = False
skip_data_period = True
connections.connect("default", host=HOST, port=PORT)
collection = re_create_collection(skip_data_period)
if not skip_data_period:
collection = prepare_data(collection)
query_iterate_collection_with_large_offset(collection)
query_iterate_collection_no_offset(collection)
query_iterate_collection_with_offset(collection)
query_iterate_collection_with_limit(collection)
Expand Down
83 changes: 52 additions & 31 deletions pymilvus/orm/iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import time
from copy import deepcopy
from typing import Any, Dict, List, Optional, TypeVar, Union

Expand Down Expand Up @@ -39,8 +40,7 @@
from .schema import CollectionSchema
from .types import DataType

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.ERROR)
log = logging.getLogger(__name__)
QueryIterator = TypeVar("QueryIterator")
SearchIterator = TypeVar("SearchIterator")

Expand Down Expand Up @@ -77,19 +77,16 @@ def __init__(
self._schema = schema
self._timeout = timeout
self._kwargs = kwargs
self.__set_up_iteration_states()
self.__check_set_batch_size(batch_size)
self._limit = limit
self.__check_set_reduce_stop_for_best()
self._returned_count = 0
self.__setup__pk_prop()
self.__set_up_expr(expr)
self._next_id = None
self.__seek()
self._cache_id_in_use = NO_CACHE_ID

def __set_up_iteration_states(self):
self._kwargs[ITERATOR_FIELD] = "True"

def __check_set_reduce_stop_for_best(self):
if self._kwargs.get(REDUCE_STOP_FOR_BEST, True):
self._kwargs[REDUCE_STOP_FOR_BEST] = "True"
Expand All @@ -115,25 +112,47 @@ def __set_up_expr(self, expr: str):

def __seek(self):
self._cache_id_in_use = NO_CACHE_ID
if self._kwargs.get(OFFSET, 0) == 0:
self._next_id = None
return

first_cursor_kwargs = self._kwargs.copy()
first_cursor_kwargs[OFFSET] = 0
# offset may be too large, needed to seek in multiple times
first_cursor_kwargs[MILVUS_LIMIT] = self._kwargs[OFFSET]

res = self._conn.query(
collection_name=self._collection_name,
expr=self._expr,
output_field=self._output_fields,
partition_name=self._partition_names,
timeout=self._timeout,
**first_cursor_kwargs,
)
self.__update_cursor(res)
self._kwargs[OFFSET] = 0
offset = self._kwargs.get(OFFSET, 0)
if offset > 0:
seek_params = self._kwargs.copy()
seek_params[OFFSET] = 0
# offset may be too large, needed to seek in multiple times
seek_params[ITERATOR_FIELD] = "False"
seek_params[REDUCE_STOP_FOR_BEST] = "False"
start_time = time.time()

def seek_offset_by_batch(batch: int, expr: str) -> int:
seek_params[MILVUS_LIMIT] = batch
res = self._conn.query(
collection_name=self._collection_name,
expr=expr,
output_field=[],
partition_name=self._partition_names,
timeout=self._timeout,
**seek_params,
)
self.__update_cursor(res)
return len(res)

while offset > 0:
batch_size = min(MAX_BATCH_SIZE, offset)
next_expr = self.__setup_next_expr()
seeked_count = seek_offset_by_batch(batch_size, next_expr)
log.debug(
f"seeked offset, seek_expr:{next_expr} batch_size:{batch_size} seeked_count:{seeked_count}"
)
if seeked_count == 0:
log.info(
"seek offset has drained all matched results for query iterator, break"
)
break
offset -= seeked_count
self._kwargs[OFFSET] = 0
seek_offset_duration = time.time() - start_time
log.info(
f"Finish seek offset for query iterator, offset:{offset}, current_pk_cursor:{self._next_id}, "
f"duration:{seek_offset_duration}"
)

def __maybe_cache(self, result: List):
if len(result) < 2 * self._kwargs[BATCH_SIZE]:
Expand All @@ -156,6 +175,7 @@ def next(self):
else:
iterator_cache.release_cache(self._cache_id_in_use)
current_expr = self.__setup_next_expr()
log.debug(f"query_iterator_next_expr:{current_expr}")
res = self._conn.query(
collection_name=self._collection_name,
expr=current_expr,
Expand Down Expand Up @@ -194,7 +214,7 @@ def __setup__pk_prop(self):
if self._pk_field_name is None or self._pk_field_name == "":
raise MilvusException(message="schema must contain pk field, broke")

def __setup_next_expr(self) -> None:
def __setup_next_expr(self) -> str:
current_expr = self._expr
if self._next_id is None:
return current_expr
Expand Down Expand Up @@ -339,7 +359,7 @@ def __init_search_iterator(self):
"Cannot init search iterator because init page contains no matched rows, "
"please check the radius and range_filter set up by searchParams"
)
LOGGER.error(message)
log.error(message)
self._cache_id = NO_CACHE_ID
self._init_success = False
return
Expand All @@ -364,14 +384,14 @@ def __update_width(self, page: SearchPage):
def __set_up_range_parameters(self, page: SearchPage):
self.__update_width(page)
self._tail_band = page[-1].distance
LOGGER.debug(
log.debug(
f"set up init parameter for searchIterator width:{self._width} tail_band:{self._tail_band}"
)

def __check_reached_limit(self) -> bool:
if self._limit == UNLIMITED or self._returned_count < self._limit:
return False
LOGGER.debug(
log.debug(
f"reached search limit:{self._limit}, returned_count:{self._returned_count}, directly return"
)
return True
Expand Down Expand Up @@ -528,7 +548,7 @@ def __try_search_fill(self) -> SearchPage:
if len(final_page) >= self._iterator_params[BATCH_SIZE]:
break
if try_time > MAX_TRY_TIME:
LOGGER.warning(f"Search probe exceed max try times:{MAX_TRY_TIME} directly break")
log.warning(f"Search probe exceed max try times:{MAX_TRY_TIME} directly break")
break
# if there's a ring containing no vectors matched, then we need to extend
# the ring continually to avoid empty ring problem
Expand All @@ -538,6 +558,7 @@ def __try_search_fill(self) -> SearchPage:
def __execute_next_search(
self, next_params: dict, next_expr: str, to_extend_batch: bool
) -> SearchPage:
log.debug(f"search_iterator_next_expr:{next_expr}, next_params:{next_params}")
res = self._conn.search(
self._iterator_params["collection_name"],
self._iterator_params["data"],
Expand Down Expand Up @@ -592,7 +613,7 @@ def __next_params(self, coefficient: int):
else:
next_params[PARAMS][RADIUS] = next_radius
next_params[PARAMS][RANGE_FILTER] = self._tail_band
LOGGER.debug(
log.debug(
f"next round search iteration radius:{next_params[PARAMS][RADIUS]},"
f"range_filter:{next_params[PARAMS][RANGE_FILTER]},"
f"coefficient:{coefficient}"
Expand Down

0 comments on commit 2c06080

Please sign in to comment.