From 724d0e3bc7d060514cbf69c53b393c0dd277c791 Mon Sep 17 00:00:00 2001 From: Anush008 <46051506+Anush008@users.noreply.github.com> Date: Sat, 2 Dec 2023 00:53:21 +0530 Subject: [PATCH 1/4] test: api tests --- .../java/io/qdrant/client/QdrantClient.java | 7 +- .../client/QdrantClientCollectionTest.java | 158 ++++++++++++++++++ .../qdrant/client/QdrantClientPointsTest.java | 3 + .../client/QdrantClientServiceTest.java | 44 +++++ .../client/QdrantClientSnapshotsTest.java | 74 ++++++++ 5 files changed, 282 insertions(+), 4 deletions(-) create mode 100644 src/test/java/io/qdrant/client/QdrantClientCollectionTest.java create mode 100644 src/test/java/io/qdrant/client/QdrantClientPointsTest.java create mode 100644 src/test/java/io/qdrant/client/QdrantClientServiceTest.java create mode 100644 src/test/java/io/qdrant/client/QdrantClientSnapshotsTest.java diff --git a/src/main/java/io/qdrant/client/QdrantClient.java b/src/main/java/io/qdrant/client/QdrantClient.java index 5d752221..9a927292 100644 --- a/src/main/java/io/qdrant/client/QdrantClient.java +++ b/src/main/java/io/qdrant/client/QdrantClient.java @@ -191,7 +191,7 @@ public Collections.CollectionOperationResponse createCollection( .setVectorsConfig(config) .setCollectionName(collectionName) .build(); - return collectionsStub.create(details); + return createCollection(details); } /** @@ -216,7 +216,6 @@ public Collections.CollectionOperationResponse createCollection( public Collections.CollectionOperationResponse recreateCollection( String collectionName, long vectorSize, Collections.Distance distance) { - deleteCollection(collectionName); Collections.VectorParams.Builder params = Collections.VectorParams.newBuilder().setDistance(distance).setSize(vectorSize); Collections.VectorsConfig config = @@ -226,7 +225,7 @@ public Collections.CollectionOperationResponse recreateCollection( .setVectorsConfig(config) .setCollectionName(collectionName) .build(); - return collectionsStub.create(details); + return recreateCollection(details); } /** @@ -1190,7 +1189,7 @@ public SnapshotsService.CreateSnapshotResponse createFullSnapshot() { * @param collectionName The name of the collection. * @return The response containing the list of full snapshots. */ - public SnapshotsService.ListSnapshotsResponse listFullSnapshots(String collectionName) { + public SnapshotsService.ListSnapshotsResponse listFullSnapshots() { SnapshotsService.ListFullSnapshotsRequest request = SnapshotsService.ListFullSnapshotsRequest.newBuilder().build(); return snapshotStub.listFull(request); diff --git a/src/test/java/io/qdrant/client/QdrantClientCollectionTest.java b/src/test/java/io/qdrant/client/QdrantClientCollectionTest.java new file mode 100644 index 00000000..af7e9996 --- /dev/null +++ b/src/test/java/io/qdrant/client/QdrantClientCollectionTest.java @@ -0,0 +1,158 @@ +package io.qdrant.client; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.grpc.StatusRuntimeException; +import io.qdrant.client.grpc.Collections; +import io.qdrant.client.grpc.Collections.Distance; +import java.util.UUID; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class QdrantClientCollectionTest { + + private static QdrantClient qdrantClient; + + @BeforeAll + static void setUp() throws Exception { + String qdrantUrl = System.getenv("QDRANT_URL"); + String apiKey = System.getenv("QDRANT_API_KEY"); + + if (qdrantUrl == null) { + qdrantUrl = "http://localhost:6334"; + } + + if (apiKey == null) { + qdrantClient = new QdrantClient(qdrantUrl); + } else { + qdrantClient = new QdrantClient(qdrantUrl, apiKey); + } + } + + @Test + void testAliasOperations() { + String collectionName = UUID.randomUUID().toString(); + String aliasName = UUID.randomUUID().toString(); + + assertThrows( + StatusRuntimeException.class, + () -> { + // This should fail as collection does not exist + qdrantClient.createAlias(collectionName, aliasName); + }); + + qdrantClient.createCollection(collectionName, 6, Distance.Euclid); + assertDoesNotThrow( + () -> { + Collections.CollectionOperationResponse response = + qdrantClient.createAlias(collectionName, aliasName); + assertTrue(response.getResult()); + }); + + Collections.ListAliasesResponse response = qdrantClient.listAliases(); + assertTrue(response.getAliasesCount() == 1); + Collections.AliasDescription alias = response.getAliasesList().get(0); + assertTrue(alias.getCollectionName().equals(collectionName)); + assertTrue(alias.getAliasName().equals(aliasName)); + + String newAliasName = UUID.randomUUID().toString(); + qdrantClient.renameAlias(aliasName, newAliasName); + response = qdrantClient.listAliases(); + assertTrue(response.getAliasesCount() == 1); + + alias = response.getAliasesList().get(0); + assertTrue(alias.getCollectionName().equals(collectionName)); + assertTrue(alias.getAliasName().equals(newAliasName)); + + qdrantClient.deleteAlias(newAliasName); + response = qdrantClient.listAliases(); + } + + @Test + void testListCollections() { + assertDoesNotThrow( + () -> { + Collections.ListCollectionsResponse response = qdrantClient.listCollections(); + assertTrue(response.getCollectionsCount() >= 0); + }); + } + + @Test + void testHasCollection() { + String collectionName = UUID.randomUUID().toString(); + boolean exists = qdrantClient.hasCollection(collectionName); + assertFalse(exists); + + qdrantClient.createCollection(collectionName, 6, Distance.Euclid); + + exists = qdrantClient.hasCollection(collectionName); + assertTrue(exists); + } + + @Test + void testCollectionConfigOperations() { + long vectorSize = 128; + String collectionName = UUID.randomUUID().toString(); + Collections.Distance distance = Collections.Distance.Cosine; + + Collections.VectorParams.Builder params = + Collections.VectorParams.newBuilder().setDistance(distance).setSize(vectorSize); + + Collections.VectorsConfig config = + Collections.VectorsConfig.newBuilder().setParams(params).build(); + + Collections.HnswConfigDiff hnsw = + Collections.HnswConfigDiff.newBuilder().setM(16).setEfConstruct(200).build(); + + Collections.CreateCollection details = + Collections.CreateCollection.newBuilder() + .setVectorsConfig(config) + .setHnswConfig(hnsw) + .setCollectionName(collectionName) + .build(); + + Collections.CollectionOperationResponse response = qdrantClient.createCollection(details); + assertTrue(response.getResult()); + + Collections.GetCollectionInfoResponse info = qdrantClient.getCollectionInfo(collectionName); + assertTrue(info.getResult().getConfig().getHnswConfig().getM() == 16); + + Collections.UpdateCollection updateCollection = + Collections.UpdateCollection.newBuilder() + .setCollectionName(collectionName) + .setHnswConfig(Collections.HnswConfigDiff.newBuilder().setM(32).build()) + .build(); + qdrantClient.updateCollection(updateCollection); + + info = qdrantClient.getCollectionInfo(collectionName); + assertTrue(info.getResult().getConfig().getHnswConfig().getM() == 32); + } + + @Test + void testRecreateCollection() { + String collectionName = UUID.randomUUID().toString(); + + qdrantClient.createCollection(collectionName, 6, Distance.Euclid); + assertDoesNotThrow( + () -> { + qdrantClient.recreateCollection(collectionName, 12, Distance.Dot); + }); + } + + @Test + void testDeleteCollection() { + String collectionName = UUID.randomUUID().toString(); + + Collections.CollectionOperationResponse response = + qdrantClient.deleteCollection(collectionName); + assertFalse(response.getResult()); + + qdrantClient.createCollection(collectionName, 6, Distance.Euclid); + + response = qdrantClient.deleteCollection(collectionName); + assertTrue(response.getResult()); + } +} diff --git a/src/test/java/io/qdrant/client/QdrantClientPointsTest.java b/src/test/java/io/qdrant/client/QdrantClientPointsTest.java new file mode 100644 index 00000000..ca37439b --- /dev/null +++ b/src/test/java/io/qdrant/client/QdrantClientPointsTest.java @@ -0,0 +1,3 @@ +package io.qdrant.client; + +class QdrantClientPointsTest {} diff --git a/src/test/java/io/qdrant/client/QdrantClientServiceTest.java b/src/test/java/io/qdrant/client/QdrantClientServiceTest.java new file mode 100644 index 00000000..0fa4b037 --- /dev/null +++ b/src/test/java/io/qdrant/client/QdrantClientServiceTest.java @@ -0,0 +1,44 @@ +package io.qdrant.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.net.MalformedURLException; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class QdrantClientServiceTest { + + private static QdrantClient qdrantClient; + + @BeforeAll + static void setUp() throws Exception { + String qdrantUrl = System.getenv("QDRANT_URL"); + String apiKey = System.getenv("QDRANT_API_KEY"); + + if (qdrantUrl == null) { + qdrantUrl = "http://localhost:6334"; + } + + if (apiKey == null) { + qdrantClient = new QdrantClient(qdrantUrl); + } else { + qdrantClient = new QdrantClient(qdrantUrl, apiKey); + } + } + + @Test + void testInvalidProtocol() { + assertThrows(IllegalArgumentException.class, () -> new QdrantClient("ftp://localhost:6334")); + } + + @Test + void testMalformedUrl() { + assertThrows(MalformedURLException.class, () -> new QdrantClient("qdrant/qdrant:latest")); + } + + @Test + void testHealthCheck() { + assertEquals(qdrantClient.healthCheck().getTitle(), "qdrant - vector search engine"); + } +} diff --git a/src/test/java/io/qdrant/client/QdrantClientSnapshotsTest.java b/src/test/java/io/qdrant/client/QdrantClientSnapshotsTest.java new file mode 100644 index 00000000..738a798c --- /dev/null +++ b/src/test/java/io/qdrant/client/QdrantClientSnapshotsTest.java @@ -0,0 +1,74 @@ +package io.qdrant.client; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.qdrant.client.grpc.Collections.Distance; +import io.qdrant.client.grpc.SnapshotsService; +import java.util.UUID; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class QdrantClientSnapshotsTest { + + private static QdrantClient qdrantClient; + + @BeforeAll + static void setUp() throws Exception { + String qdrantUrl = System.getenv("QDRANT_URL"); + String apiKey = System.getenv("QDRANT_API_KEY"); + + if (qdrantUrl == null) { + qdrantUrl = "http://localhost:6334"; + } + + if (apiKey == null) { + qdrantClient = new QdrantClient(qdrantUrl); + } else { + qdrantClient = new QdrantClient(qdrantUrl, apiKey); + } + } + + @Test + void testCollectionSnapshots() { + String collectionName = UUID.randomUUID().toString(); + + qdrantClient.createCollection(collectionName, 768, Distance.Cosine); + + assertEquals( + qdrantClient.listSnapshots(collectionName).getSnapshotDescriptionsList().size(), 0); + + assertDoesNotThrow( + () -> { + SnapshotsService.CreateSnapshotResponse response = + qdrantClient.createSnapshot(collectionName); + String snapshotName = response.getSnapshotDescription().getName(); + + assertEquals( + qdrantClient.listSnapshots(collectionName).getSnapshotDescriptionsList().size(), 1); + + qdrantClient.deleteSnapshot(collectionName, snapshotName); + + assertEquals( + qdrantClient.listSnapshots(collectionName).getSnapshotDescriptionsList().size(), 0); + }); + } + + @Test + void testFullSnapshots() { + + assertEquals(qdrantClient.listFullSnapshots().getSnapshotDescriptionsList().size(), 0); + + assertDoesNotThrow( + () -> { + SnapshotsService.CreateSnapshotResponse response = qdrantClient.createFullSnapshot(); + String snapshotName = response.getSnapshotDescription().getName(); + + assertEquals(qdrantClient.listFullSnapshots().getSnapshotDescriptionsList().size(), 1); + + qdrantClient.deleteFullSnapshot(snapshotName); + + assertEquals(qdrantClient.listFullSnapshots().getSnapshotDescriptionsList().size(), 0); + }); + } +} From d5632eeeb4115879da1722051f0fd7bfdeaafd71 Mon Sep 17 00:00:00 2001 From: Anush008 <46051506+Anush008@users.noreply.github.com> Date: Sat, 2 Dec 2023 14:48:47 +0530 Subject: [PATCH 2/4] test: pointClient --- .../io/qdrant/client/utils/VectorUtil.java | 44 +++++ .../qdrant/client/QdrantClientPointsTest.java | 170 +++++++++++++++++- 2 files changed, 213 insertions(+), 1 deletion(-) diff --git a/src/main/java/io/qdrant/client/utils/VectorUtil.java b/src/main/java/io/qdrant/client/utils/VectorUtil.java index 3b2a8648..c6eed09c 100644 --- a/src/main/java/io/qdrant/client/utils/VectorUtil.java +++ b/src/main/java/io/qdrant/client/utils/VectorUtil.java @@ -6,6 +6,7 @@ import io.qdrant.client.grpc.Points.Vectors; import java.util.List; import java.util.Map; +import java.util.Random; /** Utility class for working with vector data. */ public class VectorUtil { @@ -83,4 +84,47 @@ public static PointVectors pointVectors(String id, String name, float... vector) Vectors vectors = Vectors.newBuilder().setVectors(namedVector(name, vector)).build(); return pointVectorsBuilder.setId(PointUtil.pointId(id)).setVectors(vectors).build(); } + + /** + * Generates dummy embeddings of the specified size. + * + * @param size The size of the embeddings to generate. + * @return An array of floats representing the generated embeddings. + * @throws IllegalArgumentException If the size is less than or equal to zero. + */ + public static float[] dummyEmbeddings(int size) { + if (size <= 0) { + throw new IllegalArgumentException("Size must be greater than zero"); + } + + float[] embeddings = new float[size]; + Random random = new Random(); + + for (int i = 0; i < size; i++) { + embeddings[i] = random.nextFloat(); + } + + return embeddings; + } + + /** + * Generates a dummy vector of the specified size. + * + * @param size The size of the vector. + * @return The generated dummy vector. + */ + public static Vector dummyVector(int size) { + return toVector(dummyEmbeddings(size)); + } + + /** + * Generates a dummy named vector of the specified size. + * + * @param name The name of the vector. + * @param size The size of the vector. + * @return The generated dummy vector. + */ + public static NamedVectors dummyNamedVector(String name, int size) { + return namedVector(name, dummyEmbeddings(size)); + } } diff --git a/src/test/java/io/qdrant/client/QdrantClientPointsTest.java b/src/test/java/io/qdrant/client/QdrantClientPointsTest.java index ca37439b..678b77b5 100644 --- a/src/test/java/io/qdrant/client/QdrantClientPointsTest.java +++ b/src/test/java/io/qdrant/client/QdrantClientPointsTest.java @@ -1,3 +1,171 @@ package io.qdrant.client; -class QdrantClientPointsTest {} +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.qdrant.client.grpc.Collections.Distance; +import io.qdrant.client.grpc.Points; +import io.qdrant.client.grpc.Points.Filter; +import io.qdrant.client.grpc.Points.PointId; +import io.qdrant.client.grpc.Points.PointStruct; +import io.qdrant.client.utils.FilterUtil; +import io.qdrant.client.utils.PayloadUtil; +import io.qdrant.client.utils.PointUtil; +import io.qdrant.client.utils.SelectorUtil; +import io.qdrant.client.utils.VectorUtil; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +class QdrantClientPointsTest { + + private static QdrantClient qdrantClient; + + @BeforeAll + static void setUp() throws Exception { + String qdrantUrl = System.getenv("QDRANT_URL"); + String apiKey = System.getenv("QDRANT_API_KEY"); + + if (qdrantUrl == null) { + qdrantUrl = "http://localhost:6334"; + } + + if (apiKey == null) { + qdrantClient = new QdrantClient(qdrantUrl); + } else { + qdrantClient = new QdrantClient(qdrantUrl, apiKey); + } + } + + @Test + void testPointsWithPayloadFilters() { + String collectionName = UUID.randomUUID().toString(); + + UUID pointID = UUID.randomUUID(); + + qdrantClient.createCollection(collectionName, 768, Distance.Cosine); + + PointId[] pointIds = new PointId[] {PointUtil.pointId(pointID)}; + + Points.GetResponse response = + qdrantClient.getPoints( + collectionName, + List.of(pointIds), + SelectorUtil.withVectors(), + SelectorUtil.withPayload(), + null); + + assertEquals(0, response.getResultCount()); + + Map data = new HashMap<>(); + data.put("name", "Anush"); + data.put("age", 32); + + Map nestedData = new HashMap<>(); + nestedData.put("color", "Blue"); + nestedData.put("movie", "Man Of Steel"); + + data.put("favourites", nestedData); + + PointStruct point = + PointUtil.point(pointID, VectorUtil.dummyVector(768), PayloadUtil.toPayload(data)); + + List points = List.of(point); + qdrantClient.upsertPointsBlocking(collectionName, points, null); + response = + qdrantClient.getPoints( + collectionName, + List.of(pointIds), + SelectorUtil.withVectors(), + SelectorUtil.withPayload(), + null); + assertEquals(1, response.getResultCount()); + + Filter filter = + FilterUtil.must( + FilterUtil.fieldCondition("age", FilterUtil.match(32)), + FilterUtil.fieldCondition("name", FilterUtil.match("Anush")), + FilterUtil.fieldCondition("favourites.color", FilterUtil.match("Blue")), + FilterUtil.fieldCondition("favourites.movie", FilterUtil.match("Man Of Steel"))); + qdrantClient.deletePointsBlocking(collectionName, SelectorUtil.filterSelector(filter), null); + + response = + qdrantClient.getPoints( + collectionName, + List.of(pointIds), + SelectorUtil.withVectors(), + SelectorUtil.withPayload(), + null); + + assertEquals(0, response.getResultCount()); + } + + @Test + void testUpsertPoints() { + String collectionName = UUID.randomUUID().toString(); + + UUID pointID = UUID.randomUUID(); + + qdrantClient.createCollection(collectionName, 768, Distance.Cosine); + + PointId[] pointIds = new PointId[] {PointUtil.pointId(pointID)}; + + Points.GetResponse response = + qdrantClient.getPoints( + collectionName, + List.of(pointIds), + SelectorUtil.withVectors(), + SelectorUtil.withPayload(), + null); + + assertEquals(0, response.getResultCount()); + + PointStruct point = PointUtil.point(pointID, VectorUtil.dummyVector(768), null); + List points = List.of(point); + qdrantClient.upsertPointsBlocking(collectionName, points, null); + response = + qdrantClient.getPoints( + collectionName, + List.of(pointIds), + SelectorUtil.withVectors(), + SelectorUtil.withPayload(), + null); + assertEquals(1, response.getResultCount()); + + qdrantClient.deletePointsBlocking(collectionName, SelectorUtil.idsSelector(pointIds), null); + + response = + qdrantClient.getPoints( + collectionName, + List.of(pointIds), + SelectorUtil.withVectors(), + SelectorUtil.withPayload(), + null); + + assertEquals(0, response.getResultCount()); + } + + @Test + void testUpsertPointsBatch() { + String collectionName = UUID.randomUUID().toString(); + + qdrantClient.createCollection(collectionName, 768, Distance.Cosine); + + List points = new ArrayList<>(); + + for (int i = 0; i < 1000; i++) { + UUID pointID = UUID.randomUUID(); + PointStruct point = PointUtil.point(pointID, VectorUtil.dummyVector(768), null); + points.add(point); + } + + assertDoesNotThrow( + () -> { + qdrantClient.upsertPointsBatchBlocking(collectionName, points, null, 100); + }); + } +} From 0ba3d5823779893215edbe4f6977c06c228c994f Mon Sep 17 00:00:00 2001 From: Anush008 <46051506+Anush008@users.noreply.github.com> Date: Sat, 2 Dec 2023 17:49:04 +0530 Subject: [PATCH 3/4] tests: points tests --- .../qdrant/client/QdrantClientPointsTest.java | 104 +++++++++++++++++- .../qdrant/client/utils/PayloadUtilTest.java | 12 +- 2 files changed, 104 insertions(+), 12 deletions(-) diff --git a/src/test/java/io/qdrant/client/QdrantClientPointsTest.java b/src/test/java/io/qdrant/client/QdrantClientPointsTest.java index 678b77b5..c8cd5639 100644 --- a/src/test/java/io/qdrant/client/QdrantClientPointsTest.java +++ b/src/test/java/io/qdrant/client/QdrantClientPointsTest.java @@ -4,10 +4,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import io.qdrant.client.grpc.Collections.Distance; +import io.qdrant.client.grpc.JsonWithInt.Value; import io.qdrant.client.grpc.Points; import io.qdrant.client.grpc.Points.Filter; import io.qdrant.client.grpc.Points.PointId; import io.qdrant.client.grpc.Points.PointStruct; +import io.qdrant.client.grpc.Points.SearchPoints; +import io.qdrant.client.grpc.Points.SearchResponse; import io.qdrant.client.utils.FilterUtil; import io.qdrant.client.utils.PayloadUtil; import io.qdrant.client.utils.PointUtil; @@ -23,6 +26,7 @@ class QdrantClientPointsTest { + private static final int EMBEDDINGS_SIZE = 768; private static QdrantClient qdrantClient; @BeforeAll @@ -47,7 +51,7 @@ void testPointsWithPayloadFilters() { UUID pointID = UUID.randomUUID(); - qdrantClient.createCollection(collectionName, 768, Distance.Cosine); + qdrantClient.createCollection(collectionName, EMBEDDINGS_SIZE, Distance.Cosine); PointId[] pointIds = new PointId[] {PointUtil.pointId(pointID)}; @@ -72,7 +76,8 @@ void testPointsWithPayloadFilters() { data.put("favourites", nestedData); PointStruct point = - PointUtil.point(pointID, VectorUtil.dummyVector(768), PayloadUtil.toPayload(data)); + PointUtil.point( + pointID, VectorUtil.dummyVector(EMBEDDINGS_SIZE), PayloadUtil.toPayload(data)); List points = List.of(point); qdrantClient.upsertPointsBlocking(collectionName, points, null); @@ -110,7 +115,7 @@ void testUpsertPoints() { UUID pointID = UUID.randomUUID(); - qdrantClient.createCollection(collectionName, 768, Distance.Cosine); + qdrantClient.createCollection(collectionName, EMBEDDINGS_SIZE, Distance.Cosine); PointId[] pointIds = new PointId[] {PointUtil.pointId(pointID)}; @@ -124,7 +129,7 @@ void testUpsertPoints() { assertEquals(0, response.getResultCount()); - PointStruct point = PointUtil.point(pointID, VectorUtil.dummyVector(768), null); + PointStruct point = PointUtil.point(pointID, VectorUtil.dummyVector(EMBEDDINGS_SIZE), null); List points = List.of(point); qdrantClient.upsertPointsBlocking(collectionName, points, null); response = @@ -153,13 +158,14 @@ void testUpsertPoints() { void testUpsertPointsBatch() { String collectionName = UUID.randomUUID().toString(); - qdrantClient.createCollection(collectionName, 768, Distance.Cosine); + qdrantClient.createCollection(collectionName, EMBEDDINGS_SIZE, Distance.Cosine); List points = new ArrayList<>(); + // Upsert 1000 points with batching for (int i = 0; i < 1000; i++) { UUID pointID = UUID.randomUUID(); - PointStruct point = PointUtil.point(pointID, VectorUtil.dummyVector(768), null); + PointStruct point = PointUtil.point(pointID, VectorUtil.dummyVector(EMBEDDINGS_SIZE), null); points.add(point); } @@ -168,4 +174,90 @@ void testUpsertPointsBatch() { qdrantClient.upsertPointsBatchBlocking(collectionName, points, null, 100); }); } + + @Test + void testSearchPoints() { + String collectionName = UUID.randomUUID().toString(); + + qdrantClient.createCollection(collectionName, EMBEDDINGS_SIZE, Distance.Cosine); + + List points = new ArrayList<>(); + + // Upsert 100 points + for (int i = 0; i < 100; i++) { + UUID pointID = UUID.randomUUID(); + PointStruct point = PointUtil.point(pointID, VectorUtil.dummyVector(EMBEDDINGS_SIZE), null); + points.add(point); + } + + assertDoesNotThrow( + () -> { + qdrantClient.upsertPointsBlocking(collectionName, points, null); + }); + + SearchPoints request = + SearchPoints.newBuilder() + .setCollectionName(collectionName) + .addAllVector(VectorUtil.dummyEmbeddings(EMBEDDINGS_SIZE)) + .setWithPayload(SelectorUtil.withPayload()) + .setLimit(100) + .build(); + + SearchResponse result = qdrantClient.searchPoints(request); + + assertEquals(result.getResultList().size(), 100); + } + + @Test + void testSetPayloadWithScroll() { + String collectionName = UUID.randomUUID().toString(); + + qdrantClient.createCollection(collectionName, EMBEDDINGS_SIZE, Distance.Cosine); + + List points = new ArrayList<>(); + + // Upsert 100 points + for (int i = 0; i < 100; i++) { + UUID pointID = UUID.randomUUID(); + PointStruct point = PointUtil.point(pointID, VectorUtil.dummyVector(EMBEDDINGS_SIZE), null); + points.add(point); + } + + assertDoesNotThrow( + () -> { + qdrantClient.upsertPointsBlocking(collectionName, points, null); + }); + + Map data = new HashMap<>(); + data.put("name", "Anush"); + data.put("age", 32); + + Map nestedData = new HashMap<>(); + nestedData.put("color", "Blue"); + nestedData.put("movie", "Man of Steel"); + + data.put("favourites", nestedData); + + Map payload = PayloadUtil.toPayload(data); + + qdrantClient.setPayloadBlocking( + collectionName, SelectorUtil.filterSelector(FilterUtil.must()), payload, null); + + Points.ScrollPoints request = + Points.ScrollPoints.newBuilder() + .setCollectionName(collectionName) + .setWithPayload(SelectorUtil.withPayload()) + .setLimit(100) + .build(); + + Points.ScrollResponse response = qdrantClient.scroll(request); + + response + .getResultList() + .forEach( + (point) -> { + assertEquals(PayloadUtil.toMap(point.getPayloadMap()), data); + assertEquals(point.getPayloadMap(), PayloadUtil.toPayload(data)); + }); + } } diff --git a/src/test/java/io/qdrant/client/utils/PayloadUtilTest.java b/src/test/java/io/qdrant/client/utils/PayloadUtilTest.java index da781e0a..30353f6a 100644 --- a/src/test/java/io/qdrant/client/utils/PayloadUtilTest.java +++ b/src/test/java/io/qdrant/client/utils/PayloadUtilTest.java @@ -49,11 +49,11 @@ void testToPayload() { } @Test - void testStructToHashMap() { + void testStructtoMap() { // Test case 1: Empty struct Struct.Builder structBuilder = Struct.newBuilder(); Struct struct = structBuilder.build(); - Map structMap = PayloadUtil.payloadStructToHashMap(struct); + Map structMap = PayloadUtil.toMap(struct); assertTrue(structMap.isEmpty()); // Test case 2: Struct with different value types @@ -61,24 +61,24 @@ void testStructToHashMap() { structBuilder.putFields("age", Value.newBuilder().setIntegerValue(52).build()); structBuilder.putFields("isStudent", Value.newBuilder().setBoolValue(true).build()); struct = structBuilder.build(); - structMap = PayloadUtil.payloadStructToHashMap(struct); + structMap = PayloadUtil.toMap(struct); assertEquals("Elon", structMap.get("name")); assertEquals(52, (int) structMap.get("age")); assertEquals(true, structMap.get("isStudent")); } @Test - void testToHashMap() { + void testtoMap() { // Test case 1: Empty payload Map payload = new HashMap<>(); - Map hashMap = PayloadUtil.toHashMap(payload); + Map hashMap = PayloadUtil.toMap(payload); assertTrue(hashMap.isEmpty()); // Test case 2: Payload with different value types payload.put("name", Value.newBuilder().setStringValue("Elon").build()); payload.put("age", Value.newBuilder().setIntegerValue(52).build()); payload.put("isStudent", Value.newBuilder().setBoolValue(true).build()); - hashMap = PayloadUtil.toHashMap(payload); + hashMap = PayloadUtil.toMap(payload); assertEquals("Elon", hashMap.get("name")); assertEquals(52, hashMap.get("age")); assertEquals(true, hashMap.get("isStudent")); From 15da834c62462a6df45b4ad67dac0e51a422ee92 Mon Sep 17 00:00:00 2001 From: Anush008 <46051506+Anush008@users.noreply.github.com> Date: Sat, 2 Dec 2023 17:49:30 +0530 Subject: [PATCH 4/4] refactor: Updatedpayload methods, utility methods --- .../java/io/qdrant/client/QdrantClient.java | 61 ++++++++++++++++++- .../io/qdrant/client/utils/PayloadUtil.java | 8 +-- .../io/qdrant/client/utils/VectorUtil.java | 7 ++- 3 files changed, 66 insertions(+), 10 deletions(-) diff --git a/src/main/java/io/qdrant/client/QdrantClient.java b/src/main/java/io/qdrant/client/QdrantClient.java index 9a927292..1eb32d79 100644 --- a/src/main/java/io/qdrant/client/QdrantClient.java +++ b/src/main/java/io/qdrant/client/QdrantClient.java @@ -542,6 +542,62 @@ public Points.PointsOperationResponse upsertPointsBatchBlocking( return upsertPointsBatch(collectionName, points, ordering, true, chunkSize); } + /** Internal update method */ + private Points.PointsOperationResponse setPayload( + String collectionName, + Points.PointsSelector points, + Map payload, + Points.WriteOrderingType ordering, + Boolean wait) { + Points.SetPayloadPoints.Builder request = + Points.SetPayloadPoints.newBuilder() + .setCollectionName(collectionName) + .setPointsSelector(points) + .putAllPayload(payload) + .setWait(wait); + + if (ordering != null) { + request.setOrdering(PointUtil.ordering(ordering)); + } + return pointsStub.setPayload(request.build()); + } + + /** + * Sets the payload of the specified points in a collection. Does not wait for the operation to + * complete before returning. + * + * @param collectionName The name of the collection. + * @param points The selector for the points to be updated. + * @param payload The new payload to be assigned to the points. + * @param ordering The ordering of the write operation. + * @return The response of the points operation. + */ + public Points.PointsOperationResponse setPayload( + String collectionName, + Points.PointsSelector points, + Map payload, + Points.WriteOrderingType ordering) { + return setPayload(collectionName, points, payload, ordering, false); + } + + /** + * Sets the payload of the specified points in a collection. Waits for the operation to complete + * before returning. + * + * @param collectionName The name of the collection. + * @param points The selector for the points to be updated. + * @param payload The new payload to be assigned to the points. + * @param ordering The ordering of the write operation. + * @return The response of the points operation. + */ + public Points.PointsOperationResponse setPayloadBlocking( + String collectionName, + Points.PointsSelector points, + Map payload, + Points.WriteOrderingType ordering) { + return setPayload(collectionName, points, payload, ordering, true); + } + /** Internal payload overwrite method */ private Points.PointsOperationResponse overwritePayload( String collectionName, @@ -559,7 +615,7 @@ private Points.PointsOperationResponse overwritePayload( if (ordering != null) { request.setOrdering(PointUtil.ordering(ordering)); } - return pointsStub.setPayload(request.build()); + return pointsStub.overwritePayload(request.build()); } /** @@ -598,7 +654,7 @@ public Points.PointsOperationResponse overwritePayloadBlocking( return overwritePayload(collectionName, points, payload, ordering, true); } - /** Internal payload update method */ + /** Internal payload delete method */ private Points.PointsOperationResponse deletePayload( String collectionName, Points.PointsSelector points, @@ -1186,7 +1242,6 @@ public SnapshotsService.CreateSnapshotResponse createFullSnapshot() { /** * Retrieves a list of full snapshots for a given collection. * - * @param collectionName The name of the collection. * @return The response containing the list of full snapshots. */ public SnapshotsService.ListSnapshotsResponse listFullSnapshots() { diff --git a/src/main/java/io/qdrant/client/utils/PayloadUtil.java b/src/main/java/io/qdrant/client/utils/PayloadUtil.java index 7c89eef5..01fa6938 100644 --- a/src/main/java/io/qdrant/client/utils/PayloadUtil.java +++ b/src/main/java/io/qdrant/client/utils/PayloadUtil.java @@ -49,8 +49,8 @@ public static Map toPayload(Map inputMap) { * @param struct The payload struct to convert. * @return The converted hash map. */ - public static Map payloadStructToHashMap(Struct struct) { - Map structMap = toHashMap(struct.getFieldsMap()); + public static Map toMap(Struct struct) { + Map structMap = toMap(struct.getFieldsMap()); return structMap; } @@ -60,7 +60,7 @@ public static Map payloadStructToHashMap(Struct struct) { * @param payload The payload map to convert. * @return The converted hash map. */ - public static Map toHashMap(Map payload) { + public static Map toMap(Map payload) { Map hashMap = new HashMap<>(); for (Map.Entry entry : payload.entrySet()) { String fieldName = entry.getKey(); @@ -144,7 +144,7 @@ static Object valueToObject(Value value) { } else if (value.hasNullValue()) { return null; } else if (value.hasStructValue()) { - return payloadStructToHashMap(value.getStructValue()); + return toMap(value.getStructValue()); } else if (value.hasListValue()) { return listValueToList(value.getListValue()); } diff --git a/src/main/java/io/qdrant/client/utils/VectorUtil.java b/src/main/java/io/qdrant/client/utils/VectorUtil.java index c6eed09c..d9657bf0 100644 --- a/src/main/java/io/qdrant/client/utils/VectorUtil.java +++ b/src/main/java/io/qdrant/client/utils/VectorUtil.java @@ -4,6 +4,7 @@ import io.qdrant.client.grpc.Points.PointVectors; import io.qdrant.client.grpc.Points.Vector; import io.qdrant.client.grpc.Points.Vectors; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Random; @@ -92,16 +93,16 @@ public static PointVectors pointVectors(String id, String name, float... vector) * @return An array of floats representing the generated embeddings. * @throws IllegalArgumentException If the size is less than or equal to zero. */ - public static float[] dummyEmbeddings(int size) { + public static List dummyEmbeddings(int size) { if (size <= 0) { throw new IllegalArgumentException("Size must be greater than zero"); } - float[] embeddings = new float[size]; + List embeddings = new ArrayList<>(); Random random = new Random(); for (int i = 0; i < size; i++) { - embeddings[i] = random.nextFloat(); + embeddings.add(random.nextFloat()); } return embeddings;