Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add version check on client init #61

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ dependencies {

testImplementation "io.grpc:grpc-testing:${grpcVersion}"
testImplementation "org.junit.jupiter:junit-jupiter-api:${jUnitVersion}"
testImplementation "org.junit.jupiter:junit-jupiter-params:${jUnitVersion}"
testImplementation "org.mockito:mockito-core:3.4.0"
testImplementation "org.slf4j:slf4j-nop:${slf4jVersion}"
testImplementation "org.testcontainers:qdrant:${testcontainersVersion}"
Expand Down
77 changes: 66 additions & 11 deletions src/main/java/io/qdrant/client/QdrantGrpcClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.qdrant.client.grpc.CollectionsGrpc;
import io.qdrant.client.grpc.*;
import io.qdrant.client.grpc.CollectionsGrpc.CollectionsFutureStub;
import io.qdrant.client.grpc.PointsGrpc;
import io.qdrant.client.grpc.PointsGrpc.PointsFutureStub;
import io.qdrant.client.grpc.QdrantGrpc;
import io.qdrant.client.grpc.QdrantGrpc.QdrantFutureStub;
import io.qdrant.client.grpc.SnapshotsGrpc;
import io.qdrant.client.grpc.SnapshotsGrpc.SnapshotsFutureStub;
import java.time.Duration;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -45,7 +42,7 @@ public class QdrantGrpcClient implements AutoCloseable {
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(ManagedChannel channel) {
return new Builder(channel, false);
return new Builder(channel, false, true);
}

/**
Expand All @@ -56,7 +53,21 @@ public static Builder newBuilder(ManagedChannel channel) {
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(ManagedChannel channel, boolean shutdownChannelOnClose) {
return new Builder(channel, shutdownChannelOnClose);
return new Builder(channel, shutdownChannelOnClose, true);
}

/**
* Creates a new builder to build a client.
*
* @param channel The channel for communication.
* @param shutdownChannelOnClose Whether the channel is shutdown on client close.
* @param checkCompatibility Whether to check compatibility between client's and server's
* versions.
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(
ManagedChannel channel, boolean shutdownChannelOnClose, boolean checkCompatibility) {
return new Builder(channel, shutdownChannelOnClose, checkCompatibility);
}

/**
Expand All @@ -66,7 +77,7 @@ public static Builder newBuilder(ManagedChannel channel, boolean shutdownChannel
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(String host) {
return new Builder(host, 6334, true);
return new Builder(host, 6334, true, true);
}

/**
Expand All @@ -77,7 +88,7 @@ public static Builder newBuilder(String host) {
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(String host, int port) {
return new Builder(host, port, true);
return new Builder(host, port, true, true);
}

/**
Expand All @@ -90,7 +101,23 @@ public static Builder newBuilder(String host, int port) {
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(String host, int port, boolean useTransportLayerSecurity) {
return new Builder(host, port, useTransportLayerSecurity);
return new Builder(host, port, useTransportLayerSecurity, true);
}

/**
* Creates a new builder to build a client.
*
* @param host The host to connect to.
* @param port The port to connect to.
* @param useTransportLayerSecurity Whether the client uses Transport Layer Security (TLS) to
* secure communications. Running without TLS should only be used for testing purposes.
* @param checkCompatibility Whether to check compatibility between client's and server's
* versions.
* @return a new instance of {@link Builder}
*/
public static Builder newBuilder(
String host, int port, boolean useTransportLayerSecurity, boolean checkCompatibility) {
return new Builder(host, port, useTransportLayerSecurity, checkCompatibility);
}

/**
Expand Down Expand Up @@ -168,17 +195,24 @@ public static class Builder {
@Nullable private CallCredentials callCredentials;
@Nullable private Duration timeout;

Builder(ManagedChannel channel, boolean shutdownChannelOnClose) {
Builder(ManagedChannel channel, boolean shutdownChannelOnClose, boolean checkCompatibility) {
this.channel = channel;
this.shutdownChannelOnClose = shutdownChannelOnClose;
String clientVersion = Builder.class.getPackage().getImplementationVersion();
if (checkCompatibility) {
checkVersionsCompatibility(clientVersion);
}
}

Builder(String host, int port, boolean useTransportLayerSecurity) {
Builder(String host, int port, boolean useTransportLayerSecurity, boolean checkCompatibility) {
String clientVersion = Builder.class.getPackage().getImplementationVersion();
String javaVersion = System.getProperty("java.version");
String userAgent = "java-client/" + clientVersion + " java/" + javaVersion;
this.channel = createChannel(host, port, useTransportLayerSecurity, userAgent);
this.shutdownChannelOnClose = true;
if (checkCompatibility) {
checkVersionsCompatibility(clientVersion);
}
}

/**
Expand Down Expand Up @@ -238,5 +272,26 @@ private static ManagedChannel createChannel(

return channelBuilder.build();
}

private void checkVersionsCompatibility(String clientVersion) {
try {
String serverVersion =
QdrantGrpc.newBlockingStub(this.channel)
.healthCheck(QdrantOuterClass.HealthCheckRequest.getDefaultInstance())
.getVersion();
if (!VersionsCompatibilityChecker.isCompatible(clientVersion, serverVersion)) {
System.out.println(
tellet-q marked this conversation as resolved.
Show resolved Hide resolved
"Qdrant client version "
+ clientVersion
+ " is incompatible with server version "
+ serverVersion
+ ". Major versions should match and minor version difference must not exceed 1. "
+ "Set check_version=False to skip version check.");
}
} catch (Exception e) {
System.out.println(
"Failed to obtain server version. Unable to check client-server compatibility. Set checkCompatibility=False to skip version check.");
}
}
}
}
96 changes: 96 additions & 0 deletions src/main/java/io/qdrant/client/VersionsCompatibilityChecker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package io.qdrant.client;

import java.util.ArrayList;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class Version {
private final int major;
private final int minor;
private final String rest;

public Version(int major, int minor, String rest) {
this.major = major;
this.minor = minor;
this.rest = rest;
}

public int getMajor() {
return major;
}

public int getMinor() {
return minor;
}

public String getRest() {
return rest;
}
}

/** Utility class to check compatibility between server's and client's versions. */
public class VersionsCompatibilityChecker {
private static final Logger logger = LoggerFactory.getLogger(VersionsCompatibilityChecker.class);

/** Default constructor. */
public VersionsCompatibilityChecker() {}

private static Version parseVersion(String version) throws IllegalArgumentException {
if (version.isEmpty()) {
throw new IllegalArgumentException("Version is None");
}

try {
String[] parts = version.split("\\.");
int major = parts.length > 0 ? Integer.parseInt(parts[0]) : 0;
int minor = parts.length > 1 ? Integer.parseInt(parts[1]) : 0;
String rest =
parts.length > 2
? String.join(".", new ArrayList<>(Arrays.asList(parts).subList(2, parts.length)))
: "";

return new Version(major, minor, rest);
Comment on lines +48 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rest part of the version is not used.

} catch (Exception e) {
throw new IllegalArgumentException(
"Unable to parse version, expected format: x.y.z, found: " + version, e);
}
}

/**
* Compares server's and client's versions.
*
* @param clientVersion The client's version.
* @param serverVersion The server's version.
* @return True if the versions are compatible, false otherwise.
*/
public static boolean isCompatible(String clientVersion, String serverVersion) {
if (clientVersion.isEmpty()) {
logger.warn("Unable to compare with client version {}", clientVersion);
return false;
}

if (serverVersion.isEmpty()) {
logger.warn("Unable to compare with server version {}", serverVersion);
return false;
}

if (clientVersion.equals(serverVersion)) {
return true;
}

try {
Version parsedServerVersion = parseVersion(serverVersion);
Version parsedClientVersion = parseVersion(clientVersion);

int majorDiff = Math.abs(parsedServerVersion.getMajor() - parsedClientVersion.getMajor());
if (majorDiff >= 1) {
return false;
}
return Math.abs(parsedServerVersion.getMinor() - parsedClientVersion.getMinor()) <= 1;
} catch (IllegalArgumentException e) {
logger.warn("Unable to compare versions: {}", e.getMessage());
return false;
}
}
Comment on lines +67 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public static boolean isCompatible(String clientVersion, String serverVersion) {
if (clientVersion.isEmpty()) {
logger.warn("Unable to compare with client version {}", clientVersion);
return false;
}
if (serverVersion.isEmpty()) {
logger.warn("Unable to compare with server version {}", serverVersion);
return false;
}
if (clientVersion.equals(serverVersion)) {
return true;
}
try {
Version parsedServerVersion = parseVersion(serverVersion);
Version parsedClientVersion = parseVersion(clientVersion);
int majorDiff = Math.abs(parsedServerVersion.getMajor() - parsedClientVersion.getMajor());
if (majorDiff >= 1) {
return false;
}
return Math.abs(parsedServerVersion.getMinor() - parsedClientVersion.getMinor()) <= 1;
} catch (IllegalArgumentException e) {
logger.warn("Unable to compare versions: {}", e.getMessage());
return false;
}
}
public static boolean isCompatible(String clientVersion, String serverVersion) {
try {
Version client = parseVersion(clientVersion);
Version server = parseVersion(serverVersion);
if (client.getMajor() != server.getMajor()) return false;
return Math.abs(client.getMinor() - server.getMinor()) <= 1;
} catch (IllegalArgumentException e) {
logger.warn("Version comparison failed: {}", e.getMessage());
return false;
}
}

Could be simpler like this I think.

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package io.qdrant.client;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

public class VersionsCompatibilityCheckerTest {
private static Stream<Object[]> validVersionProvider() {
return Stream.of(
new Object[] {"1.2.3", 1, 2, "3"},
new Object[] {"1.2.3-alpha", 1, 2, "3-alpha"},
new Object[] {"1.2", 1, 2, ""},
new Object[] {"1", 1, 0, ""},
new Object[] {"1.", 1, 0, ""});
}

@ParameterizedTest
@MethodSource("validVersionProvider")
public void testParseVersion_validVersion(
String versionStr, int expectedMajor, int expectedMinor, String expectedRest)
throws Exception {
Method method =
VersionsCompatibilityChecker.class.getDeclaredMethod("parseVersion", String.class);
method.setAccessible(true);
Version version = (Version) method.invoke(null, versionStr);
assertEquals(expectedMajor, version.getMajor());
assertEquals(expectedMinor, version.getMinor());
assertEquals(expectedRest, version.getRest());
}

private static Stream<String> invalidVersionProvider() {
return Stream.of("v1.12.0", "", ".1", ".1.", "1.null.1", "null.0.1", null);
}

@ParameterizedTest
@MethodSource("invalidVersionProvider")
public void testParseVersion_invalidVersion(String versionStr) throws Exception {
Method method =
VersionsCompatibilityChecker.class.getDeclaredMethod("parseVersion", String.class);
method.setAccessible(true);
assertThrows(
InvocationTargetException.class,
() -> method.invoke(null, versionStr));
}

private static Stream<Object[]> versionCompatibilityProvider() {
return Stream.of(
new Object[] {"1.9.3.dev0", "2.8.1.dev12-something", false},
new Object[] {"1.9", "2.8", false},
new Object[] {"1", "2", false},
new Object[] {"1.9.0", "2.9.0", false},
new Object[] {"1.1.0", "1.2.9", true},
new Object[] {"1.2.7", "1.1.8.dev0", true},
new Object[] {"1.2.1", "1.2.29", true},
new Object[] {"1.2.0", "1.2.0", true},
new Object[] {"1.2.0", "1.4.0", false},
new Object[] {"1.4.0", "1.2.0", false},
new Object[] {"1.9.0", "3.7.0", false},
new Object[] {"3.0.0", "1.0.0", false},
new Object[] {"", "1.0.0", false},
new Object[] {"1.0.0", "", false},
new Object[] {"", "", false});
}

@ParameterizedTest
@MethodSource("versionCompatibilityProvider")
public void testIsCompatible(String clientVersion, String serverVersion, boolean expected) {
assertEquals(expected, VersionsCompatibilityChecker.isCompatible(clientVersion, serverVersion));
}
}
Loading