Skip to content

Commit

Permalink
test: discover searches
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Dec 15, 2023
1 parent 77aecbf commit 5e3707c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/main/java/io/qdrant/client/QdrantClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -2269,7 +2269,7 @@ public ListenableFuture<List<ScoredPoint>> discoverAsync(DiscoverPoints request,
* Constrained by the context.
*
* @param collectionName The name of the collection
* @param request The list for discover point searches
* @param discoverSearches The list for discover point searches
* @param readConsistency Options for specifying read consistency guarantees
* @return a new instance of {@link ListenableFuture}
*/
Expand All @@ -2286,7 +2286,7 @@ public ListenableFuture<List<BatchResult>> discoverBatchAsync(
* Constrained by the context.
*
* @param collectionName The name of the collection
* @param request The list for discover point searches
* @param discoverSearches The list for discover point searches
* @param readConsistency Options for specifying read consistency guarantees
* @param timeout The timeout for the call.
* @return a new instance of {@link ListenableFuture}
Expand Down
54 changes: 49 additions & 5 deletions src/test/java/io/qdrant/client/PointsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.testcontainers.shaded.com.google.common.collect.ImmutableMap;
import org.testcontainers.shaded.com.google.common.collect.ImmutableSet;
import io.qdrant.client.container.QdrantContainer;
import io.qdrant.client.grpc.Points.DiscoverPoints;

import java.util.List;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -45,8 +46,9 @@
import static io.qdrant.client.ConditionFactory.hasId;
import static io.qdrant.client.ConditionFactory.matchKeyword;
import static io.qdrant.client.PointIdFactory.id;
import static io.qdrant.client.TargetVectorFactory.targetVector;
import static io.qdrant.client.ValueFactory.value;
import static io.qdrant.client.VectorsFactory.vector;
import static io.qdrant.client.VectorFactory.vector;

@Testcontainers
class PointsTest {
Expand Down Expand Up @@ -307,7 +309,7 @@ public void searchGroups() throws ExecutionException, InterruptedException {
ImmutableList.of(
PointStruct.newBuilder()
.setId(id(10))
.setVectors(VectorsFactory.vector(30f, 31f))
.setVectors(VectorsFactory.vectors(30f, 31f))
.putAllPayload(ImmutableMap.of("foo", value("hello")))
.build()
)
Expand Down Expand Up @@ -404,7 +406,7 @@ public void recommendGroups() throws ExecutionException, InterruptedException {
ImmutableList.of(
PointStruct.newBuilder()
.setId(id(10))
.setVectors(VectorsFactory.vector(30f, 31f))
.setVectors(VectorsFactory.vectors(30f, 31f))
.putAllPayload(ImmutableMap.of("foo", value("hello")))
.build()
)
Expand All @@ -423,6 +425,48 @@ public void recommendGroups() throws ExecutionException, InterruptedException {
assertEquals(2, groups.get(0).getHitsCount());
}

@Test
public void discover() throws ExecutionException, InterruptedException {
createAndSeedCollection(testName);

List<ScoredPoint> points = client.discoverAsync(DiscoverPoints.newBuilder()
.setCollectionName(testName)
.setTarget(targetVector(vector(ImmutableList.of(10.4f, 11.4f))))
.setLimit(1)
.build()).get();

assertEquals(1, points.size());
assertEquals(id(9), points.get(0).getId());
}

@Test
public void discoverBatch() throws ExecutionException, InterruptedException {
createAndSeedCollection(testName);

List<BatchResult> batchResults = client.discoverBatchAsync(
testName,
ImmutableList.of(
DiscoverPoints.newBuilder()
.setCollectionName(testName)
.setTarget(targetVector(vector(ImmutableList.of(10.4f, 11.4f))))
.setLimit(1)
.build(),
DiscoverPoints.newBuilder()
.setCollectionName(testName)
.setTarget(targetVector(vector(ImmutableList.of(3.5f, 4.5f))))
.setLimit(1)
.build()),
null).get();

assertEquals(2, batchResults.size());
BatchResult result = batchResults.get(0);
assertEquals(1, result.getResultCount());
assertEquals(id(9), result.getResult(0).getId());
result = batchResults.get(1);
assertEquals(1, result.getResultCount());
assertEquals(id(8), result.getResult(0).getId());
}

@Test
public void count() throws ExecutionException, InterruptedException {
createAndSeedCollection(testName);
Expand Down Expand Up @@ -512,15 +556,15 @@ private void createAndSeedCollection(String collectionName) throws ExecutionExce
UpdateResult result = client.upsertAsync(collectionName, ImmutableList.of(
PointStruct.newBuilder()
.setId(id(8))
.setVectors(VectorsFactory.vector(ImmutableList.of(3.5f, 4.5f)))
.setVectors(VectorsFactory.vectors(ImmutableList.of(3.5f, 4.5f)))
.putAllPayload(ImmutableMap.of(
"foo", value("hello"),
"bar", value(1)
))
.build(),
PointStruct.newBuilder()
.setId(id(9))
.setVectors(VectorsFactory.vector(ImmutableList.of(10.5f, 11.5f)))
.setVectors(VectorsFactory.vectors(ImmutableList.of(10.5f, 11.5f)))
.putAllPayload(ImmutableMap.of(
"foo", value("goodbye"),
"bar", value(2)
Expand Down

0 comments on commit 5e3707c

Please sign in to comment.