Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Batch update points #16

Merged
merged 9 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 37 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,26 @@ To install the library, add the following lines to your build config file.
<dependency>
<groupId>io.qdrant</groupId>
<artifactId>client</artifactId>
<version>1.7.0</version>
<version>1.7.1</version>
</dependency>
```

#### Scala SBT
#### SBT

```sbt
libraryDependencies += "io.qdrant" % "client" % "1.7.0"
libraryDependencies += "io.qdrant" % "client" % "1.7.1"
```

#### Gradle

```gradle
implementation 'io.qdrant:client:1.7.0'
implementation 'io.qdrant:client:1.7.1'
```

## 📖 Documentation

- [`QdrantClient` Reference](https://qdrant.github.io/java-client/io/qdrant/client/QdrantClient.html#constructor-detail)
- [JavaDoc Reference](https://qdrant.github.io/java-client/)
- Usage examples are available throughout the [Qdrant documentation](https://qdrant.tech/documentation/quick-start/)

## 🔌 Getting started

Expand Down Expand Up @@ -125,39 +126,43 @@ Insert vectors into a collection
// import static convenience methods
import static io.qdrant.client.PointIdFactory.id;
import static io.qdrant.client.ValueFactory.value;
import static io.qdrant.client.VectorsFactory.vector;

Random random = new Random();
List<PointStruct> points = IntStream.range(1, 101)
.mapToObj(i -> PointStruct.newBuilder()
.setId(id(i))
.setVectors(vector(IntStream.range(1, 101)
.mapToObj(v -> random.nextFloat())
.collect(Collectors.toList())))
.putAllPayload(ImmutableMap.of(
"color", value("red"),
"rand_number", value(i % 10))
)
.build()
)
.collect(Collectors.toList());
import static io.qdrant.client.VectorsFactory.vectors;

List<PointStruct> points =
List.of(
PointStruct.newBuilder()
.setId(id(1))
.setVectors(vectors(0.32f, 0.52f, 0.21f, 0.52f))
.putAllPayload(
Map.of(
"color", value("red"),
"rand_number", value(32)))
.build(),
PointStruct.newBuilder()
.setId(id(2))
.setVectors(vectors(0.42f, 0.52f, 0.67f, 0.632f))
.putAllPayload(
Map.of(
"color", value("black"),
"rand_number", value(53),
"extra_field", value(true)))
.build());

UpdateResult updateResult = client.upsertAsync("my_collection", points).get();
```

Search for similar vectors

```java
List<Float> queryVector = IntStream.range(1, 101)
.mapToObj(v -> random.nextFloat())
.collect(Collectors.toList());

List<ScoredPoint> points = client.searchAsync(SearchPoints.newBuilder()
.setCollectionName("my_collection")
.addAllVector(queryVector)
.setLimit(5)
.build()
).get();
List<ScoredPoint> anush =
client
.searchAsync(
SearchPoints.newBuilder()
.setCollectionName("my_collection")
.addAllVector(List.of(0.6235f, 0.123f, 0.532f, 0.123f))
.setLimit(5)
.build())
.get();
```

Search for similar vectors with filtering condition
Expand All @@ -168,7 +173,7 @@ import static io.qdrant.client.ConditionFactory.range;

