diff --git a/README.md b/README.md index 2329365..e577abd 100644 --- a/README.md +++ b/README.md @@ -34,20 +34,20 @@ To install the library, add the following lines to your build config file. io.qdrant client - 1.10.0 + 1.11.0 ``` #### SBT ```sbt -libraryDependencies += "io.qdrant" % "client" % "1.10.0" +libraryDependencies += "io.qdrant" % "client" % "1.11.0" ``` #### Gradle ```gradle -implementation 'io.qdrant:client:1.10.0' +implementation 'io.qdrant:client:1.11.0' ``` > [!NOTE] diff --git a/gradle.properties b/gradle.properties index 2abb78d..2f21d3f 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,11 +1,8 @@ # The version of qdrant to use to download protos -qdrantProtosVersion=v1.10.0 +qdrantProtosVersion=dev # The version of qdrant docker image to run integration tests against -qdrantVersion=v1.10.0 +qdrantVersion=dev # The version of the client to generate -packageVersion=1.10.0 - -## Extension of the default memory config for the spotless formatter plugin -org.gradle.jvmargs= -Xmx2000m "-XX:MaxMetaspaceSize=1000m" +packageVersion=1.11.0 diff --git a/src/main/java/io/qdrant/client/QdrantClient.java b/src/main/java/io/qdrant/client/QdrantClient.java index d03a8f8..e525b0c 100644 --- a/src/main/java/io/qdrant/client/QdrantClient.java +++ b/src/main/java/io/qdrant/client/QdrantClient.java @@ -67,6 +67,8 @@ import io.qdrant.client.grpc.Points.PointsUpdateOperation; import io.qdrant.client.grpc.Points.QueryBatchPoints; import io.qdrant.client.grpc.Points.QueryBatchResponse; +import io.qdrant.client.grpc.Points.QueryGroupsResponse; +import io.qdrant.client.grpc.Points.QueryPointGroups; import io.qdrant.client.grpc.Points.QueryPoints; import io.qdrant.client.grpc.Points.QueryResponse; import io.qdrant.client.grpc.Points.ReadConsistency; @@ -2771,6 +2773,37 @@ public ListenableFuture> queryBatchAsync( future, QueryBatchResponse::getResultList, MoreExecutors.directExecutor()); } + /** + * Universally query points. Covers all capabilities of search, recommend, discover, filters. + * Grouped by a payload field. + * + * @param request the query request + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> queryGroupsAsync(QueryPointGroups request) { + return queryGroupsAsync(request, null); + } + + /** + * Universally query points. Covers all capabilities of search, recommend, discover, filters. + * Grouped by a payload field. + * + * @param request the query request + * @param timeout the timeout for the call. + * @return a new instance of {@link ListenableFuture} + */ + public ListenableFuture> queryGroupsAsync( + QueryPointGroups request, @Nullable Duration timeout) { + Preconditions.checkArgument( + !request.getCollectionName().isEmpty(), "Collection name must not be empty"); + + logger.debug("Query groups on '{}'", request.getCollectionName()); + ListenableFuture future = getPoints(timeout).queryGroups(request); + addLogFailureCallback(future, "Query groups"); + return Futures.transform( + future, response -> response.getResult().getGroupsList(), MoreExecutors.directExecutor()); + } + // region Snapshot Management /** diff --git a/src/main/java/io/qdrant/client/QueryFactory.java b/src/main/java/io/qdrant/client/QueryFactory.java index 275ab6a..a229942 100644 --- a/src/main/java/io/qdrant/client/QueryFactory.java +++ b/src/main/java/io/qdrant/client/QueryFactory.java @@ -10,6 +10,7 @@ import io.qdrant.client.grpc.Points.PointId; import io.qdrant.client.grpc.Points.Query; import io.qdrant.client.grpc.Points.RecommendInput; +import io.qdrant.client.grpc.Points.Sample; import io.qdrant.client.grpc.Points.VectorInput; import java.util.List; import java.util.UUID; @@ -172,5 +173,15 @@ public static Query nearestMultiVector(List> vectors) { return Query.newBuilder().setNearest(multiVectorInput(vectors)).build(); } + /** + * Creates a {@link Query} for sampling. + * + * @param sample An instance of {@link Sample} + * @return A new instance of {@link Query} + */ + public static Query sample(Sample sample) { + return Query.newBuilder().setSample(sample).build(); + } + // endregion } diff --git a/src/test/java/io/qdrant/client/PointsTest.java b/src/test/java/io/qdrant/client/PointsTest.java index 4422bd8..9549a24 100644 --- a/src/test/java/io/qdrant/client/PointsTest.java +++ b/src/test/java/io/qdrant/client/PointsTest.java @@ -6,6 +6,7 @@ import static io.qdrant.client.QueryFactory.fusion; import static io.qdrant.client.QueryFactory.nearest; import static io.qdrant.client.QueryFactory.orderBy; +import static io.qdrant.client.QueryFactory.sample; import static io.qdrant.client.TargetVectorFactory.targetVector; import static io.qdrant.client.ValueFactory.value; import static io.qdrant.client.VectorFactory.vector; @@ -38,10 +39,12 @@ import io.qdrant.client.grpc.Points.PointsUpdateOperation.ClearPayload; import io.qdrant.client.grpc.Points.PointsUpdateOperation.UpdateVectors; import io.qdrant.client.grpc.Points.PrefetchQuery; +import io.qdrant.client.grpc.Points.QueryPointGroups; import io.qdrant.client.grpc.Points.QueryPoints; import io.qdrant.client.grpc.Points.RecommendPointGroups; import io.qdrant.client.grpc.Points.RecommendPoints; import io.qdrant.client.grpc.Points.RetrievedPoint; +import io.qdrant.client.grpc.Points.Sample; import io.qdrant.client.grpc.Points.ScoredPoint; import io.qdrant.client.grpc.Points.ScrollPoints; import io.qdrant.client.grpc.Points.ScrollResponse; @@ -50,6 +53,7 @@ import io.qdrant.client.grpc.Points.UpdateResult; import io.qdrant.client.grpc.Points.UpdateStatus; import io.qdrant.client.grpc.Points.Vectors; +import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -596,7 +600,7 @@ public void batchPointUpdate() throws ExecutionException, InterruptedException { createAndSeedCollection(testName); List operations = - List.of( + Arrays.asList( PointsUpdateOperation.newBuilder() .setClearPayload( ClearPayload.newBuilder() @@ -757,6 +761,58 @@ public void queryWithPrefetchAndFusion() throws ExecutionException, InterruptedE assertEquals(2, points.size()); } + @Test + public void queryWithSampling() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List points = + client + .queryAsync( + QueryPoints.newBuilder() + .setCollectionName(testName) + .setQuery(sample(Sample.Random)) + .setLimit(1) + .build()) + .get(); + + assertEquals(1, points.size()); + } + + @Test + public void queryGroups() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + client + .upsertAsync( + testName, + ImmutableList.of( + PointStruct.newBuilder() + .setId(id(10)) + .setVectors(VectorsFactory.vectors(30f, 31f)) + .putAllPayload(ImmutableMap.of("foo", value("hello"))) + .build())) + .get(); + // 3 points in total, 2 with "foo" = "hello" and 1 with "foo" = "goodbye" + + List groups = + client + .queryGroupsAsync( + QueryPointGroups.newBuilder() + .setCollectionName(testName) + .setQuery(nearest(ImmutableList.of(10.4f, 11.4f))) + .setGroupBy("foo") + .setGroupSize(2) + .setLimit(10) + .build()) + .get(); + + assertEquals(2, groups.size()); + // A group with 2 hits because of 2 points with "foo" = "hello" + assertEquals(1, groups.stream().filter(g -> g.getHitsCount() == 2).count()); + // A group with 1 hit because of 1 point with "foo" = "goodbye" + assertEquals(1, groups.stream().filter(g -> g.getHitsCount() == 1).count()); + } + private void createAndSeedCollection(String collectionName) throws ExecutionException, InterruptedException { CreateCollection request =