Skip to content

Commit

Permalink
WorkerProto: Support fine-grained protocol feature negotiation
Browse files Browse the repository at this point in the history
Currently, the worker protocol has a version number that we increment
whenever we change something in the protocol. However, this can cause
a collision between Nix PRs / forks that make protocol changes
(e.g. PR NixOS#9857 increments the version, which could collide with
another PR). So instead, the client and daemon now exchange a set of
protocol features (such as `auth-forwarding`). They will use the
intersection of the sets of features, i.e. the features they both
support.

Note that protocol features are completely distinct from
`ExperimentalFeature`s.
  • Loading branch information
edolstra committed Jul 24, 2024
1 parent b13ba74 commit 27f2711
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 29 deletions.
11 changes: 6 additions & 5 deletions src/libstore/daemon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1025,19 +1025,20 @@ void processConnection(
#endif

/* Exchange the greeting. */
WorkerProto::Version clientVersion =
auto [protoVersion, features] =
WorkerProto::BasicServerConnection::handshake(
to, from, PROTOCOL_VERSION);
to, from, PROTOCOL_VERSION, WorkerProto::allFeatures);

if (clientVersion < 0x10a)
if (protoVersion < 0x10a)
throw Error("the Nix client version is too old");

WorkerProto::BasicServerConnection conn;
conn.to = std::move(to);
conn.from = std::move(from);
conn.protoVersion = clientVersion;
conn.protoVersion = protoVersion;
conn.features = features;

auto tunnelLogger = new TunnelLogger(conn.to, clientVersion);
auto tunnelLogger = new TunnelLogger(conn.to, protoVersion);
auto prevLogger = nix::logger;
// FIXME
if (!recursive)
Expand Down
10 changes: 8 additions & 2 deletions src/libstore/remote-store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ void RemoteStore::initConnection(Connection & conn)
StringSink saved;
TeeSource tee(conn.from, saved);
try {
conn.protoVersion = WorkerProto::BasicClientConnection::handshake(
conn.to, tee, PROTOCOL_VERSION);
auto [protoVersion, features] = WorkerProto::BasicClientConnection::handshake(
conn.to, tee, PROTOCOL_VERSION,
WorkerProto::allFeatures);
conn.protoVersion = protoVersion;
conn.features = features;
} catch (SerialisationError & e) {
/* In case the other side is waiting for our input, close
it. */
Expand All @@ -88,6 +91,9 @@ void RemoteStore::initConnection(Connection & conn)

static_cast<WorkerProto::ClientHandshakeInfo &>(conn) = conn.postHandshake(*this);

for (auto & feature : conn.features)
debug("negotiated feature '%s'", feature);

auto ex = conn.processStderrReturn();
if (ex) std::rethrow_exception(ex);
}
Expand Down
51 changes: 45 additions & 6 deletions src/libstore/worker-protocol-connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

namespace nix {

const std::set<WorkerProto::Feature> WorkerProto::allFeatures{};

WorkerProto::BasicClientConnection::~BasicClientConnection()
{
try {
Expand Down Expand Up @@ -137,8 +139,21 @@ void WorkerProto::BasicClientConnection::processStderr(bool * daemonException, S
}
}

WorkerProto::Version
WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion)
static std::set<WorkerProto::Feature>
intersectFeatures(const std::set<WorkerProto::Feature> & a, const std::set<WorkerProto::Feature> & b)
{
std::set<WorkerProto::Feature> res;
for (auto & x : a)
if (b.contains(x))
res.insert(x);
return res;
}

std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> WorkerProto::BasicClientConnection::handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<WorkerProto::Feature> & supportedFeatures)
{
to << WORKER_MAGIC_1 << localVersion;
to.flush();
Expand All @@ -153,19 +168,43 @@ WorkerProto::BasicClientConnection::handshake(BufferedSink & to, Source & from,
if (GET_PROTOCOL_MINOR(daemonVersion) < 10)
throw Error("the Nix daemon version is too old");

return std::min(daemonVersion, localVersion);
auto protoVersion = std::min(daemonVersion, localVersion);

/* Exchange features. */
std::set<WorkerProto::Feature> daemonFeatures;
if (GET_PROTOCOL_MINOR(protoVersion) >= 38) {
to << supportedFeatures;
to.flush();
daemonFeatures = readStrings<std::set<WorkerProto::Feature>>(from);
}

return {protoVersion, intersectFeatures(daemonFeatures, supportedFeatures)};
}

WorkerProto::Version
WorkerProto::BasicServerConnection::handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion)
std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> WorkerProto::BasicServerConnection::handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<WorkerProto::Feature> & supportedFeatures)
{
unsigned int magic = readInt(from);
if (magic != WORKER_MAGIC_1)
throw Error("protocol mismatch");
to << WORKER_MAGIC_2 << localVersion;
to.flush();
auto clientVersion = readInt(from);
return std::min(clientVersion, localVersion);

auto protoVersion = std::min(clientVersion, localVersion);

/* Exchange features. */
std::set<WorkerProto::Feature> clientFeatures;
if (GET_PROTOCOL_MINOR(protoVersion) >= 38) {
clientFeatures = readStrings<std::set<WorkerProto::Feature>>(from);
to << supportedFeatures;
to.flush();
}

return {protoVersion, intersectFeatures(clientFeatures, supportedFeatures)};
}

WorkerProto::ClientHandshakeInfo WorkerProto::BasicClientConnection::postHandshake(const StoreDirConfig & store)
Expand Down
27 changes: 23 additions & 4 deletions src/libstore/worker-protocol-connection.hh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ struct WorkerProto::BasicConnection
*/
WorkerProto::Version protoVersion;

/**
* The set of features that both sides support.
*/
std::set<Feature> features;

