Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: api tests #1

Merged
merged 4 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 61 additions & 7 deletions src/main/java/io/qdrant/client/QdrantClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public Collections.CollectionOperationResponse createCollection(
.setVectorsConfig(config)
.setCollectionName(collectionName)
.build();
return collectionsStub.create(details);
return createCollection(details);
}

/**
Expand All @@ -216,7 +216,6 @@ public Collections.CollectionOperationResponse createCollection(
public Collections.CollectionOperationResponse recreateCollection(
String collectionName, long vectorSize, Collections.Distance distance) {

deleteCollection(collectionName);
Collections.VectorParams.Builder params =
Collections.VectorParams.newBuilder().setDistance(distance).setSize(vectorSize);
Collections.VectorsConfig config =
Expand All @@ -226,7 +225,7 @@ public Collections.CollectionOperationResponse recreateCollection(
.setVectorsConfig(config)
.setCollectionName(collectionName)
.build();
return collectionsStub.create(details);
return recreateCollection(details);
}

/**
Expand Down Expand Up @@ -543,6 +542,62 @@ public Points.PointsOperationResponse upsertPointsBatchBlocking(
return upsertPointsBatch(collectionName, points, ordering, true, chunkSize);
}

/** Internal update method */
private Points.PointsOperationResponse setPayload(
String collectionName,
Points.PointsSelector points,
Map<String, Value> payload,
Points.WriteOrderingType ordering,
Boolean wait) {
Points.SetPayloadPoints.Builder request =
Points.SetPayloadPoints.newBuilder()
.setCollectionName(collectionName)
.setPointsSelector(points)
.putAllPayload(payload)
.setWait(wait);

if (ordering != null) {
request.setOrdering(PointUtil.ordering(ordering));
}
return pointsStub.setPayload(request.build());
}

/**
* Sets the payload of the specified points in a collection. Does not wait for the operation to
* complete before returning.
*
* @param collectionName The name of the collection.
* @param points The selector for the points to be updated.
* @param payload The new payload to be assigned to the points.
* @param ordering The ordering of the write operation.
* @return The response of the points operation.
*/
public Points.PointsOperationResponse setPayload(
String collectionName,
Points.PointsSelector points,
Map<String, Value> payload,
Points.WriteOrderingType ordering) {
return setPayload(collectionName, points, payload, ordering, false);
}

/**
* Sets the payload of the specified points in a collection. Waits for the operation to complete
* before returning.
*
* @param collectionName The name of the collection.
* @param points The selector for the points to be updated.
* @param payload The new payload to be assigned to the points.
* @param ordering The ordering of the write operation.
* @return The response of the points operation.
*/
public Points.PointsOperationResponse setPayloadBlocking(
String collectionName,
Points.PointsSelector points,
Map<String, Value> payload,
Points.WriteOrderingType ordering) {
return setPayload(collectionName, points, payload, ordering, true);
}

