Skip to content

Commit

Permalink
Create unit tests for the serve proto handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
Ericson2314 committed Jan 19, 2024
1 parent 7577997 commit b3da462
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
Binary file not shown.
110 changes: 110 additions & 0 deletions tests/unit/libstore/serve-protocol.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <thread>
#include <regex>

#include <nlohmann/json.hpp>
Expand All @@ -6,6 +7,7 @@
#include "serve-protocol.hh"
#include "serve-protocol-impl.hh"
#include "build-result.hh"
#include "file-descriptor.hh"
#include "tests/protocol.hh"
#include "tests/characterization.hh"

Expand Down Expand Up @@ -401,4 +403,112 @@ VERSIONED_CHARACTERIZATION_TEST(
},
}))

TEST_F(ServeProtoTest, handshake_log)
{
CharacterizationTest::writeTest("handshake-to-client", [&]() -> std::string {
StringSink toClientLog;

Pipe toClient, toServer;
toClient.create();
toServer.create();

ServeProto::Version clientResult, serverResult;

auto thread = std::thread([&]() {
FdSink out { toServer.writeSide.get() };
FdSource in0 { toClient.readSide.get() };
TeeSource in { in0, toClientLog };
clientResult = ServeProto::BasicClientConnection::handshake(
out, in, defaultVersion, "blah");
});

{
FdSink out { toClient.writeSide.get() };
FdSource in { toServer.readSide.get() };
serverResult = ServeProto::BasicServerConnection::handshake(
out, in, defaultVersion);
};

thread.join();

return std::move(toClientLog.s);
});
}

/// Has to be a `BufferedSink` for handshake.
struct NullBufferedSink : BufferedSink {
void writeUnbuffered(std::string_view data) override { }
};

TEST_F(ServeProtoTest, handshake_client_replay)
{
CharacterizationTest::readTest("handshake-to-client", [&](std::string toClientLog) {
NullBufferedSink nullSink;

StringSource in { toClientLog };
auto clientResult = ServeProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, "blah");

EXPECT_EQ(clientResult, defaultVersion);
});
}

TEST_F(ServeProtoTest, handshake_client_trunated_replay_throws)
{
CharacterizationTest::readTest("handshake-to-client", [&](std::string toClientLog) {
for (size_t len = 0; len < toClientLog.size(); ++len) {
NullBufferedSink nullSink;
StringSource in {
// truncate
toClientLog.substr(0, len)
};
if (len < 8) {
EXPECT_THROW(
ServeProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, "blah"),
EndOfFile);
} else {
// Not sure why cannot keep on checking for `EndOfFile`.
EXPECT_THROW(
ServeProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, "blah"),
Error);
}
}
});
}

TEST_F(ServeProtoTest, handshake_client_corrupted_throws)
{
CharacterizationTest::readTest("handshake-to-client", [&](const std::string toClientLog) {
for (size_t idx = 0; idx < toClientLog.size(); ++idx) {
// corrupt a copy
std::string toClientLogCorrupt = toClientLog;
toClientLogCorrupt[idx] *= 4;
++toClientLogCorrupt[idx];

NullBufferedSink nullSink;
StringSource in { toClientLogCorrupt };

if (idx < 4 || idx == 9) {
// magic bytes don't match
EXPECT_THROW(
ServeProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, "blah"),
Error);
} else if (idx < 8 || idx >= 12) {
// Number out of bounds
EXPECT_THROW(
ServeProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, "blah"),
SerialisationError);
} else {
auto ver = ServeProto::BasicClientConnection::handshake(
nullSink, in, defaultVersion, "blah");
EXPECT_NE(ver, defaultVersion);
}
}
});
}

}

0 comments on commit b3da462

Please sign in to comment.