diff --git a/gradle.properties b/gradle.properties index 8db2040..8126872 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,8 +1,8 @@ # The version of qdrant to use to download protos -qdrantProtosVersion=v1.9.5 +qdrantProtosVersion=v1.10.0 # The version of qdrant docker image to run integration tests against -qdrantVersion=v1.9.5 +qdrantVersion=v1.10.0 # The version of the client to generate -packageVersion=1.9.1 +packageVersion=1.10.0 \ No newline at end of file diff --git a/src/main/java/io/qdrant/client/QdrantClient.java b/src/main/java/io/qdrant/client/QdrantClient.java index 5281e52..e71dea0 100644 --- a/src/main/java/io/qdrant/client/QdrantClient.java +++ b/src/main/java/io/qdrant/client/QdrantClient.java @@ -77,6 +77,10 @@ import io.qdrant.client.grpc.Points.PointsOperationResponse; import io.qdrant.client.grpc.Points.PointsSelector; import io.qdrant.client.grpc.Points.PointsUpdateOperation; +import io.qdrant.client.grpc.Points.QueryBatchPoints; +import io.qdrant.client.grpc.Points.QueryBatchResponse; +import io.qdrant.client.grpc.Points.QueryPoints; +import io.qdrant.client.grpc.Points.QueryResponse; import io.qdrant.client.grpc.Points.ReadConsistency; import io.qdrant.client.grpc.Points.RecommendBatchPoints; import io.qdrant.client.grpc.Points.RecommendBatchResponse; @@ -2746,6 +2750,103 @@ public ListenableFuture countAsync( return Futures.transform(future, response -> response.getResult().getCount(), MoreExecutors.directExecutor()); } + /** + * Universally query points. + * Covers all capabilities of search, recommend, discover, filters. + * Also enables hybrid and multi-stage queries. + * + * @param request the query request + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> queryAsync(QueryPoints request) { + return queryAsync(request, null); + } + + /** + * Universally query points. + * Covers all capabilities of search, recommend, discover, filters. + * Also enables hybrid and multi-stage queries. + * + * @param request the query request + * @param timeout the timeout for the call. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> queryAsync(QueryPoints request, @Nullable Duration timeout) { + Preconditions.checkArgument( + !request.getCollectionName().isEmpty(), + "Collection name must not be empty"); + + logger.debug("Query on '{}'", request.getCollectionName()); + ListenableFuture future = getPoints(timeout).query(request); + addLogFailureCallback(future, "Query"); + return Futures.transform(future, QueryResponse::getResultList, MoreExecutors.directExecutor()); + } + + /** + * Universally query points in batch. + * Covers all capabilities of search, recommend, discover, filters. + * Also enables hybrid and multi-stage queries. + * + * @param collectionName The name of the collection + * @param queries The queries to be performed in the batch. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> queryBatchAsync( + String collectionName, + List queries + ) { + return queryBatchAsync(collectionName, queries, null, null); + } + + /** + * Universally query points in batch. + * Covers all capabilities of search, recommend, discover, filters. + * Also enables hybrid and multi-stage queries. + * + * @param collectionName The name of the collection + * @param queries The queries to be performed in the batch. + * @param readConsistency Options for specifying read consistency guarantees. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> queryBatchAsync( + String collectionName, + List queries, + @Nullable ReadConsistency readConsistency + ) { + return queryBatchAsync(collectionName, queries, readConsistency, null); + } + + /** + * Universally query points in batch. + * Covers all capabilities of search, recommend, discover, filters. + * Also enables hybrid and multi-stage queries. + * + * @param collectionName The name of the collection + * @param queries The queries to be performed in the batch. + * @param readConsistency Options for specifying read consistency guarantees. + * @param timeout The timeout for the call. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> queryBatchAsync( + String collectionName, + List queries, + @Nullable ReadConsistency readConsistency, + @Nullable Duration timeout + ) { + QueryBatchPoints.Builder requestBuilder = QueryBatchPoints.newBuilder() + .setCollectionName(collectionName) + .addAllQueryPoints(queries); + + if (readConsistency != null) { + requestBuilder.setReadConsistency(readConsistency); + } + + logger.debug("Query batch on '{}'", collectionName); + ListenableFuture future = getPoints(timeout).queryBatch(requestBuilder.build()); + addLogFailureCallback(future, "Query batch"); + return Futures.transform(future, QueryBatchResponse::getResultList, MoreExecutors.directExecutor()); + } + //region Snapshot Management /** diff --git a/src/main/java/io/qdrant/client/QueryFactory.java b/src/main/java/io/qdrant/client/QueryFactory.java new file mode 100644 index 0000000..4b876ca --- /dev/null +++ b/src/main/java/io/qdrant/client/QueryFactory.java @@ -0,0 +1,182 @@ +package io.qdrant.client; + +import java.util.List; +import java.util.UUID; +import io.qdrant.client.grpc.Points.ContextInput; +import io.qdrant.client.grpc.Points.DiscoverInput; +import io.qdrant.client.grpc.Points.Fusion; +import io.qdrant.client.grpc.Points.OrderBy; +import io.qdrant.client.grpc.Points.PointId; +import io.qdrant.client.grpc.Points.Query; +import io.qdrant.client.grpc.Points.RecommendInput; +import io.qdrant.client.grpc.Points.VectorInput; + +import static io.qdrant.client.VectorInputFactory.vectorInput; +import static io.qdrant.client.VectorInputFactory.multiVectorInput; + + +/** + * Convenience methods for constructing {@link Query} + */ +public final class QueryFactory { + private QueryFactory() { + } + + /** + * Creates a {@link Query} for recommendation. + * + * @param input An instance of {@link RecommendInput} + * @return a new instance of {@link Query} + */ + public static Query recommend(RecommendInput input) { + return Query.newBuilder().setRecommend(input).build(); + } + + /** + * Creates a {@link Query} for discovery. + * + * @param input An instance of {@link DiscoverInput} + * @return a new instance of {@link Query} + */ + public static Query discover(DiscoverInput input) { + return Query.newBuilder().setDiscover(input).build(); + } + + /** + * Creates a {@link Query} for context search. + * + * @param input An instance of {@link ContextInput} + * @return a new instance of {@link Query} + */ + public static Query context(ContextInput input) { + return Query.newBuilder().setContext(input).build(); + } + + /** + * Creates a {@link Query} for pre-fetch results fusion. + * + * @param fusion An instance of {@link Fusion} + * @return a new instance of {@link Query} + */ + public static Query fusion(Fusion fusion) { + return Query.newBuilder().setFusion(fusion).build(); + } + + /** + * Creates a {@link Query} to order points by a payload field. + * + * @param key Name of the payload field to order by + * @return a new instance of {@link Query} + */ + public static Query orderBy(String key) { + OrderBy orderBy = OrderBy.newBuilder().setKey(key).build(); + return Query.newBuilder().setOrderBy(orderBy).build(); + } + + /** + * Creates a {@link Query} to order points by a payload field. + * + * @param orderBy An instance of {@link OrderBy} + * @return a new instance of {@link Query} + */ + public static Query orderBy(OrderBy orderBy) { + return Query.newBuilder().setOrderBy(orderBy).build(); + } + + // region Nearest search queries + + /** + * Creates a {@link Query} for nearest search. + * + * @param input An instance of {@link VectorInput} + * @return a new instance of {@link Query} + */ + public static Query nearest(VectorInput input) { + return Query.newBuilder().setNearest(input).build(); + } + + /** + * Creates a {@link Query} from a list of floats + * + * @param values A map of vector names to values + * @return A new instance of {@link Query} + */ + public static Query nearest(List < Float > values) { + return Query.newBuilder().setNearest(vectorInput(values)).build(); + } + + /** + * Creates a {@link Query} from a list of floats + * + * @param values A list of values + * @return A new instance of {@link Query} + */ + public static Query nearest(float...values) { + return Query.newBuilder().setNearest(vectorInput(values)).build(); + } + + /** + * Creates a {@link Query} from a list of floats and integers as indices + * + * @param values The list of floats representing the vector. + * @param indices The list of integers representing the indices. + * @return A new instance of {@link Query} + */ + public static Query nearest(List < Float > values, List < Integer > indices) { + return Query.newBuilder().setNearest(vectorInput(values, indices)).build(); + } + + /** + * Creates a {@link Query} from a nested array of floats representing a multi + * vector + * + * @param vectors The nested array of floats. + * @return A new instance of {@link Query} + */ + public static Query nearest(float[][] vectors) { + return Query.newBuilder().setNearest(multiVectorInput(vectors)).build(); + } + + /** + * Creates a {@link Query} from a {@link long} + * + * @param id The point id + * @return a new instance of {@link Query} + */ + public static Query nearest(long id) { + return Query.newBuilder().setNearest(vectorInput(id)).build(); + } + + /** + * Creates a {@link Query} from a {@link UUID} + * + * @param id The pint id + * @return a new instance of {@link Query} + */ + public static Query nearest(UUID id) { + return Query.newBuilder().setNearest(vectorInput(id)).build(); + } + + /** + * Creates a {@link Query} from a {@link PointId} + * + * @param id The pint id + * @return a new instance of {@link Query} + */ + public static Query nearest(PointId id) { + return Query.newBuilder().setNearest(vectorInput(id)).build(); + } + + /** + * Creates a {@link Query} from a nested list of floats representing a multi + * vector + * + * @param vectors The nested list of floats. + * @return A new instance of {@link Query} + */ + public static Query nearestMultiVector(List < List < Float >> vectors) { + return Query.newBuilder().setNearest(multiVectorInput(vectors)).build(); + } + + // endregion +} diff --git a/src/main/java/io/qdrant/client/VectorFactory.java b/src/main/java/io/qdrant/client/VectorFactory.java index 0e7ab18..6bea026 100644 --- a/src/main/java/io/qdrant/client/VectorFactory.java +++ b/src/main/java/io/qdrant/client/VectorFactory.java @@ -1,6 +1,8 @@ package io.qdrant.client; +import java.util.ArrayList; import java.util.List; +import java.util.stream.Collectors; import com.google.common.primitives.Floats; @@ -49,4 +51,42 @@ public static Vector vector(List vector, List indices) { .setIndices(SparseIndices.newBuilder().addAllData(indices).build()) .build(); } + + /** + * Creates a multi vector from a nested list of floats + * + * @param vectors The nested list of floats representing the multi vector. + * @return A new instance of {@link Vector} + */ + public static Vector multiVector(List> vectors) { + int vectorSize = vectors.size(); + List flatVector = vectors.stream().flatMap(List::stream).collect(Collectors.toList()); + + return Vector.newBuilder() + .addAllData(flatVector) + .setVectorsCount(vectorSize) + .build(); + } + + /** + * Creates a multi vector from a nested array of floats + * + * @param vectors The nested array of floats representing the multi vector. + * @return A new instance of {@link Vector} + */ + public static Vector multiVector(float[][] vectors) { + int vectorSize = vectors.length; + + List flatVector = new ArrayList<>(); + for (float[] vector : vectors) { + for (float value : vector) { + flatVector.add(value); + } + } + + return Vector.newBuilder() + .addAllData(flatVector) + .setVectorsCount(vectorSize) + .build(); + } } diff --git a/src/main/java/io/qdrant/client/VectorInputFactory.java b/src/main/java/io/qdrant/client/VectorInputFactory.java new file mode 100644 index 0000000..b14f156 --- /dev/null +++ b/src/main/java/io/qdrant/client/VectorInputFactory.java @@ -0,0 +1,122 @@ +package io.qdrant.client; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.stream.Collectors; +import com.google.common.primitives.Floats; +import io.qdrant.client.grpc.Points.DenseVector; +import io.qdrant.client.grpc.Points.MultiDenseVector; +import io.qdrant.client.grpc.Points.PointId; +import io.qdrant.client.grpc.Points.SparseVector; +import io.qdrant.client.grpc.Points.VectorInput; +import static io.qdrant.client.PointIdFactory.id; + +/** + * Convenience methods for constructing {@link VectorInput} + */ +public final class VectorInputFactory { + private VectorInputFactory() {} + + /** + * Creates a {@link VectorInput} from a list of floats + * + * @param values A map of vector names to values + * @return A new instance of {@link VectorInput} + */ + public static VectorInput vectorInput(List < Float > values) { + return VectorInput.newBuilder().setDense(DenseVector.newBuilder().addAllData(values)).build(); + } + /** + * Creates a {@link VectorInput} from a list of floats + * + * @param values A list of values + * @return A new instance of {@link VectorInput} + */ + public static VectorInput vectorInput(float...values) { + return VectorInput.newBuilder().setDense(DenseVector.newBuilder().addAllData(Floats.asList(values))) + .build(); + } + /** + * Creates a {@link VectorInput} from a list of floats and integers as indices + * + * @param vector The list of floats representing the vector. + * @param indices The list of integers representing the indices. + * @return A new instance of {@link VectorInput} + */ + public static VectorInput vectorInput(List < Float > vector, List < Integer > indices) { + return VectorInput.newBuilder() + .setSparse(SparseVector.newBuilder() + .addAllValues(vector) + .addAllIndices(indices) + .build()) + .build(); + } + /** + * Creates a {@link VectorInput} from a nested list of floats representing a multi + * vector + * + * @param vectors The nested list of floats. + * @return A new instance of {@link VectorInput} + */ + public static VectorInput multiVectorInput(List < List < Float >> vectors) { + List < DenseVector > denseVectors = vectors.stream() + .map(v -> DenseVector.newBuilder().addAllData(v).build()) + .collect(Collectors.toList()); + return VectorInput.newBuilder() + .setMultiDense(MultiDenseVector.newBuilder() + .addAllVectors(denseVectors) + .build()) + .build(); + } + /** + * Creates a {@link VectorInput} from a nested array of floats representing a multi + * vector + * + * @param vectors The nested array of floats. + * @return A new instance of {@link VectorInput} + */ + public static VectorInput multiVectorInput(float[][] vectors) { + List < DenseVector > denseVectors = new ArrayList < > (); + for (float[] vector: vectors) { + denseVectors.add(DenseVector.newBuilder().addAllData(Floats.asList(vector)).build()); + } + return VectorInput.newBuilder() + .setMultiDense(MultiDenseVector.newBuilder() + .addAllVectors(denseVectors) + .build()) + .build(); + } + /** + * Creates a {@link VectorInput} from a {@link long} + * + * @param id The point id + * @return a new instance of {@link VectorInput} + */ + public static VectorInput vectorInput(long id) { + return VectorInput.newBuilder() + .setId(id(id)) + .build(); + } + /** + * Creates a {@link VectorInput} from a {@link UUID} + * + * @param id The pint id + * @return a new instance of {@link VectorInput} + */ + public static VectorInput vectorInput(UUID id) { + return VectorInput.newBuilder() + .setId(id(id)) + .build(); + } + /** + * Creates a {@link VectorInput} from a {@link PointId} + * + * @param id The pint id + * @return a new instance of {@link VectorInput} + */ + public static VectorInput vectorInput(PointId id) { + return VectorInput.newBuilder() + .setId(id) + .build(); + } +} \ No newline at end of file diff --git a/src/main/java/io/qdrant/client/VectorsFactory.java b/src/main/java/io/qdrant/client/VectorsFactory.java index b88d7f1..1bfba8a 100644 --- a/src/main/java/io/qdrant/client/VectorsFactory.java +++ b/src/main/java/io/qdrant/client/VectorsFactory.java @@ -50,4 +50,15 @@ public static Vectors vectors(float... values) { .setVector(vector(values)) .build(); } + + /** + * Creates a vector + * @param vector An instance of {@link Vector} + * @return a new instance of {@link Vectors} + */ + public static Vectors vectors(Vector vector) { + return Vectors.newBuilder() + .setVector(vector) + .build(); + } } diff --git a/src/test/java/io/qdrant/client/PointsTest.java b/src/test/java/io/qdrant/client/PointsTest.java index ecf0845..b354d32 100644 --- a/src/test/java/io/qdrant/client/PointsTest.java +++ b/src/test/java/io/qdrant/client/PointsTest.java @@ -3,6 +3,9 @@ import static io.qdrant.client.ConditionFactory.hasId; import static io.qdrant.client.ConditionFactory.matchKeyword; import static io.qdrant.client.PointIdFactory.id; +import static io.qdrant.client.QueryFactory.fusion; +import static io.qdrant.client.QueryFactory.nearest; +import static io.qdrant.client.QueryFactory.orderBy; import static io.qdrant.client.TargetVectorFactory.targetVector; import static io.qdrant.client.ValueFactory.value; import static io.qdrant.client.VectorFactory.vector; @@ -40,6 +43,7 @@ import io.qdrant.client.grpc.Points.BatchResult; import io.qdrant.client.grpc.Points.DiscoverPoints; import io.qdrant.client.grpc.Points.Filter; +import io.qdrant.client.grpc.Points.Fusion; import io.qdrant.client.grpc.Points.PointGroup; import io.qdrant.client.grpc.Points.PointStruct; import io.qdrant.client.grpc.Points.PointVectors; @@ -48,6 +52,8 @@ import io.qdrant.client.grpc.Points.PointsUpdateOperation; import io.qdrant.client.grpc.Points.PointsUpdateOperation.ClearPayload; import io.qdrant.client.grpc.Points.PointsUpdateOperation.UpdateVectors; +import io.qdrant.client.grpc.Points.PrefetchQuery; +import io.qdrant.client.grpc.Points.QueryPoints; import io.qdrant.client.grpc.Points.RecommendPointGroups; import io.qdrant.client.grpc.Points.RecommendPoints; import io.qdrant.client.grpc.Points.RetrievedPoint; @@ -630,6 +636,147 @@ public void batchPointUpdate() throws ExecutionException, InterruptedException { response.forEach(result -> assertEquals(UpdateStatus.Completed, result.getStatus())); } + @Test + public void query() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List points = client.queryAsync( + QueryPoints.newBuilder().setCollectionName(testName) + .build()) + .get(); + + assertEquals(2, points.size()); + assertEquals(points.get(0).getId(), id(8)); + assertEquals(points.get(1).getId(), id(9)); + + points = client.queryAsync( + QueryPoints.newBuilder().setCollectionName(testName) + .setLimit(1) + .build()) + .get(); + + assertEquals(1, points.size()); + assertEquals(id(8), points.get(0).getId()); + } + + @Test + public void queryWithFilter() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List points = client.queryAsync( + QueryPoints.newBuilder() + .setCollectionName(testName) + .setFilter(Filter.newBuilder().addMust(matchKeyword("foo", "hello")).build()) + .build()) + .get(); + + assertEquals(1, points.size()); + assertEquals(id(8), points.get(0).getId()); + } + + @Test + public void queryNearestWithID() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List points = client.queryAsync( + QueryPoints.newBuilder() + .setCollectionName(testName) + .setQuery(nearest(8)) + .build()) + .get(); + + assertEquals(1, points.size()); + assertEquals(id(9), points.get(0).getId()); + } + + @Test + public void queryNearestWithVector() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List points = client.queryAsync( + QueryPoints.newBuilder() + .setCollectionName(testName) + .setQuery(nearest(10.5f, 11.5f)) + .build()) + .get(); + + assertEquals(2, points.size()); + assertEquals(id(9), points.get(0).getId()); + } + + @Test + public void queryOrderBy() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + Collections.PayloadIndexParams params = Collections.PayloadIndexParams.newBuilder() + .setIntegerIndexParams( + Collections.IntegerIndexParams.newBuilder().setLookup(false).setRange(true).build()) + .build(); + + UpdateResult resultIndex = client.createPayloadIndexAsync( + testName, + "bar", + PayloadSchemaType.Integer, + params, + true, + null, + null).get(); + + assertEquals(UpdateStatus.Completed, resultIndex.getStatus()); + + CollectionInfo collectionInfo = client.getCollectionInfoAsync(testName).get(); + assertEquals(ImmutableSet.of("bar"), collectionInfo.getPayloadSchemaMap().keySet()); + assertEquals(PayloadSchemaType.Integer, collectionInfo.getPayloadSchemaMap().get("bar").getDataType()); + + List points = client.queryAsync( + QueryPoints.newBuilder() + .setCollectionName(testName) + .setLimit(1) + .setQuery(orderBy("bar")) + .build()) + .get(); + + assertEquals(1, points.size()); + assertEquals(id(8), points.get(0).getId()); + } + + @Test + public void queryWithPrefetchLimit() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List points = client.queryAsync( + QueryPoints.newBuilder() + .addPrefetch(PrefetchQuery.newBuilder() + .setLimit(1) + .build()) + .setCollectionName(testName) + .setQuery(nearest(10.5f, 11.5f)) + .build()) + .get(); + + assertEquals(1, points.size()); + } + + @Test + public void queryWithPrefetchAndFusion() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List points = client.queryAsync( + QueryPoints.newBuilder() + .addPrefetch(PrefetchQuery.newBuilder() + .setQuery(nearest(10.5f, 11.5f)) + .build()) + .addPrefetch(PrefetchQuery.newBuilder() + .setQuery(nearest(3.5f, 4.5f)) + .build()) + .setCollectionName(testName) + .setQuery(fusion(Fusion.RRF)) + .build()) + .get(); + + assertEquals(2, points.size()); + } + private void createAndSeedCollection(String collectionName) throws ExecutionException, InterruptedException { CreateCollection request = CreateCollection.newBuilder() .setCollectionName(collectionName)