Skip to content

Commit

Permalink
fix: enforce load limits when loading snapshot (#4136)
Browse files Browse the repository at this point in the history
* fix: enforce load limits when loading snapshot

Prevent loading snapshots with used memory higher than max memory limit.

1. Store the used memory metadata only inside the summary file
2. Load the summary file before loading anything else, and if the used-memory is higher,
   abort the load.
---------

Signed-off-by: Roman Gershman <[email protected]>
  • Loading branch information
romange authored Nov 20, 2024
1 parent 4e7800f commit 0e7ae34
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 80 deletions.
2 changes: 1 addition & 1 deletion src/redis/rdb.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,6 @@
// Currently moved here from server.h
#define LONG_STR_SIZE 21 /* Bytes needed for long -> str + '\0' */

#define REDIS_VERSION "999.999.999"
#define REDIS_VERSION "6.2.11"

#endif
14 changes: 9 additions & 5 deletions src/server/detail/snapshot_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ string SnapshotStorage::FindMatchingFile(string_view prefix, string_view dbfilen
return {};
}

io::Result<vector<string>, GenericError> SnapshotStorage::ExpandSnapshot(const string& load_path) {
io::Result<SnapshotStorage::ExpandResult, GenericError> SnapshotStorage::ExpandSnapshot(
const string& load_path) {
if (!(absl::EndsWith(load_path, ".rdb") || absl::EndsWith(load_path, "summary.dfs"))) {
return nonstd::make_unexpected(
GenericError(std::make_error_code(std::errc::invalid_argument), "Bad filename extension"));
Expand All @@ -101,17 +102,20 @@ io::Result<vector<string>, GenericError> SnapshotStorage::ExpandSnapshot(const s
return nonstd::make_unexpected(GenericError(ec, "File not found"));
}

vector<string> paths{{load_path}};
ExpandResult result;

// Collect all other files in case we're loading dfs.
if (absl::EndsWith(load_path, "summary.dfs")) {
auto res = ExpandFromPath(load_path);
if (!res) {
return res;
return nonstd::make_unexpected(res.error());
}
paths.insert(paths.end(), res->begin(), res->end());
result = std::move(*res);
result.push_back(load_path);
} else {
result.push_back(load_path);
}
return paths;
return result;
}

FileSnapshotStorage::FileSnapshotStorage(fb2::FiberQueueThreadPool* fq_threadpool)
Expand Down
3 changes: 2 additions & 1 deletion src/server/detail/snapshot_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ class SnapshotStorage {
virtual io::Result<std::string, GenericError> LoadPath(std::string_view dir,
std::string_view dbfilename) = 0;

using ExpandResult = std::vector<std::string>;
// Searches for all the relevant snapshot files given the RDB file or DFS summary file path.
io::Result<std::vector<std::string>, GenericError> ExpandSnapshot(const std::string& load_path);
io::Result<ExpandResult, GenericError> ExpandSnapshot(const std::string& load_path);

virtual bool IsCloud() const {
return false;
Expand Down
13 changes: 11 additions & 2 deletions src/server/rdb_load.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ string error_category::message(int ev) const {
switch (ev) {
case errc::wrong_signature:
return "Wrong signature while trying to load from rdb file";
case errc::out_of_memory:
return "Out of memory, or used memory is too high";
default:
return absl::StrCat("Internal error when loading RDB file ", ev);
break;
Expand Down Expand Up @@ -2596,7 +2598,9 @@ error_code RdbLoader::HandleAux() {
} else if (auxkey == "lua") {
LoadScriptFromAux(std::move(auxval));
} else if (auxkey == "redis-ver") {
VLOG(1) << "Loading RDB produced by version " << auxval;
VLOG(1) << "Loading RDB produced by Redis version " << auxval;
} else if (auxkey == "df-ver") {
VLOG(1) << "Loading RDB produced by Dragonfly version " << auxval;
} else if (auxkey == "ctime") {
int64_t ctime;
if (absl::SimpleAtoi(auxval, &ctime)) {
Expand All @@ -2606,9 +2610,14 @@ error_code RdbLoader::HandleAux() {
VLOG(1) << "RDB age " << strings::HumanReadableElapsedTime(age);
}
} else if (auxkey == "used-mem") {
long long usedmem;
int64_t usedmem;
if (absl::SimpleAtoi(auxval, &usedmem)) {
VLOG(1) << "RDB memory usage when created " << strings::HumanReadableNumBytes(usedmem);
if (usedmem > ssize_t(max_memory_limit)) {
LOG(WARNING) << "Could not load snapshot - its used memory is " << usedmem
<< " but the limit is " << max_memory_limit;
return RdbError(errc::out_of_memory);
}
}
} else if (auxkey == "aof-preamble") {
long long haspreamble;
Expand Down
8 changes: 5 additions & 3 deletions src/server/rdb_save.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1561,16 +1561,18 @@ void RdbSaver::FillFreqMap(RdbTypeFreqMap* freq_map) {
error_code RdbSaver::SaveAux(const GlobalData& glob_state) {
static_assert(sizeof(void*) == 8, "");

int aof_preamble = false;
error_code ec;

/* Add a few fields about the state when the RDB was created. */
RETURN_ON_ERR(impl_->SaveAuxFieldStrStr("redis-ver", REDIS_VERSION));
RETURN_ON_ERR(impl_->SaveAuxFieldStrStr("df-ver", GetVersion()));
RETURN_ON_ERR(SaveAuxFieldStrInt("redis-bits", 64));

RETURN_ON_ERR(SaveAuxFieldStrInt("ctime", time(NULL)));
RETURN_ON_ERR(SaveAuxFieldStrInt("used-mem", used_mem_current.load(memory_order_relaxed)));
RETURN_ON_ERR(SaveAuxFieldStrInt("aof-preamble", aof_preamble));
auto used_mem = used_mem_current.load(memory_order_relaxed);
VLOG(1) << "Used memory during save: " << used_mem;
RETURN_ON_ERR(SaveAuxFieldStrInt("used-mem", used_mem));
RETURN_ON_ERR(SaveAuxFieldStrInt("aof-preamble", 0));

// Save lua scripts only in rdb or summary file
DCHECK(save_mode_ != SaveMode::SINGLE_SHARD || glob_state.lua_scripts.empty());
Expand Down
11 changes: 10 additions & 1 deletion src/server/rdb_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TEST_F(RdbTest, Crc) {

TEST_F(RdbTest, LoadEmpty) {
auto ec = LoadRdb("empty.rdb");
CHECK(!ec);
ASSERT_FALSE(ec) << ec;
}

TEST_F(RdbTest, LoadSmall6) {
Expand Down Expand Up @@ -646,4 +646,13 @@ TEST_F(RdbTest, LoadHugeStream) {
ASSERT_EQ(2000, CheckedInt({"xlen", "test:0"}));
}

TEST_F(RdbTest, SnapshotTooBig) {
// Run({"debug", "populate", "10000", "foo", "1000"});
// usleep(5000); // let the stats to sync
max_memory_limit = 100000;
used_mem_current = 1000000;
auto resp = Run({"debug", "reload"});
ASSERT_THAT(resp, ErrArg("Out of memory"));
}

} // namespace dfly
49 changes: 25 additions & 24 deletions src/server/server_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1083,24 +1083,23 @@ std::optional<fb2::Future<GenericError>> ServerFamily::Load(string_view load_pat

DCHECK_GT(shard_count(), 0u);

if (ServerState::tlocal() && !ServerState::tlocal()->is_master) {
// TODO: to move it to helio.
auto immediate = [](auto val) {
fb2::Future<GenericError> future;
future.Resolve(string("Replica cannot load data"));
future.Resolve(val);
return future;
}

auto paths_result = snapshot_storage_->ExpandSnapshot(path);
if (!paths_result) {
LOG(ERROR) << "Failed to load snapshot: " << paths_result.error().Format();
};

fb2::Future<GenericError> future;
future.Resolve(paths_result.error());
return future;
if (ServerState::tlocal() && !ServerState::tlocal()->is_master) {
return immediate(string("Replica cannot load data"));
}

std::vector<std::string> paths = *paths_result;
auto expand_result = snapshot_storage_->ExpandSnapshot(path);
if (!expand_result) {
LOG(ERROR) << "Failed to load snapshot: " << expand_result.error().Format();

LOG(INFO) << "Loading " << path;
return immediate(expand_result.error());
}

auto new_state = service_.SwitchState(GlobalState::ACTIVE, GlobalState::LOADING);
if (new_state != GlobalState::LOADING) {
Expand All @@ -1110,6 +1109,10 @@ std::optional<fb2::Future<GenericError>> ServerFamily::Load(string_view load_pat

auto& pool = service_.proactor_pool();

const vector<string>& paths = *expand_result;

LOG(INFO) << "Loading " << path;

vector<fb2::Fiber> load_fibers;
load_fibers.reserve(paths.size());

Expand All @@ -1125,39 +1128,36 @@ std::optional<fb2::Future<GenericError>> ServerFamily::Load(string_view load_pat
proactor = pool.GetNextProactor();
}

auto load_fiber = [this, aggregated_result, existing_keys, path = std::move(path)]() {
auto load_func = [this, aggregated_result, existing_keys, path = std::move(path)]() {
auto load_result = LoadRdb(path, existing_keys);
if (load_result.has_value())
aggregated_result->keys_read.fetch_add(*load_result);
else
aggregated_result->first_error = load_result.error();
};
load_fibers.push_back(proactor->LaunchFiber(std::move(load_fiber)));
load_fibers.push_back(proactor->LaunchFiber(std::move(load_func)));
}

fb2::Future<GenericError> future;

// Run fiber that empties the channel and sets ec_promise.
auto load_join_fiber = [this, aggregated_result, load_fibers = std::move(load_fibers),
future]() mutable {
auto load_join_func = [this, aggregated_result, load_fibers = std::move(load_fibers),
future]() mutable {
for (auto& fiber : load_fibers) {
fiber.Join();
}

if (aggregated_result->first_error) {
LOG(ERROR) << "Rdb load failed. " << (*aggregated_result->first_error).message();
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
future.Resolve(*aggregated_result->first_error);
return;
LOG(ERROR) << "Rdb load failed: " << (*aggregated_result->first_error).message();
} else {
RdbLoader::PerformPostLoad(&service_);
LOG(INFO) << "Load finished, num keys read: " << aggregated_result->keys_read;
}

RdbLoader::PerformPostLoad(&service_);

LOG(INFO) << "Load finished, num keys read: " << aggregated_result->keys_read;
service_.SwitchState(GlobalState::LOADING, GlobalState::ACTIVE);
future.Resolve(*(aggregated_result->first_error));
};
pool.GetNextProactor()->Dispatch(std::move(load_join_fiber));
pool.GetNextProactor()->Dispatch(std::move(load_join_func));

return future;
}
Expand Down Expand Up @@ -1196,6 +1196,7 @@ void ServerFamily::SnapshotScheduling() {
io::Result<size_t> ServerFamily::LoadRdb(const std::string& rdb_file,
LoadExistingKeys existing_keys) {
VLOG(1) << "Loading data from " << rdb_file;
CHECK(fb2::ProactorBase::IsProactorThread()) << "must be called from proactor thread";

error_code ec;
io::ReadonlyFileOrError res = snapshot_storage_->OpenReadFile(rdb_file);
Expand Down
Loading

0 comments on commit 0e7ae34

Please sign in to comment.