List<ScoredPoint> points = client.searchAsync(SearchPoints.newBuilder()
.setCollectionName("my_collection")
.addAllVector(queryVector)
.addAllVector(List.of(0.6235f, 0.123f, 0.532f, 0.123f))
.setFilter(Filter.newBuilder()
.addMust(range("rand_number", Range.newBuilder().setGte(3).build()))
.build())
Expand Down
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ plugins {
id 'maven-publish'

id 'com.google.protobuf' version '0.9.4'
id "net.ltgt.errorprone" version '3.1.0'
id 'io.github.gradle-nexus.publish-plugin' version "1.3.0"
id 'net.ltgt.errorprone' version '3.1.0'
id 'io.github.gradle-nexus.publish-plugin' version '1.3.0'
}

group = 'io.qdrant'
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ qdrantProtosVersion=v1.7.0
qdrantVersion=v1.7.0

# The version of the client to generate
packageVersion=1.7.0
packageVersion=1.7.1
71 changes: 68 additions & 3 deletions src/main/java/io/qdrant/client/QdrantClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
import static io.qdrant.client.grpc.Points.DiscoverBatchResponse;
import static io.qdrant.client.grpc.Points.DiscoverPoints;
import static io.qdrant.client.grpc.Points.DiscoverResponse;
import static io.qdrant.client.grpc.Points.PointsUpdateOperation;
import static io.qdrant.client.grpc.Points.UpdateBatchPoints;
import static io.qdrant.client.grpc.Points.UpdateBatchResponse;
import static io.qdrant.client.grpc.Collections.GetCollectionInfoRequest;
import static io.qdrant.client.grpc.Collections.GetCollectionInfoResponse;
import static io.qdrant.client.grpc.Collections.ListAliasesRequest;
Expand Down Expand Up @@ -1551,7 +1554,7 @@ public ListenableFuture<UpdateResult> overwritePayloadAsync(
}

/**
* Overwrites the payload for the given ids.
* Overwrites the payload for the filtered points.
*
* @param collectionName The name of the collection.
* @param payload New payload values
Expand Down Expand Up @@ -1696,7 +1699,7 @@ public ListenableFuture<UpdateResult> deletePayloadAsync(
}

/**
* Delete specified key payload for the given ids.
* Delete specified key payload for the filtered points.
*
* @param collectionName The name of the collection.
* @param keys List of keys to delete.
Expand Down Expand Up @@ -1832,7 +1835,7 @@ public ListenableFuture<UpdateResult> clearPayloadAsync(
}

/**
* Removes all payload for the given ids.
* Removes all payload for the filtered points.
*
* @param collectionName The name of the collection.
* @param filter A filter selecting the points for which to remove the payload.
Expand Down Expand Up @@ -2204,6 +2207,68 @@ public ListenableFuture<List<BatchResult>> recommendBatchAsync(
MoreExecutors.directExecutor());
}

/**
* Performs a batch update of points.
*
* @param collectionName The name of the collection.
* @param operations The list of point update operations.
*
* @return a new instance of {@link ListenableFuture}
*/
public ListenableFuture<List<UpdateResult>> batchUpdateAsync(String collectionName, List<PointsUpdateOperation> operations) {
return batchUpdateAsync(collectionName, operations, null, null, null);
}

/**
* Performs a batch update of points.
*
* @param collectionName The name of the collection.
* @param operations The list of point update operations.
* @param wait Whether to wait until the changes have been applied. Defaults to <code>true</code>.
* @param ordering Write ordering guarantees.
* @param timeout The timeout for the call.
*
* @return a new instance of {@link ListenableFuture}
*/
public ListenableFuture<List<UpdateResult>> batchUpdateAsync(
String collectionName,
List<PointsUpdateOperation> operations,
@Nullable Boolean wait,
@Nullable WriteOrdering ordering,
@Nullable Duration timeout) {

UpdateBatchPoints.Builder requestBuilder = UpdateBatchPoints.newBuilder()
.setCollectionName(collectionName)
.addAllOperations(operations)
.setWait(wait == null || wait);

if (ordering != null) {
requestBuilder.setOrdering(ordering);
}
return batchUpdateAsync(requestBuilder.build(), timeout);
}


/**
* Performs a batch update of points.
*
* @param request The update batch request.
* @param timeout The timeout for the call.
*
* @return a new instance of {@link ListenableFuture}
*/
public ListenableFuture<List<UpdateResult>> batchUpdateAsync(UpdateBatchPoints request, @Nullable Duration timeout) {
String collectionName = request.getCollectionName();
Preconditions.checkArgument(!collectionName.isEmpty(), "Collection name must not be empty");
logger.debug("Batch update points on '{}'", collectionName);
ListenableFuture<UpdateBatchResponse> future = getPoints(timeout).updateBatch(request);
addLogFailureCallback(future, "Batch update points");
return Futures.transform(
future,
UpdateBatchResponse::getResultList,
MoreExecutors.directExecutor());
}

/**
* Look for the points which are closer to stored positive examples and at the same time further to negative
* examples, grouped by a given field
Expand Down
7 changes: 4 additions & 3 deletions src/main/java/io/qdrant/client/VectorsFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static io.qdrant.client.VectorFactory.vector;
import static io.qdrant.client.grpc.Points.NamedVectors;
import io.qdrant.client.grpc.Points.Vector;
import static io.qdrant.client.grpc.Points.Vectors;

/**
Expand All @@ -18,13 +19,13 @@ private VectorsFactory() {

/**
* Creates named vectors
* @param values A map of vector names to values
* @param values A map of vector names to {@link Vector}
* @return a new instance of {@link Vectors}
*/
public static Vectors namedVectors(Map<String, List<Float>> values) {
public static Vectors namedVectors(Map<String, Vector> values) {
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
return Vectors.newBuilder()
.setVectors(NamedVectors.newBuilder()
.putAllVectors(Maps.transformValues(values, v -> vector(v)))
.putAllVectors(values)
)
.build();
}
Expand Down
30 changes: 30 additions & 0 deletions src/test/java/io/qdrant/client/PointsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
import org.testcontainers.shaded.com.google.common.collect.ImmutableSet;
import io.qdrant.client.container.QdrantContainer;
import io.qdrant.client.grpc.Points.DiscoverPoints;
import io.qdrant.client.grpc.Points.PointVectors;
import io.qdrant.client.grpc.Points.PointsIdsList;
import io.qdrant.client.grpc.Points.PointsSelector;
import io.qdrant.client.grpc.Points.PointsUpdateOperation;
import io.qdrant.client.grpc.Points.UpdateBatchResponse;
import io.qdrant.client.grpc.Points.PointsUpdateOperation.ClearPayload;
import io.qdrant.client.grpc.Points.PointsUpdateOperation.UpdateVectors;

import java.util.List;
import java.util.concurrent.ExecutionException;
Expand Down Expand Up @@ -49,6 +56,7 @@
import static io.qdrant.client.TargetVectorFactory.targetVector;
import static io.qdrant.client.ValueFactory.value;
import static io.qdrant.client.VectorFactory.vector;
import static io.qdrant.client.VectorsFactory.vectors;

@Testcontainers
class PointsTest {
Expand Down Expand Up @@ -540,6 +548,28 @@ public void delete_by_filter() throws ExecutionException, InterruptedException {
assertEquals(0, points.size());
}

@Test
public void batchPointUpdate() throws ExecutionException, InterruptedException {
createAndSeedCollection(testName);

List<PointsUpdateOperation> operations = List.of(
PointsUpdateOperation.newBuilder()
.setClearPayload(ClearPayload.newBuilder().setPoints(
PointsSelector.newBuilder().setPoints(PointsIdsList.newBuilder().addIds(id(9))))
.build())
.build(),
PointsUpdateOperation.newBuilder()
.setUpdateVectors(UpdateVectors.newBuilder()
.addPoints(PointVectors.newBuilder()
.setId(id(9))
.setVectors(vectors(0.6f, 0.7f))))
.build());

List<UpdateResult> response = client.batchUpdateAsync(testName, operations).get();

response.forEach(result -> assertEquals(UpdateStatus.Completed, result.getStatus()));
}

private void createAndSeedCollection(String collectionName) throws ExecutionException, InterruptedException {
CreateCollection request = CreateCollection.newBuilder()
.setCollectionName(collectionName)
Expand Down