From 5e3707c3e785f8d51def742649de915af2cc5140 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Fri, 15 Dec 2023 11:48:28 +0530 Subject: [PATCH] test: discover searches --- .../java/io/qdrant/client/QdrantClient.java | 4 +- .../java/io/qdrant/client/PointsTest.java | 54 +++++++++++++++++-- 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/main/java/io/qdrant/client/QdrantClient.java b/src/main/java/io/qdrant/client/QdrantClient.java index a9e85ac..569d3e9 100644 --- a/src/main/java/io/qdrant/client/QdrantClient.java +++ b/src/main/java/io/qdrant/client/QdrantClient.java @@ -2269,7 +2269,7 @@ public ListenableFuture> discoverAsync(DiscoverPoints request, * Constrained by the context. * * @param collectionName The name of the collection - * @param request The list for discover point searches + * @param discoverSearches The list for discover point searches * @param readConsistency Options for specifying read consistency guarantees * @return a new instance of {@link ListenableFuture} */ @@ -2286,7 +2286,7 @@ public ListenableFuture> discoverBatchAsync( * Constrained by the context. * * @param collectionName The name of the collection - * @param request The list for discover point searches + * @param discoverSearches The list for discover point searches * @param readConsistency Options for specifying read consistency guarantees * @param timeout The timeout for the call. * @return a new instance of {@link ListenableFuture} diff --git a/src/test/java/io/qdrant/client/PointsTest.java b/src/test/java/io/qdrant/client/PointsTest.java index 82c30c5..75cb540 100644 --- a/src/test/java/io/qdrant/client/PointsTest.java +++ b/src/test/java/io/qdrant/client/PointsTest.java @@ -13,6 +13,7 @@ import org.testcontainers.shaded.com.google.common.collect.ImmutableMap; import org.testcontainers.shaded.com.google.common.collect.ImmutableSet; import io.qdrant.client.container.QdrantContainer; +import io.qdrant.client.grpc.Points.DiscoverPoints; import java.util.List; import java.util.concurrent.ExecutionException; @@ -45,8 +46,9 @@ import static io.qdrant.client.ConditionFactory.hasId; import static io.qdrant.client.ConditionFactory.matchKeyword; import static io.qdrant.client.PointIdFactory.id; +import static io.qdrant.client.TargetVectorFactory.targetVector; import static io.qdrant.client.ValueFactory.value; -import static io.qdrant.client.VectorsFactory.vector; +import static io.qdrant.client.VectorFactory.vector; @Testcontainers class PointsTest { @@ -307,7 +309,7 @@ public void searchGroups() throws ExecutionException, InterruptedException { ImmutableList.of( PointStruct.newBuilder() .setId(id(10)) - .setVectors(VectorsFactory.vector(30f, 31f)) + .setVectors(VectorsFactory.vectors(30f, 31f)) .putAllPayload(ImmutableMap.of("foo", value("hello"))) .build() ) @@ -404,7 +406,7 @@ public void recommendGroups() throws ExecutionException, InterruptedException { ImmutableList.of( PointStruct.newBuilder() .setId(id(10)) - .setVectors(VectorsFactory.vector(30f, 31f)) + .setVectors(VectorsFactory.vectors(30f, 31f)) .putAllPayload(ImmutableMap.of("foo", value("hello"))) .build() ) @@ -423,6 +425,48 @@ public void recommendGroups() throws ExecutionException, InterruptedException { assertEquals(2, groups.get(0).getHitsCount()); } + @Test + public void discover() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List points = client.discoverAsync(DiscoverPoints.newBuilder() + .setCollectionName(testName) + .setTarget(targetVector(vector(ImmutableList.of(10.4f, 11.4f)))) + .setLimit(1) + .build()).get(); + + assertEquals(1, points.size()); + assertEquals(id(9), points.get(0).getId()); + } + + @Test + public void discoverBatch() throws ExecutionException, InterruptedException { + createAndSeedCollection(testName); + + List batchResults = client.discoverBatchAsync( + testName, + ImmutableList.of( + DiscoverPoints.newBuilder() + .setCollectionName(testName) + .setTarget(targetVector(vector(ImmutableList.of(10.4f, 11.4f)))) + .setLimit(1) + .build(), + DiscoverPoints.newBuilder() + .setCollectionName(testName) + .setTarget(targetVector(vector(ImmutableList.of(3.5f, 4.5f)))) + .setLimit(1) + .build()), + null).get(); + + assertEquals(2, batchResults.size()); + BatchResult result = batchResults.get(0); + assertEquals(1, result.getResultCount()); + assertEquals(id(9), result.getResult(0).getId()); + result = batchResults.get(1); + assertEquals(1, result.getResultCount()); + assertEquals(id(8), result.getResult(0).getId()); + } + @Test public void count() throws ExecutionException, InterruptedException { createAndSeedCollection(testName); @@ -512,7 +556,7 @@ private void createAndSeedCollection(String collectionName) throws ExecutionExce UpdateResult result = client.upsertAsync(collectionName, ImmutableList.of( PointStruct.newBuilder() .setId(id(8)) - .setVectors(VectorsFactory.vector(ImmutableList.of(3.5f, 4.5f))) + .setVectors(VectorsFactory.vectors(ImmutableList.of(3.5f, 4.5f))) .putAllPayload(ImmutableMap.of( "foo", value("hello"), "bar", value(1) @@ -520,7 +564,7 @@ private void createAndSeedCollection(String collectionName) throws ExecutionExce .build(), PointStruct.newBuilder() .setId(id(9)) - .setVectors(VectorsFactory.vector(ImmutableList.of(10.5f, 11.5f))) + .setVectors(VectorsFactory.vectors(ImmutableList.of(10.5f, 11.5f))) .putAllPayload(ImmutableMap.of( "foo", value("goodbye"), "bar", value(2)