/**
* Coercion to `WorkerProto::ReadConn`. This makes it easy to use the
* factored out serve protocol serializers with a
Expand Down Expand Up @@ -72,8 +77,8 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection
/**
* Establishes connection, negotiating version.
*
* @return the version provided by the other side of the
* connection.
* @return the minimum version supported by both sides and the set
* of protocol features supported by both sides.
*
* @param to Taken by reference to allow for various error handling
* mechanisms.
Expand All @@ -82,8 +87,15 @@ struct WorkerProto::BasicClientConnection : WorkerProto::BasicConnection
* handling mechanisms.
*
* @param localVersion Our version which is sent over
*
* @param features The protocol features that we support
*/
static Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion);
// FIXME: this should probably be a constructor.
static std::tuple<Version, std::set<Feature>> handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<Feature> & supportedFeatures);

/**
* After calling handshake, must call this to exchange some basic
Expand Down Expand Up @@ -138,8 +150,15 @@ struct WorkerProto::BasicServerConnection : WorkerProto::BasicConnection
* handling mechanisms.
*
* @param localVersion Our version which is sent over
*
* @param features The protocol features that we support
*/
static WorkerProto::Version handshake(BufferedSink & to, Source & from, WorkerProto::Version localVersion);
// FIXME: this should probably be a constructor.
static std::tuple<Version, std::set<Feature>> handshake(
BufferedSink & to,
Source & from,
WorkerProto::Version localVersion,
const std::set<Feature> & supportedFeatures);

/**
* After calling handshake, must call this to exchange some basic
Expand Down
8 changes: 7 additions & 1 deletion src/libstore/worker-protocol.hh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ namespace nix {
#define WORKER_MAGIC_1 0x6e697863
#define WORKER_MAGIC_2 0x6478696f

#define PROTOCOL_VERSION (1 << 8 | 37)
/* Note: you generally shouldn't change the protocol version. Add a
feature to `WorkerProtoFeature` instead. */
#define PROTOCOL_VERSION (1 << 8 | 38)
#define GET_PROTOCOL_MAJOR(x) ((x) & 0xff00)
#define GET_PROTOCOL_MINOR(x) ((x) & 0x00ff)

Expand Down Expand Up @@ -131,6 +133,10 @@ struct WorkerProto
{
WorkerProto::Serialise<T>::write(store, conn, t);
}

using Feature = std::string;

static const std::set<Feature> allFeatures;
};

enum struct WorkerProto::Op : uint64_t
Expand Down
49 changes: 38 additions & 11 deletions tests/unit/libstore/worker-protocol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,15 +658,15 @@ TEST_F(WorkerProtoTest, handshake_log)
FdSink out { toServer.writeSide.get() };
FdSource in0 { toClient.readSide.get() };
TeeSource in { in0, toClientLog };
clientResult = WorkerProto::BasicClientConnection::handshake(
out, in, defaultVersion);
clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake(
out, in, defaultVersion, {}));
});

{
FdSink out { toClient.writeSide.get() };
FdSource in { toServer.readSide.get() };
WorkerProto::BasicServerConnection::handshake(
out, in, defaultVersion);
out, in, defaultVersion, {});
};

thread.join();
Expand All @@ -675,6 +675,33 @@ TEST_F(WorkerProtoTest, handshake_log)
});
}

TEST_F(WorkerProtoTest, handshake_features)
{
Pipe toClient, toServer;
toClient.create();
toServer.create();

std::tuple<WorkerProto::Version, std::set<WorkerProto::Feature>> clientResult;

auto clientThread = std::thread([&]() {
FdSink out { toServer.writeSide.get() };
FdSource in { toClient.readSide.get() };
clientResult = WorkerProto::BasicClientConnection::handshake(
out, in, 123, {"bar", "aap", "mies", "xyzzy"});
});

FdSink out { toClient.writeSide.get() };
FdSource in { toServer.readSide.get() };
auto daemonResult = WorkerProto::BasicServerConnection::handshake(
out, in, 456, {"foo", "bar", "xyzzy"});

clientThread.join();

EXPECT_EQ(clientResult, daemonResult);
EXPECT_EQ(std::get<0>(clientResult), 123);
EXPECT_EQ(std::get<1>(clientResult), std::set<WorkerProto::Feature>({"bar", "xyzzy"}));
}

/// Has to be a `BufferedSink` for handshake.
struct NullBufferedSink : BufferedSink {
void writeUnbuffered(std::string_view data) override { }
Expand All @@ -686,8 +713,8 @@ TEST_F(WorkerProtoTest, handshake_client_replay)
NullBufferedSink nullSink;

StringSource in { toClientLog };
auto clientResult = WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion);
auto clientResult = std::get<0>(WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, {}));

EXPECT_EQ(clientResult, defaultVersion);
});
Expand All @@ -705,13 +732,13 @@ TEST_F(WorkerProtoTest, handshake_client_truncated_replay_throws)
if (len < 8) {
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
EndOfFile);
} else {
// Not sure why cannot keep on checking for `EndOfFile`.
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
Error);
}
}
Expand All @@ -734,17 +761,17 @@ TEST_F(WorkerProtoTest, handshake_client_corrupted_throws)
// magic bytes don't match
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
Error);
} else if (idx < 8 || idx >= 12) {
// Number out of bounds
EXPECT_THROW(
WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion),
nullSink, in, defaultVersion, {}),
SerialisationError);
} else {
auto ver = WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion);
auto ver = std::get<0>(WorkerProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, {}));
// `std::min` of this and the other version saves us
EXPECT_EQ(ver, defaultVersion);
}
Expand Down

0 comments on commit 27f2711

Please sign in to comment.