/** Internal payload overwrite method */
private Points.PointsOperationResponse overwritePayload(
String collectionName,
Expand All @@ -560,7 +615,7 @@ private Points.PointsOperationResponse overwritePayload(
if (ordering != null) {
request.setOrdering(PointUtil.ordering(ordering));
}
return pointsStub.setPayload(request.build());
return pointsStub.overwritePayload(request.build());
}

/**
Expand Down Expand Up @@ -599,7 +654,7 @@ public Points.PointsOperationResponse overwritePayloadBlocking(
return overwritePayload(collectionName, points, payload, ordering, true);
}

/** Internal payload update method */
/** Internal payload delete method */
private Points.PointsOperationResponse deletePayload(
String collectionName,
Points.PointsSelector points,
Expand Down Expand Up @@ -1187,10 +1242,9 @@ public SnapshotsService.CreateSnapshotResponse createFullSnapshot() {
/**
* Retrieves a list of full snapshots for a given collection.
*
* @param collectionName The name of the collection.
* @return The response containing the list of full snapshots.
*/
public SnapshotsService.ListSnapshotsResponse listFullSnapshots(String collectionName) {
public SnapshotsService.ListSnapshotsResponse listFullSnapshots() {
SnapshotsService.ListFullSnapshotsRequest request =
SnapshotsService.ListFullSnapshotsRequest.newBuilder().build();
return snapshotStub.listFull(request);
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/io/qdrant/client/utils/PayloadUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ public static Map<String, Value> toPayload(Map<String, Object> inputMap) {
* @param struct The payload struct to convert.
* @return The converted hash map.
*/
public static Map<String, Object> payloadStructToHashMap(Struct struct) {
Map<String, Object> structMap = toHashMap(struct.getFieldsMap());
public static Map<String, Object> toMap(Struct struct) {
Map<String, Object> structMap = toMap(struct.getFieldsMap());
return structMap;
}

Expand All @@ -60,7 +60,7 @@ public static Map<String, Object> payloadStructToHashMap(Struct struct) {
* @param payload The payload map to convert.
* @return The converted hash map.
*/
public static Map<String, Object> toHashMap(Map<String, Value> payload) {
public static Map<String, Object> toMap(Map<String, Value> payload) {
Map<String, Object> hashMap = new HashMap<>();
for (Map.Entry<String, Value> entry : payload.entrySet()) {
String fieldName = entry.getKey();
Expand Down Expand Up @@ -144,7 +144,7 @@ static Object valueToObject(Value value) {
} else if (value.hasNullValue()) {
return null;
} else if (value.hasStructValue()) {
return payloadStructToHashMap(value.getStructValue());
return toMap(value.getStructValue());
} else if (value.hasListValue()) {
return listValueToList(value.getListValue());
}
Expand Down
45 changes: 45 additions & 0 deletions src/main/java/io/qdrant/client/utils/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import io.qdrant.client.grpc.Points.PointVectors;
import io.qdrant.client.grpc.Points.Vector;
import io.qdrant.client.grpc.Points.Vectors;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;

/** Utility class for working with vector data. */
public class VectorUtil {
Expand Down Expand Up @@ -83,4 +85,47 @@ public static PointVectors pointVectors(String id, String name, float... vector)
Vectors vectors = Vectors.newBuilder().setVectors(namedVector(name, vector)).build();
return pointVectorsBuilder.setId(PointUtil.pointId(id)).setVectors(vectors).build();
}

/**
* Generates dummy embeddings of the specified size.
*
* @param size The size of the embeddings to generate.
* @return An array of floats representing the generated embeddings.
* @throws IllegalArgumentException If the size is less than or equal to zero.
*/
public static List<Float> dummyEmbeddings(int size) {
if (size <= 0) {
throw new IllegalArgumentException("Size must be greater than zero");
}

List<Float> embeddings = new ArrayList<>();
Random random = new Random();

for (int i = 0; i < size; i++) {
embeddings.add(random.nextFloat());
}

return embeddings;
}

/**
* Generates a dummy vector of the specified size.
*
* @param size The size of the vector.
* @return The generated dummy vector.
*/
public static Vector dummyVector(int size) {
return toVector(dummyEmbeddings(size));
}

/**
* Generates a dummy named vector of the specified size.
*
* @param name The name of the vector.
* @param size The size of the vector.
* @return The generated dummy vector.
*/
public static NamedVectors dummyNamedVector(String name, int size) {
return namedVector(name, dummyEmbeddings(size));
}
}
158 changes: 158 additions & 0 deletions src/test/java/io/qdrant/client/QdrantClientCollectionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package io.qdrant.client;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.grpc.StatusRuntimeException;
import io.qdrant.client.grpc.Collections;
import io.qdrant.client.grpc.Collections.Distance;
import java.util.UUID;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

class QdrantClientCollectionTest {

private static QdrantClient qdrantClient;

@BeforeAll
static void setUp() throws Exception {
String qdrantUrl = System.getenv("QDRANT_URL");
String apiKey = System.getenv("QDRANT_API_KEY");

if (qdrantUrl == null) {
qdrantUrl = "http://localhost:6334";
}

if (apiKey == null) {
qdrantClient = new QdrantClient(qdrantUrl);
} else {
qdrantClient = new QdrantClient(qdrantUrl, apiKey);
}
}

@Test
void testAliasOperations() {
String collectionName = UUID.randomUUID().toString();
String aliasName = UUID.randomUUID().toString();

assertThrows(
StatusRuntimeException.class,
() -> {
// This should fail as collection does not exist
qdrantClient.createAlias(collectionName, aliasName);
});

qdrantClient.createCollection(collectionName, 6, Distance.Euclid);
assertDoesNotThrow(
() -> {
Collections.CollectionOperationResponse response =
qdrantClient.createAlias(collectionName, aliasName);
assertTrue(response.getResult());
});

Collections.ListAliasesResponse response = qdrantClient.listAliases();
assertTrue(response.getAliasesCount() == 1);
Collections.AliasDescription alias = response.getAliasesList().get(0);
assertTrue(alias.getCollectionName().equals(collectionName));
assertTrue(alias.getAliasName().equals(aliasName));

String newAliasName = UUID.randomUUID().toString();
qdrantClient.renameAlias(aliasName, newAliasName);
response = qdrantClient.listAliases();
assertTrue(response.getAliasesCount() == 1);

alias = response.getAliasesList().get(0);
assertTrue(alias.getCollectionName().equals(collectionName));
assertTrue(alias.getAliasName().equals(newAliasName));

qdrantClient.deleteAlias(newAliasName);
response = qdrantClient.listAliases();
}

@Test
void testListCollections() {
assertDoesNotThrow(
() -> {
Collections.ListCollectionsResponse response = qdrantClient.listCollections();
assertTrue(response.getCollectionsCount() >= 0);
});
}

@Test
void testHasCollection() {
String collectionName = UUID.randomUUID().toString();
boolean exists = qdrantClient.hasCollection(collectionName);
assertFalse(exists);

qdrantClient.createCollection(collectionName, 6, Distance.Euclid);

exists = qdrantClient.hasCollection(collectionName);
assertTrue(exists);
}

@Test
void testCollectionConfigOperations() {
long vectorSize = 128;
String collectionName = UUID.randomUUID().toString();
Collections.Distance distance = Collections.Distance.Cosine;

Collections.VectorParams.Builder params =
Collections.VectorParams.newBuilder().setDistance(distance).setSize(vectorSize);

Collections.VectorsConfig config =
Collections.VectorsConfig.newBuilder().setParams(params).build();

Collections.HnswConfigDiff hnsw =
Collections.HnswConfigDiff.newBuilder().setM(16).setEfConstruct(200).build();

Collections.CreateCollection details =
Collections.CreateCollection.newBuilder()
.setVectorsConfig(config)
.setHnswConfig(hnsw)
.setCollectionName(collectionName)
.build();

Collections.CollectionOperationResponse response = qdrantClient.createCollection(details);
assertTrue(response.getResult());

Collections.GetCollectionInfoResponse info = qdrantClient.getCollectionInfo(collectionName);
assertTrue(info.getResult().getConfig().getHnswConfig().getM() == 16);

Collections.UpdateCollection updateCollection =
Collections.UpdateCollection.newBuilder()
.setCollectionName(collectionName)
.setHnswConfig(Collections.HnswConfigDiff.newBuilder().setM(32).build())
.build();
qdrantClient.updateCollection(updateCollection);

info = qdrantClient.getCollectionInfo(collectionName);
assertTrue(info.getResult().getConfig().getHnswConfig().getM() == 32);
}

@Test
void testRecreateCollection() {
String collectionName = UUID.randomUUID().toString();

qdrantClient.createCollection(collectionName, 6, Distance.Euclid);
assertDoesNotThrow(
() -> {
qdrantClient.recreateCollection(collectionName, 12, Distance.Dot);
});
}

@Test
void testDeleteCollection() {
String collectionName = UUID.randomUUID().toString();

Collections.CollectionOperationResponse response =
qdrantClient.deleteCollection(collectionName);
assertFalse(response.getResult());

qdrantClient.createCollection(collectionName, 6, Distance.Euclid);

response = qdrantClient.deleteCollection(collectionName);
assertTrue(response.getResult());
}
}
Loading