Skip to content

Commit

Permalink
test: api tests (#1)
Browse files Browse the repository at this point in the history
* test: api tests

* test: pointClient

* tests: points tests

* refactor: Updatedpayload methods, utility methods
  • Loading branch information
Anush008 authored Dec 2, 2023
1 parent 9dd6a9d commit 9fd3840
Show file tree
Hide file tree
Showing 8 changed files with 655 additions and 17 deletions.
68 changes: 61 additions & 7 deletions src/main/java/io/qdrant/client/QdrantClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public Collections.CollectionOperationResponse createCollection(
.setVectorsConfig(config)
.setCollectionName(collectionName)
.build();
return collectionsStub.create(details);
return createCollection(details);
}

/**
Expand All @@ -216,7 +216,6 @@ public Collections.CollectionOperationResponse createCollection(
public Collections.CollectionOperationResponse recreateCollection(
String collectionName, long vectorSize, Collections.Distance distance) {

deleteCollection(collectionName);
Collections.VectorParams.Builder params =
Collections.VectorParams.newBuilder().setDistance(distance).setSize(vectorSize);
Collections.VectorsConfig config =
Expand All @@ -226,7 +225,7 @@ public Collections.CollectionOperationResponse recreateCollection(
.setVectorsConfig(config)
.setCollectionName(collectionName)
.build();
return collectionsStub.create(details);
return recreateCollection(details);
}

/**
Expand Down Expand Up @@ -543,6 +542,62 @@ public Points.PointsOperationResponse upsertPointsBatchBlocking(
return upsertPointsBatch(collectionName, points, ordering, true, chunkSize);
}

/** Internal update method */
private Points.PointsOperationResponse setPayload(
String collectionName,
Points.PointsSelector points,
Map<String, Value> payload,
Points.WriteOrderingType ordering,
Boolean wait) {
Points.SetPayloadPoints.Builder request =
Points.SetPayloadPoints.newBuilder()
.setCollectionName(collectionName)
.setPointsSelector(points)
.putAllPayload(payload)
.setWait(wait);

if (ordering != null) {
request.setOrdering(PointUtil.ordering(ordering));
}
return pointsStub.setPayload(request.build());
}

/**
* Sets the payload of the specified points in a collection. Does not wait for the operation to
* complete before returning.
*
* @param collectionName The name of the collection.
* @param points The selector for the points to be updated.
* @param payload The new payload to be assigned to the points.
* @param ordering The ordering of the write operation.
* @return The response of the points operation.
*/
public Points.PointsOperationResponse setPayload(
String collectionName,
Points.PointsSelector points,
Map<String, Value> payload,
Points.WriteOrderingType ordering) {
return setPayload(collectionName, points, payload, ordering, false);
}

/**
* Sets the payload of the specified points in a collection. Waits for the operation to complete
* before returning.
*
* @param collectionName The name of the collection.
* @param points The selector for the points to be updated.
* @param payload The new payload to be assigned to the points.
* @param ordering The ordering of the write operation.
* @return The response of the points operation.
*/
public Points.PointsOperationResponse setPayloadBlocking(
String collectionName,
Points.PointsSelector points,
Map<String, Value> payload,
Points.WriteOrderingType ordering) {
return setPayload(collectionName, points, payload, ordering, true);
}

/** Internal payload overwrite method */
private Points.PointsOperationResponse overwritePayload(
String collectionName,
Expand All @@ -560,7 +615,7 @@ private Points.PointsOperationResponse overwritePayload(
if (ordering != null) {
request.setOrdering(PointUtil.ordering(ordering));
}
return pointsStub.setPayload(request.build());
return pointsStub.overwritePayload(request.build());
}

/**
Expand Down Expand Up @@ -599,7 +654,7 @@ public Points.PointsOperationResponse overwritePayloadBlocking(
return overwritePayload(collectionName, points, payload, ordering, true);
}

/** Internal payload update method */
/** Internal payload delete method */
private Points.PointsOperationResponse deletePayload(
String collectionName,
Points.PointsSelector points,
Expand Down Expand Up @@ -1187,10 +1242,9 @@ public SnapshotsService.CreateSnapshotResponse createFullSnapshot() {
/**
* Retrieves a list of full snapshots for a given collection.
*
* @param collectionName The name of the collection.
* @return The response containing the list of full snapshots.
*/
public SnapshotsService.ListSnapshotsResponse listFullSnapshots(String collectionName) {
public SnapshotsService.ListSnapshotsResponse listFullSnapshots() {
SnapshotsService.ListFullSnapshotsRequest request =
SnapshotsService.ListFullSnapshotsRequest.newBuilder().build();
return snapshotStub.listFull(request);
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/io/qdrant/client/utils/PayloadUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ public static Map<String, Value> toPayload(Map<String, Object> inputMap) {
* @param struct The payload struct to convert.
* @return The converted hash map.
*/
public static Map<String, Object> payloadStructToHashMap(Struct struct) {
Map<String, Object> structMap = toHashMap(struct.getFieldsMap());
public static Map<String, Object> toMap(Struct struct) {
Map<String, Object> structMap = toMap(struct.getFieldsMap());
return structMap;
}

Expand All @@ -60,7 +60,7 @@ public static Map<String, Object> payloadStructToHashMap(Struct struct) {
* @param payload The payload map to convert.
* @return The converted hash map.
*/
public static Map<String, Object> toHashMap(Map<String, Value> payload) {
public static Map<String, Object> toMap(Map<String, Value> payload) {
Map<String, Object> hashMap = new HashMap<>();
for (Map.Entry<String, Value> entry : payload.entrySet()) {
String fieldName = entry.getKey();
Expand Down Expand Up @@ -144,7 +144,7 @@ static Object valueToObject(Value value) {
} else if (value.hasNullValue()) {
return null;
} else if (value.hasStructValue()) {
return payloadStructToHashMap(value.getStructValue());
return toMap(value.getStructValue());
} else if (value.hasListValue()) {
return listValueToList(value.getListValue());
}
Expand Down
45 changes: 45 additions & 0 deletions src/main/java/io/qdrant/client/utils/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import io.qdrant.client.grpc.Points.PointVectors;
import io.qdrant.client.grpc.Points.Vector;
import io.qdrant.client.grpc.Points.Vectors;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;

/** Utility class for working with vector data. */
public class VectorUtil {
Expand Down Expand Up @@ -83,4 +85,47 @@ public static PointVectors pointVectors(String id, String name, float... vector)
Vectors vectors = Vectors.newBuilder().setVectors(namedVector(name, vector)).build();
return pointVectorsBuilder.setId(PointUtil.pointId(id)).setVectors(vectors).build();
}

/**
* Generates dummy embeddings of the specified size.
*
* @param size The size of the embeddings to generate.
* @return An array of floats representing the generated embeddings.
* @throws IllegalArgumentException If the size is less than or equal to zero.
*/
public static List<Float> dummyEmbeddings(int size) {
if (size <= 0) {
throw new IllegalArgumentException("Size must be greater than zero");
}

List<Float> embeddings = new ArrayList<>();
Random random = new Random();

for (int i = 0; i < size; i++) {
embeddings.add(random.nextFloat());
}

return embeddings;
}

/**
* Generates a dummy vector of the specified size.
*
* @param size The size of the vector.
* @return The generated dummy vector.
*/
public static Vector dummyVector(int size) {
return toVector(dummyEmbeddings(size));
}

/**
* Generates a dummy named vector of the specified size.
*
* @param name The name of the vector.
* @param size The size of the vector.
* @return The generated dummy vector.
*/
public static NamedVectors dummyNamedVector(String name, int size) {
return namedVector(name, dummyEmbeddings(size));
}
}
158 changes: 158 additions & 0 deletions src/test/java/io/qdrant/client/QdrantClientCollectionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package io.qdrant.client;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.grpc.StatusRuntimeException;
import io.qdrant.client.grpc.Collections;
import io.qdrant.client.grpc.Collections.Distance;
import java.util.UUID;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

class QdrantClientCollectionTest {

private static QdrantClient qdrantClient;

@BeforeAll
static void setUp() throws Exception {
String qdrantUrl = System.getenv("QDRANT_URL");
String apiKey = System.getenv("QDRANT_API_KEY");

if (qdrantUrl == null) {
qdrantUrl = "http://localhost:6334";
}

if (apiKey == null) {
qdrantClient = new QdrantClient(qdrantUrl);
} else {
qdrantClient = new QdrantClient(qdrantUrl, apiKey);
}
}

@Test
void testAliasOperations() {
String collectionName = UUID.randomUUID().toString();
String aliasName = UUID.randomUUID().toString();

assertThrows(
StatusRuntimeException.class,
() -> {
// This should fail as collection does not exist
qdrantClient.createAlias(collectionName, aliasName);
});

qdrantClient.createCollection(collectionName, 6, Distance.Euclid);
assertDoesNotThrow(
() -> {
Collections.CollectionOperationResponse response =
qdrantClient.createAlias(collectionName, aliasName);
assertTrue(response.getResult());
});

Collections.ListAliasesResponse response = qdrantClient.listAliases();
assertTrue(response.getAliasesCount() == 1);
Collections.AliasDescription alias = response.getAliasesList().get(0);
assertTrue(alias.getCollectionName().equals(collectionName));
assertTrue(alias.getAliasName().equals(aliasName));

String newAliasName = UUID.randomUUID().toString();
qdrantClient.renameAlias(aliasName, newAliasName);
response = qdrantClient.listAliases();
assertTrue(response.getAliasesCount() == 1);

alias = response.getAliasesList().get(0);
assertTrue(alias.getCollectionName().equals(collectionName));
assertTrue(alias.getAliasName().equals(newAliasName));

qdrantClient.deleteAlias(newAliasName);
response = qdrantClient.listAliases();
}

@Test
void testListCollections() {
assertDoesNotThrow(
() -> {
Collections.ListCollectionsResponse response = qdrantClient.listCollections();
assertTrue(response.getCollectionsCount() >= 0);
});
}

@Test
void testHasCollection() {
String collectionName = UUID.randomUUID().toString();
boolean exists = qdrantClient.hasCollection(collectionName);
assertFalse(exists);

qdrantClient.createCollection(collectionName, 6, Distance.Euclid);

exists = qdrantClient.hasCollection(collectionName);
assertTrue(exists);
}

@Test
void testCollectionConfigOperations() {
long vectorSize = 128;
String collectionName = UUID.randomUUID().toString();
Collections.Distance distance = Collections.Distance.Cosine;

Collections.VectorParams.Builder params =
Collections.VectorParams.newBuilder().setDistance(distance).setSize(vectorSize);

Collections.VectorsConfig config =
Collections.VectorsConfig.newBuilder().setParams(params).build();

Collections.HnswConfigDiff hnsw =
Collections.HnswConfigDiff.newBuilder().setM(16).setEfConstruct(200).build();

Collections.CreateCollection details =
Collections.CreateCollection.newBuilder()
.setVectorsConfig(config)
.setHnswConfig(hnsw)
.setCollectionName(collectionName)
.build();

Collections.CollectionOperationResponse response = qdrantClient.createCollection(details);
assertTrue(response.getResult());

Collections.GetCollectionInfoResponse info = qdrantClient.getCollectionInfo(collectionName);
assertTrue(info.getResult().getConfig().getHnswConfig().getM() == 16);

Collections.UpdateCollection updateCollection =
Collections.UpdateCollection.newBuilder()
.setCollectionName(collectionName)
.setHnswConfig(Collections.HnswConfigDiff.newBuilder().setM(32).build())
.build();
qdrantClient.updateCollection(updateCollection);

info = qdrantClient.getCollectionInfo(collectionName);
assertTrue(info.getResult().getConfig().getHnswConfig().getM() == 32);
}

@Test
void testRecreateCollection() {
String collectionName = UUID.randomUUID().toString();

qdrantClient.createCollection(collectionName, 6, Distance.Euclid);
assertDoesNotThrow(
() -> {
qdrantClient.recreateCollection(collectionName, 12, Distance.Dot);
});
}

@Test
void testDeleteCollection() {
String collectionName = UUID.randomUUID().toString();

Collections.CollectionOperationResponse response =
qdrantClient.deleteCollection(collectionName);
assertFalse(response.getResult());

qdrantClient.createCollection(collectionName, 6, Distance.Euclid);

response = qdrantClient.deleteCollection(collectionName);
assertTrue(response.getResult());
}
}
Loading

0 comments on commit 9fd3840

Please sign in to comment.