Skip to content

Commit

Permalink
feat: download snapshots
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush authored and Anush committed Dec 14, 2023
1 parent 6ce7f19 commit 4ee52fa
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
67 changes: 67 additions & 0 deletions src/main/java/io/qdrant/client/QdrantClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.file.Path;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -2380,6 +2388,65 @@ public ListenableFuture<DeleteSnapshotResponse> deleteFullSnapshotAsync(String s
addLogFailureCallback(future, "Delete full snapshot");
return future;
}
/**
* Downloads a snapshot of a collection from the specified REST API URI and
* saves it to the given
* output path.
*
* @param outPath The path where the snapshot will be saved.
* @param collectionName The name of the collection.
* @param snapshotName The name of the snapshot. If null, the latest snapshot
* will be downloaded.
* @param restApiUri The URI of the REST API. If null, the default URI
* "http://localhost:6333"
* will be used.
*/
public void downloadSnapshot(
Path outPath,
String collectionName,
@Nullable String snapshotName,
@Nullable String restApiUri) throws InterruptedException, IOException, ExecutionException {
String resolvedSnapshotName;
if (snapshotName != null) {
resolvedSnapshotName = snapshotName;
} else {
// Get the latest(0th) snapshot of the collection
List<SnapshotDescription> snapshots = listSnapshotAsync(collectionName).get();
if (snapshots.isEmpty()) {
throw new RuntimeException("No snapshots found");
}
resolvedSnapshotName = snapshots.get(0).getName();
}

String uri;
if (restApiUri != null) {
uri = String.format(
"%s/collections/%s/snapshots/%s", restApiUri, collectionName, resolvedSnapshotName);
} else {
uri = String.format(
"http://localhost:6333/collections/%s/snapshots/%s",
collectionName, resolvedSnapshotName);
}

URL url = new URL(uri);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();

if (connection.getResponseCode() == 200) {
try (InputStream in = connection.getInputStream();
FileOutputStream fileOut = new FileOutputStream(outPath.toFile())) {

byte[] buffer = new byte[8192];
int bytesRead;
while ((bytesRead = in.read(buffer)) != -1) {
fileOut.write(buffer, 0, bytesRead);
}

logger.info("Downloaded snapshot to {}", outPath);
}
} else {
throw new RuntimeException("Download failed. HTTP Status Code: " + connection.getResponseCode());
}
}

//endregion

Expand Down
27 changes: 27 additions & 0 deletions src/test/java/io/qdrant/client/SnapshotsTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.qdrant.client;

import io.qdrant.client.grpc.Collections;
import io.qdrant.client.grpc.SnapshotsService;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
Expand All @@ -14,13 +15,17 @@
import org.testcontainers.junit.jupiter.Testcontainers;
import io.qdrant.client.container.QdrantContainer;

import java.io.IOException;
import java.nio.file.FileSystems;
import java.nio.file.Path;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import static io.qdrant.client.grpc.SnapshotsService.SnapshotDescription;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

@Testcontainers
class SnapshotsTest {
Expand Down Expand Up @@ -134,6 +139,28 @@ public void listFullSnapshots() throws ExecutionException, InterruptedException
assertEquals(2, snapshotDescriptions.size());
}

@Test
public void testDownloadSnapshot() throws ExecutionException, InterruptedException, IOException {
String restApiUri = "http://" + QDRANT_CONTAINER.getHttpHostAddress();
createCollection(testName);

assertEquals(client.listSnapshotAsync(testName).get().size(), 0);

// Test with snapshot name
SnapshotsService.SnapshotDescription response = client.createSnapshotAsync(testName).get();
String snapshotName = response.getName();

Path path = FileSystems.getDefault().getPath("./test.snapshot");

client.downloadSnapshot(path, testName, snapshotName, restApiUri);
assertTrue(path.toFile().exists());

// Test without snapshot name
path = FileSystems.getDefault().getPath("./test_2.snapshot");
client.downloadSnapshot(path, testName, null, restApiUri);
assertTrue(path.toFile().exists());
}

private void createCollection(String collectionName) throws ExecutionException, InterruptedException {
Collections.CreateCollection request = Collections.CreateCollection.newBuilder()
.setCollectionName(collectionName)
Expand Down

0 comments on commit 4ee52fa

Please sign in to comment.