From 1351e5489fe367725ea1a7ab183f7b415bbf5bd5 Mon Sep 17 00:00:00 2001 From: Kostas Kyrimis Date: Tue, 14 May 2024 20:02:32 +0300 Subject: [PATCH 1/6] chore: add WasmRegistry and WasmModule (#3044) * add WasmRegistry and WasmModule * add wasmtime deps --- src/server/CMakeLists.txt | 11 +- src/server/acl/acl_commands_def.h | 4 +- src/server/main_service.cc | 1 + src/server/main_service.h | 2 + src/server/wasm/api.h | 25 + src/server/wasm/wasm_family.cc | 63 + src/server/wasm/wasm_family.h | 29 + src/server/wasm/wasm_registry.cc | 100 + src/server/wasm/wasm_registry.h | 92 + src/server/wasm/wasmtime.hh | 3268 +++++++++++++++++++++++++++++ 10 files changed, 3593 insertions(+), 2 deletions(-) create mode 100644 src/server/wasm/api.h create mode 100644 src/server/wasm/wasm_family.cc create mode 100644 src/server/wasm/wasm_family.h create mode 100644 src/server/wasm/wasm_registry.cc create mode 100644 src/server/wasm/wasm_registry.h create mode 100644 src/server/wasm/wasmtime.hh diff --git a/src/server/CMakeLists.txt b/src/server/CMakeLists.txt index 57e7514511aa..a8aa76225419 100644 --- a/src/server/CMakeLists.txt +++ b/src/server/CMakeLists.txt @@ -54,7 +54,8 @@ add_library(dragonfly_lib bloom_family.cc engine_shard_set.cc channel_store.cc cluster/cluster_family.cc cluster/incoming_slot_migration.cc cluster/outgoing_slot_migration.cc cluster/cluster_defs.cc acl/user.cc acl/user_registry.cc acl/acl_family.cc - acl/validator.cc acl/helpers.cc) + acl/validator.cc acl/helpers.cc + wasm/wasm_registry.cc wasm/wasm_family.cc) if (DF_ENABLE_MEMORY_TRACKING) target_compile_definitions(dragonfly_lib PRIVATE DFLY_ENABLE_MEMORY_TRACKING) @@ -77,6 +78,14 @@ cxx_link(dragonfly_lib dfly_transaction dfly_facade redis_lib awsv2_lib jsonpath http_client_lib absl::random_random TRDP::jsoncons ${ZSTD_LIB} TRDP::lz4 TRDP::croncpp TRDP::flatbuffers) +# Better way to integrate +# https://github.com/corrosion-rs/corrosion +target_include_directories(dfly_transaction PUBLIC ${CMAKE_SOURCE_DIR}/c-api/include) + +add_library(wasmtime STATIC IMPORTED) +target_include_directories(dragonfly_lib PRIVATE ${CMAKE_SOURCE_DIR}/c-api/include) +target_link_libraries(dragonfly_lib ${CMAKE_SOURCE_DIR}/c-api/lib/libwasmtime.a) + if (DF_USE_SSL) set(TLS_LIB tls_lib) target_compile_definitions(dragonfly_lib PRIVATE DFLY_USE_SSL) diff --git a/src/server/acl/acl_commands_def.h b/src/server/acl/acl_commands_def.h index 86324ab490bd..f927e120a5be 100644 --- a/src/server/acl/acl_commands_def.h +++ b/src/server/acl/acl_commands_def.h @@ -37,6 +37,7 @@ enum AclCat { SCRIPTING = 1ULL << 20, // Extensions + WASM = 1ULL << 27, BLOOM = 1ULL << 28, FT_SEARCH = 1ULL << 29, THROTTLE = 1ULL << 30, @@ -67,6 +68,7 @@ inline const absl::flat_hash_map CATEGORY_INDEX_TABL {"CONNECTION", CONNECTION}, {"TRANSACTION", TRANSACTION}, {"SCRIPTING", SCRIPTING}, + {"WASM", WASM}, {"BLOOM", BLOOM}, {"FT_SEARCH", FT_SEARCH}, {"THROTTLE", THROTTLE}, @@ -81,7 +83,7 @@ inline const std::vector REVERSE_CATEGORY_INDEX_TABLE{ "KEYSPACE", "READ", "WRITE", "SET", "SORTEDSET", "LIST", "HASH", "STRING", "BITMAP", "HYPERLOG", "GEO", "STREAM", "PUBSUB", "ADMIN", "FAST", "SLOW", "BLOCKING", "DANGEROUS", "CONNECTION", "TRANSACTION", "SCRIPTING", - "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", + "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "_RESERVED", "WASM", "BLOOM", "FT_SEARCH", "THROTTLE", "JSON"}; using RevCommandField = std::vector; diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 0827bb75b8f9..4244e273323b 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -2641,6 +2641,7 @@ void Service::RegisterCommands() { BloomFamily::Register(®istry_); server_family_.Register(®istry_); cluster_family_.Register(®istry_); + wasm_family_.Register(®istry_); acl_family_.Register(®istry_); acl::BuildIndexers(registry_.GetFamilies()); diff --git a/src/server/main_service.h b/src/server/main_service.h index 0bc318378d3f..b70d539f2631 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -17,6 +17,7 @@ #include "server/config_registry.h" #include "server/engine_shard_set.h" #include "server/server_family.h" +#include "server/wasm/wasm_family.h" namespace util { class AcceptServer; @@ -185,6 +186,7 @@ class Service : public facade::ServiceInterface { cluster::ClusterFamily cluster_family_; CommandRegistry registry_; absl::flat_hash_map unknown_cmds_; + wasm::WasmFamily wasm_family_; const CommandId* exec_cid_; // command id of EXEC command for pipeline squashing diff --git a/src/server/wasm/api.h b/src/server/wasm/api.h new file mode 100644 index 000000000000..05919062a2be --- /dev/null +++ b/src/server/wasm/api.h @@ -0,0 +1,25 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include +#include +#include +#include + +#include "base/logging.h" +#include "server/wasm/wasmtime.hh" + +namespace dfly::wasm::api { + +template +bool RegisterApiFunction(std::string_view name, Fn f, wasmtime::Linker* linker) { + std::string module_name = "dragonfly"; + auto args_signature = wasmtime::FuncType({}, {}); + auto res = linker->func_new(module_name, name, args_signature, f); + return (bool)res; +} + +} // namespace dfly::wasm::api diff --git a/src/server/wasm/wasm_family.cc b/src/server/wasm/wasm_family.cc new file mode 100644 index 000000000000..c6130b62f0bc --- /dev/null +++ b/src/server/wasm/wasm_family.cc @@ -0,0 +1,63 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// +#include "server/wasm/wasm_family.h" + +#include "absl/strings/str_cat.h" +#include "facade/facade_types.h" +#include "server/acl/acl_commands_def.h" +#include "server/command_registry.h" +#include "server/conn_context.h" + +namespace dfly { +namespace wasm { + +using MemberFunc = void (WasmFamily::*)(CmdArgList args, ConnectionContext* cntx); + +CommandId::Handler HandlerFunc(WasmFamily* wasm, MemberFunc f) { + return [=](CmdArgList args, ConnectionContext* cntx) { return (wasm->*f)(args, cntx); }; +} + +#define HFUNC(x) SetHandler(HandlerFunc(this, &WasmFamily::x)) + +void WasmFamily::Register(dfly::CommandRegistry* registry) { + using CI = dfly::CommandId; + registry->StartFamily(); + *registry << CI{"WASMCALL", dfly::CO::LOADING, 2, 0, 0, acl::WASM}.HFUNC(Call); + *registry << CI{"WASMLOAD", dfly::CO::LOADING, 2, 0, 0, acl::WASM}.HFUNC(Load); + *registry << CI{"WASMDEL", dfly::CO::LOADING, 2, 0, 0, acl::WASM}.HFUNC(Delete); +} + +void WasmFamily::Load(CmdArgList args, ConnectionContext* cntx) { + auto path = absl::StrCat(facade::ToSV(args[0]), "\0"); + if (auto res = registry_.Add(path); !res.empty()) { + cntx->SendError(res); + return; + } + auto slash = path.rfind('/'); + auto name = path; + if (slash != path.npos) { + name = name.substr(slash + 1); + } + cntx->SendOk(); +} + +void WasmFamily::Call(CmdArgList args, ConnectionContext* cntx) { + auto name = facade::ToSV(args[0]); + auto res = registry_.GetInstanceFromModule(name); + if (!res) { + cntx->SendError(absl::StrCat("Could not find module with ", name)); + return; + } + auto& wasm_instance = *res; + wasm_instance(); + cntx->SendOk(); +} + +void WasmFamily::Delete(CmdArgList args, ConnectionContext* cntx) { + auto name = facade::ToSV(args[0]); + cntx->SendLong(registry_.Delete(name)); +} + +} // namespace wasm +} // namespace dfly diff --git a/src/server/wasm/wasm_family.h b/src/server/wasm/wasm_family.h new file mode 100644 index 000000000000..6b1bdaa70990 --- /dev/null +++ b/src/server/wasm/wasm_family.h @@ -0,0 +1,29 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include "facade/facade_types.h" +#include "server/command_registry.h" +#include "server/wasm/wasm_registry.h" + +namespace dfly { + +class ConnectionContext; +namespace wasm { + +class WasmFamily final { + public: + void Register(CommandRegistry* registry); + + private: + void Load(facade::CmdArgList args, ConnectionContext* cntx); + void Call(facade::CmdArgList args, ConnectionContext* cntx); + void Delete(facade::CmdArgList args, ConnectionContext* cntx); + + WasmRegistry registry_; +}; + +} // namespace wasm +} // namespace dfly diff --git a/src/server/wasm/wasm_registry.cc b/src/server/wasm/wasm_registry.cc new file mode 100644 index 000000000000..6f06e5567090 --- /dev/null +++ b/src/server/wasm/wasm_registry.cc @@ -0,0 +1,100 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "server/wasm/wasm_registry.h" + +#include + +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/strings/str_cat.h" +#include "base/logging.h" +#include "io/file_util.h" +#include "server/wasm/api.h" +#include "server/wasm/wasmtime.hh" + +namespace dfly::wasm { + +WasmRegistry::WasmRegistry() + : engine_(WasmRegistry::GetConfig()), linker_(engine_), store_(engine_) { + api::RegisterApiFunction( + "hello", + [](auto...) { + LOG(INFO) << "Hello from WASM"; + return std::monostate(); + }, + &linker_); + + wasmtime::WasiConfig wasi; + wasi.inherit_argv(); + wasi.inherit_env(); + wasi.inherit_stdin(); + wasi.inherit_stdout(); + wasi.inherit_stderr(); + store_.context().set_wasi(std::move(wasi)).unwrap(); + + linker_.define_wasi().unwrap(); +} + +WasmRegistry::~WasmRegistry() { +} + +std::string WasmRegistry::Add(std::string_view path) { + // 1. Read the wasm file in path + auto is_file_read = io::ReadFileToString(path); + if (!is_file_read) { + return absl::StrCat("File error for path: ", path, " with error ", + is_file_read.error().message()); + } + + // In this context the cast is safe + wasmtime::Span wasm_bin{reinterpret_cast(is_file_read->data()), + is_file_read->size()}; + + // 2. Setup && compile + auto result = wasmtime::Module::compile(engine_, wasm_bin); + if (!result) { + return absl::StrCat("Error compiling file: ", path, " with error: ", result.err().message()); + } + + // 3. Insert to registry + auto slash = path.rfind('/'); + auto name = path; + if (slash != path.npos) { + name = name.substr(slash + 1); + // HELLO + } + std::unique_lock lock(mu_); + modules_.emplace(name, std::move(result.ok())); + + return {}; +} + +bool WasmRegistry::Delete(std::string_view name) { + std::unique_lock lock(mu_); + return modules_.erase(name); +} + +std::optional WasmRegistry::GetInstanceFromModule( + std::string_view module_name) { + std::shared_lock lock(mu_); + + if (!modules_.contains(module_name)) { + return {}; + } + + auto& module = modules_.at(module_name); + + auto instance = linker_.instantiate(store_, module.GetImpl()); + if (!instance) { + LOG(INFO) << instance.err().message(); + return {}; + } + + return WasmModuleInstance{instance.ok(), &store_}; +} + +} // namespace dfly::wasm diff --git a/src/server/wasm/wasm_registry.h b/src/server/wasm/wasm_registry.h new file mode 100644 index 000000000000..8208edc0ad0b --- /dev/null +++ b/src/server/wasm/wasm_registry.h @@ -0,0 +1,92 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "base/logging.h" +#include "server/wasm/api.h" +#include "server/wasm/wasmtime.hh" +#include "util/fibers/synchronization.h" + +namespace dfly::wasm { + +class WasmModule { + public: + explicit WasmModule(wasmtime::Module module) : module_(std::move(module)) { + } + + WasmModule(WasmModule&&) = default; + WasmModule& operator=(WasmModule&&) = default; + WasmModule(const WasmModule&) = delete; + ~WasmModule() = default; + + wasmtime::Module& GetImpl() { + return module_; + } + + private: + wasmtime::Module module_; +}; + +class WasmRegistry { + public: + WasmRegistry(); + WasmRegistry(const WasmRegistry&) = delete; + WasmRegistry(WasmRegistry&&) = delete; + ~WasmRegistry(); + std::string Add(std::string_view path); + bool Delete(std::string_view name); + + // Very light-weight. Each Module is compiled *once* but each UDF call, e,g, `CALLWASM` + // will spawn its own instantiation of the wasm module. This is fine, because the former + // is the expensive operation while the later is used as the context upon the function + // will execute (effectively allowing concurrent calls to the same wasm module) + class WasmModuleInstance { + public: + explicit WasmModuleInstance(wasmtime::Instance instance, wasmtime::Store* store) + : instance_{instance}, store_(store) { + } + + void operator()() { + // Users will export functions for their modules via the attribute + // __attribute__((export_name(func_name))). We will expose this in our sdk + auto extern_def = instance_.get(*store_, "my_fun"); + if (!extern_def) { + // return error + return; + } + auto run = std::get(*extern_def); + run.call(store_, {}).unwrap(); + } + + private: + wasmtime::Instance instance_; + wasmtime::Store* store_; + }; + + std::optional GetInstanceFromModule(std::string_view module_name); + + private: + absl::flat_hash_map modules_; + mutable util::fb2::SharedMutex mu_; + + // Global available for all threads + // see: https://docs.wasmtime.dev/c-api/wasmtime_8h.html in section thread safety + wasmtime::Engine engine_; + wasmtime::Linker linker_; + wasmtime::Store store_; + static wasmtime::Config GetConfig() { + wasmtime::Config config; + config.epoch_interruption(false); + return config; + } +}; + +} // namespace dfly::wasm diff --git a/src/server/wasm/wasmtime.hh b/src/server/wasm/wasmtime.hh new file mode 100644 index 000000000000..c2dee4b4b994 --- /dev/null +++ b/src/server/wasm/wasmtime.hh @@ -0,0 +1,3268 @@ +/** + * \mainpage + * + * This project is a C++ API for + * [Wasmtime](https://github.com/bytecodealliance/wasmtime). Support for the + * C++ API is exclusively built on the [C API of + * Wasmtime](https://docs.wasmtime.dev/c-api/), so the C++ support for this is + * simply a single header file. To use this header file, though, it must be + * combined with the header and binary of Wasmtime's C API. Note, though, that + * while this header is built on top of the `wasmtime.h` header file you should + * only need to use the contents of this header file to interact with Wasmtime. + * + * Examples can be [found + * online](https://github.com/bytecodealliance/wasmtime-cpp/tree/main/examples) + * and otherwise be sure to check out the + * [README](https://github.com/bytecodealliance/wasmtime-cpp/blob/main/README.md) + * for simple usage instructions. Otherwise you can dive right in to the + * reference documentation of \ref wasmtime.hh + * + * \example hello.cc + * \example gcd.cc + * \example linking.cc + * \example memory.cc + * \example interrupt.cc + * \example externref.cc + */ + +/** + * \file wasmtime.hh + */ + +#ifndef WASMTIME_HH +#define WASMTIME_HH + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef __has_include +#if __has_include() +#include +#endif +#endif + +#include "wasmtime.h" + +namespace wasmtime { + +#ifdef __cpp_lib_span + +/// \brief Alias to C++20 std::span when it is available +template using Span = std::span; + +#else + +/// \brief Means number of elements determined at runtime +inline constexpr size_t dynamic_extent = std::numeric_limits::max(); + +/** + * \brief Span class used when c++20 is not available + * @tparam T Type of data + * @tparam Extent Static size of data referred by Span class + */ +template class Span; + +/// \brief Check whether a type is `Span` +template struct IsSpan : std::false_type {}; + +template struct IsSpan> : std::true_type {}; + +template class Span { + static_assert(Extent == dynamic_extent, + "The current implementation supports dynamic-extent span only"); + + public: + /// \brief Type used to iterate over this span (a raw pointer) + using iterator = T*; + + /// \brief Constructor of Span class + Span(T* t, std::size_t n) : ptr_{t}, size_{n} { + } + + /// \brief Constructor of Span class for containers + template < + typename C, + std::enable_if_t< + !IsSpan::value && std::is_pointer_v().data())> && + std::is_convertible_v< + std::remove_pointer_t().data())> (*)[], T (*)[]> && + std::is_convertible_v().size()), std::size_t>, + int> = 0> + Span(C& range) : ptr_{range.data()}, size_{range.size()} { + } + + /// \brief Returns item by index + T& operator[](ptrdiff_t idx) const { + return ptr_[idx]; // NOLINT + } + + /// \brief Returns pointer to data + T* data() const { + return ptr_; + } + + /// \brief Returns number of data that referred by Span class + std::size_t size() const { + return size_; + } + + /// \brief Returns begin iterator + iterator begin() const { + return ptr_; + } + + /// \brief Returns end iterator + iterator end() const { + return ptr_ + size_; // NOLINT + } + + /// \brief Returns size in bytes + std::size_t size_bytes() const { + return sizeof(T) * size_; + } + + private: + T* ptr_; + std::size_t size_; +}; + +#endif + +class Trace; + +/** + * \brief Errors coming from Wasmtime + * + * This class represents an error that came from Wasmtime and contains a textual + * description of the error that occurred. + */ +class Error { + struct deleter { + void operator()(wasmtime_error_t* p) const { + wasmtime_error_delete(p); + } + }; + + std::unique_ptr ptr; + + public: + /// \brief Creates an error from the raw C API representation + /// + /// Takes ownership of the provided `error`. + Error(wasmtime_error_t* error) : ptr(error) { + } + + /// \brief Returns the error message associated with this error. + std::string message() const { + wasm_byte_vec_t msg_bytes; + wasmtime_error_message(ptr.get(), &msg_bytes); + auto ret = std::string(msg_bytes.data, msg_bytes.size); + wasm_byte_vec_delete(&msg_bytes); + return ret; + } + + /// If this trap represents a call to `exit` for WASI, this will return the + /// optional error code associated with the exit trap. + std::optional i32_exit() const { + int32_t status = 0; + if (wasmtime_error_exit_status(ptr.get(), &status)) { + return status; + } + return std::nullopt; + } + + /// Returns the trace of WebAssembly frames associated with this error. + /// + /// Note that the `trace` cannot outlive this error object. + Trace trace() const; +}; + +/// \brief Used to print an error. +inline std::ostream& operator<<(std::ostream& os, const Error& e) { + os << e.message(); + return os; +} + +/** + * \brief Fallible result type used for Wasmtime. + * + * This type is used as the return value of many methods in the Wasmtime API. + * This behaves similarly to Rust's `Result` and will be replaced with a + * C++ standard when it exists. + */ +template class [[nodiscard]] Result { + std::variant data; + + public: + /// \brief Creates a `Result` from its successful value. + Result(T t) : data(std::move(t)) { + } + /// \brief Creates a `Result` from an error value. + Result(E e) : data(std::move(e)) { + } + + /// \brief Returns `true` if this result is a success, `false` if it's an + /// error + explicit operator bool() const { + return data.index() == 0; + } + + /// \brief Returns the error, if present, aborts if this is not an error. + E&& err() { + return std::get(std::move(data)); + } + /// \brief Returns the error, if present, aborts if this is not an error. + const E&& err() const { + return std::get(std::move(data)); + } + + /// \brief Returns the success, if present, aborts if this is an error. + T&& ok() { + return std::get(std::move(data)); + } + /// \brief Returns the success, if present, aborts if this is an error. + const T&& ok() const { + return std::get(std::move(data)); + } + + /// \brief Returns the success, if present, aborts if this is an error. + T unwrap() { + if (*this) { + return this->ok(); + } + unwrap_failed(); + } + + private: + [[noreturn]] void unwrap_failed() { + fprintf(stderr, "error: %s\n", this->err().message().c_str()); // NOLINT + std::abort(); + } +}; + +/// \brief Strategies passed to `Config::strategy` +enum class Strategy { + /// Automatically selects the compilation strategy + Auto = WASMTIME_STRATEGY_AUTO, + /// Requires Cranelift to be used for compilation + Cranelift = WASMTIME_STRATEGY_CRANELIFT, +}; + +/// \brief Values passed to `Config::cranelift_opt_level` +enum class OptLevel { + /// No extra optimizations performed + None = WASMTIME_OPT_LEVEL_NONE, + /// Optimize for speed + Speed = WASMTIME_OPT_LEVEL_SPEED, + /// Optimize for speed and generated code size + SpeedAndSize = WASMTIME_OPT_LEVEL_SPEED_AND_SIZE, +}; + +/// \brief Values passed to `Config::profiler` +enum class ProfilingStrategy { + /// No profiling enabled + None = WASMTIME_PROFILING_STRATEGY_NONE, + /// Profiling hooks via perf's jitdump + Jitdump = WASMTIME_PROFILING_STRATEGY_JITDUMP, + /// Profiling hooks via VTune + Vtune = WASMTIME_PROFILING_STRATEGY_VTUNE, +}; + +/** + * \brief Configuration for Wasmtime. + * + * This class is used to configure Wasmtime's compilation and various other + * settings such as enabled WebAssembly proposals. + * + * For more information be sure to consult the [rust + * documentation](https://docs.wasmtime.dev/api/wasmtime/struct.Config.html). + */ +class Config { + friend class Engine; + + struct deleter { + void operator()(wasm_config_t* p) const { + wasm_config_delete(p); + } + }; + + std::unique_ptr ptr; + + public: + /// \brief Creates configuration with all the default settings. + Config() : ptr(wasm_config_new()) { + } + + /// \brief Configures whether dwarf debuginfo is emitted for assisting + /// in-process debugging. + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.debug_info + void debug_info(bool enable) { + wasmtime_config_debug_info_set(ptr.get(), enable); + } + + /// \brief Configures whether epochs are enabled which can be used to + /// interrupt currently executing WebAssembly. + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.epoch_interruption + void epoch_interruption(bool enable) { + wasmtime_config_epoch_interruption_set(ptr.get(), enable); + } + + /// \brief Configures whether WebAssembly code will consume fuel and trap when + /// it runs out. + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.consume_fuel + void consume_fuel(bool enable) { + wasmtime_config_consume_fuel_set(ptr.get(), enable); + } + + /// \brief Configures the maximum amount of native stack wasm can consume. + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.max_wasm_stack + void max_wasm_stack(size_t stack) { + wasmtime_config_max_wasm_stack_set(ptr.get(), stack); + } + + /// \brief Configures whether the WebAssembly threads proposal is enabled + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.wasm_threads + void wasm_threads(bool enable) { + wasmtime_config_wasm_threads_set(ptr.get(), enable); + } + + /// \brief Configures whether the WebAssembly reference types proposal is + /// enabled + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.wasm_reference_types + void wasm_reference_types(bool enable) { + wasmtime_config_wasm_reference_types_set(ptr.get(), enable); + } + + /// \brief Configures whether the WebAssembly simd proposal is enabled + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.wasm_simd + void wasm_simd(bool enable) { + wasmtime_config_wasm_simd_set(ptr.get(), enable); + } + + /// \brief Configures whether the WebAssembly bulk memory proposal is enabled + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.wasm_bulk_memory + void wasm_bulk_memory(bool enable) { + wasmtime_config_wasm_bulk_memory_set(ptr.get(), enable); + } + + /// \brief Configures whether the WebAssembly multi value proposal is enabled + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.wasm_multi_value + void wasm_multi_value(bool enable) { + wasmtime_config_wasm_multi_value_set(ptr.get(), enable); + } + + /// \brief Configures compilation strategy for wasm code. + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.strategy + void strategy(Strategy strategy) { + wasmtime_config_strategy_set(ptr.get(), static_cast(strategy)); + } + + /// \brief Configures whether cranelift's debug verifier is enabled + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.cranelift_debug_verifier + void cranelift_debug_verifier(bool enable) { + wasmtime_config_cranelift_debug_verifier_set(ptr.get(), enable); + } + + /// \brief Configures cranelift's optimization level + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.cranelift_opt_level + void cranelift_opt_level(OptLevel level) { + wasmtime_config_cranelift_opt_level_set(ptr.get(), static_cast(level)); + } + + /// \brief Configures an active wasm profiler + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.profiler + void profiler(ProfilingStrategy profiler) { + wasmtime_config_profiler_set(ptr.get(), static_cast(profiler)); + } + + /// \brief Configures the maximum size of memory to use a "static memory" + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.static_memory_maximum_size + void static_memory_maximum_size(size_t size) { + wasmtime_config_static_memory_maximum_size_set(ptr.get(), size); + } + + /// \brief Configures the size of static memory's guard region + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.static_memory_guard_size + void static_memory_guard_size(size_t size) { + wasmtime_config_static_memory_guard_size_set(ptr.get(), size); + } + + /// \brief Configures the size of dynamic memory's guard region + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.dynamic_memory_guard_size + void dynamic_memory_guard_size(size_t size) { + wasmtime_config_dynamic_memory_guard_size_set(ptr.get(), size); + } + + /// \brief Loads the default cache configuration present on the system. + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.cache_config_load_default + Result cache_load_default() { + auto* error = wasmtime_config_cache_config_load(ptr.get(), nullptr); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } + + /// \brief Loads cache configuration from the specified filename. + /// + /// https://docs.wasmtime.dev/api/wasmtime/struct.Config.html#method.cache_config_load + Result cache_load(const std::string& path) { + auto* error = wasmtime_config_cache_config_load(ptr.get(), path.c_str()); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } +}; + +/** + * \brief Global compilation state in Wasmtime. + * + * Created with either default configuration or with a specified instance of + * configuration, an `Engine` is used as an umbrella "session" for all other + * operations in Wasmtime. + */ +class Engine { + friend class Store; + friend class Module; + friend class Linker; + + struct deleter { + void operator()(wasm_engine_t* p) const { + wasm_engine_delete(p); + } + }; + + std::unique_ptr ptr; + + public: + /// \brief Creates an engine with default compilation settings. + Engine() : ptr(wasm_engine_new()) { + } + /// \brief Creates an engine with the specified compilation settings. + explicit Engine(Config config) : ptr(wasm_engine_new_with_config(config.ptr.release())) { + } + + /// \brief Increments the current epoch which may result in interrupting + /// currently executing WebAssembly in connected stores if the epoch is now + /// beyond the configured threshold. + void increment_epoch() const { + wasmtime_engine_increment_epoch(ptr.get()); + } +}; + +/** + * \brief Converts the WebAssembly text format into the WebAssembly binary + * format. + * + * This will parse the text format and attempt to translate it to the binary + * format. Note that the text parser assumes that all WebAssembly features are + * enabled and will parse syntax of future proposals. The exact syntax here + * parsed may be tweaked over time. + * + * Returns either an error if parsing failed or the wasm binary. + */ +inline Result> wat2wasm(std::string_view wat) { + wasm_byte_vec_t ret; + auto* error = wasmtime_wat2wasm(wat.data(), wat.size(), &ret); + if (error != nullptr) { + return Error(error); + } + std::vector vec; + // NOLINTNEXTLINE TODO can this be done without triggering lints? + Span raw(reinterpret_cast(ret.data), ret.size); + vec.assign(raw.begin(), raw.end()); + wasm_byte_vec_delete(&ret); + return vec; +} + +/// Different kinds of types accepted by Wasmtime. +enum class ValKind { + /// WebAssembly's `i32` type + I32, + /// WebAssembly's `i64` type + I64, + /// WebAssembly's `f32` type + F32, + /// WebAssembly's `f64` type + F64, + /// WebAssembly's `v128` type from the simd proposal + V128, + /// WebAssembly's `externref` type from the reference types + ExternRef, + /// WebAssembly's `funcref` type from the reference types + FuncRef, +}; + +/// Helper X macro to construct statement for each enumerator in `ValKind`. +/// X(enumerator in `ValKind`, name string, enumerator in `wasm_valkind_t`) +#define WASMTIME_FOR_EACH_VAL_KIND(X) \ + X(I32, "i32", WASM_I32) \ + X(I64, "i64", WASM_I64) \ + X(F32, "f32", WASM_F32) \ + X(F64, "f64", WASM_F64) \ + X(ExternRef, "externref", WASM_EXTERNREF) \ + X(FuncRef, "funcref", WASM_FUNCREF) \ + X(V128, "v128", WASMTIME_V128) + +/// \brief Used to print a ValKind. +inline std::ostream& operator<<(std::ostream& os, const ValKind& e) { + switch (e) { +#define CASE_KIND_PRINT_NAME(kind, name, ignore) \ + case ValKind::kind: \ + os << name; \ + break; + WASMTIME_FOR_EACH_VAL_KIND(CASE_KIND_PRINT_NAME) +#undef CASE_KIND_PRINT_NAME + default: + abort(); + } + return os; +} + +/** + * \brief Type information about a WebAssembly value. + * + * Currently mostly just contains the `ValKind`. + */ +class ValType { + friend class TableType; + friend class GlobalType; + friend class FuncType; + + struct deleter { + void operator()(wasm_valtype_t* p) const { + wasm_valtype_delete(p); + } + }; + + std::unique_ptr ptr; + + static wasm_valkind_t kind_to_c(ValKind kind) { + switch (kind) { +#define CASE_KIND_TO_C(kind, ignore, ckind) \ + case ValKind::kind: \ + return ckind; + WASMTIME_FOR_EACH_VAL_KIND(CASE_KIND_TO_C) +#undef CASE_KIND_TO_C + default: + abort(); + } + } + + public: + /// \brief Non-owning reference to a `ValType`, must not be used after the + /// original `ValType` is deleted. + class Ref { + friend class ValType; + + const wasm_valtype_t* ptr; + + public: + /// \brief Instantiates from the raw C API representation. + Ref(const wasm_valtype_t* ptr) : ptr(ptr) { + } + /// Copy constructor + Ref(const ValType& ty) : Ref(ty.ptr.get()) { + } + + /// \brief Returns the corresponding "kind" for this type. + ValKind kind() const { + switch (wasm_valtype_kind(ptr)) { +#define CASE_C_TO_KIND(kind, ignore, ckind) \ + case ckind: \ + return ValKind::kind; + WASMTIME_FOR_EACH_VAL_KIND(CASE_C_TO_KIND) +#undef CASE_C_TO_KIND + } + std::abort(); + } + }; + + /// \brief Non-owning reference to a list of `ValType` instances. Must not be + /// used after the original owner is deleted. + class ListRef { + const wasm_valtype_vec_t* list; + + public: + /// Creates a list from the raw underlying C API. + ListRef(const wasm_valtype_vec_t* list) : list(list) { + } + + /// This list iterates over a list of `ValType::Ref` instances. + typedef const Ref* iterator; + + /// Pointer to the beginning of iteration + iterator begin() const { + return reinterpret_cast(&list->data[0]); // NOLINT + } + + /// Pointer to the end of iteration + iterator end() const { + return reinterpret_cast(&list->data[list->size]); // NOLINT + } + + /// Returns how many types are in this list. + size_t size() const { + return list->size; + } + }; + + private: + Ref ref; + ValType(wasm_valtype_t* ptr) : ptr(ptr), ref(ptr) { + } + + public: + /// Creates a new type from its kind. + ValType(ValKind kind) : ValType(wasm_valtype_new(kind_to_c(kind))) { + } + /// Copies a `Ref` to a new owned value. + ValType(Ref other) : ValType(wasm_valtype_copy(other.ptr)) { + } + /// Copies one type to a new one. + ValType(const ValType& other) : ValType(wasm_valtype_copy(other.ptr.get())) { + } + /// Copies the contents of another type into this one. + ValType& operator=(const ValType& other) { + ptr.reset(wasm_valtype_copy(other.ptr.get())); + ref = other.ref; + return *this; + } + ~ValType() = default; + /// Moves the memory owned by another value type into this one. + ValType(ValType&& other) = default; + /// Moves the memory owned by another value type into this one. + ValType& operator=(ValType&& other) = default; + + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator->() { + return &ref; + } + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator*() { + return &ref; + } +}; + +/** + * \brief Type information about a WebAssembly linear memory + */ +class MemoryType { + friend class Memory; + + struct deleter { + void operator()(wasm_memorytype_t* p) const { + wasm_memorytype_delete(p); + } + }; + + std::unique_ptr ptr; + + public: + /// \brief Non-owning reference to a `MemoryType`, must not be used after the + /// original owner has been deleted. + class Ref { + friend class MemoryType; + + const wasm_memorytype_t* ptr; + + public: + /// Creates a reference from the raw C API representation. + Ref(const wasm_memorytype_t* ptr) : ptr(ptr) { + } + /// Creates a reference from an original `MemoryType`. + Ref(const MemoryType& ty) : Ref(ty.ptr.get()) { + } + + /// Returns the minimum size, in WebAssembly pages, of this memory. + uint64_t min() const { + return wasmtime_memorytype_minimum(ptr); + } + + /// Returns the maximum size, in WebAssembly pages, of this memory, if + /// specified. + std::optional max() const { + uint64_t max = 0; + auto present = wasmtime_memorytype_maximum(ptr, &max); + if (present) { + return max; + } + return std::nullopt; + } + + /// Returns whether or not this is a 64-bit memory type. + bool is_64() const { + return wasmtime_memorytype_is64(ptr); + } + }; + + private: + Ref ref; + MemoryType(wasm_memorytype_t* ptr) : ptr(ptr), ref(ptr) { + } + + public: + /// Creates a new 32-bit wasm memory type with the specified minimum number of + /// pages for the minimum size. The created type will have no maximum memory + /// size. + explicit MemoryType(uint32_t min) : MemoryType(wasmtime_memorytype_new(min, false, 0, false)) { + } + /// Creates a new 32-bit wasm memory type with the specified minimum number of + /// pages for the minimum size, and maximum number of pages for the max size. + MemoryType(uint32_t min, uint32_t max) + : MemoryType(wasmtime_memorytype_new(min, true, max, false)) { + } + + /// Same as the `MemoryType` constructor, except creates a 64-bit memory. + static MemoryType New64(uint64_t min) { + return MemoryType(wasmtime_memorytype_new(min, false, 0, true)); + } + + /// Same as the `MemoryType` constructor, except creates a 64-bit memory. + static MemoryType New64(uint64_t min, uint64_t max) { + return MemoryType(wasmtime_memorytype_new(min, true, max, true)); + } + + /// Creates a new wasm memory type from the specified ref, making a fresh + /// owned value. + MemoryType(Ref other) : MemoryType(wasm_memorytype_copy(other.ptr)) { + } + /// Copies the provided type into a new type. + MemoryType(const MemoryType& other) : MemoryType(wasm_memorytype_copy(other.ptr.get())) { + } + /// Copies the provided type into a new type. + MemoryType& operator=(const MemoryType& other) { + ptr.reset(wasm_memorytype_copy(other.ptr.get())); + return *this; + } + ~MemoryType() = default; + /// Moves the type information from another type into this one. + MemoryType(MemoryType&& other) = default; + /// Moves the type information from another type into this one. + MemoryType& operator=(MemoryType&& other) = default; + + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator->() { + return &ref; + } + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator*() { + return &ref; + } +}; + +/** + * \brief Type information about a WebAssembly table. + */ +class TableType { + friend class Table; + + struct deleter { + void operator()(wasm_tabletype_t* p) const { + wasm_tabletype_delete(p); + } + }; + + std::unique_ptr ptr; + + public: + /// Non-owning reference to a `TableType`, must not be used after the original + /// owner is deleted. + class Ref { + friend class TableType; + + const wasm_tabletype_t* ptr; + + public: + /// Creates a reference from the raw underlying C API representation. + Ref(const wasm_tabletype_t* ptr) : ptr(ptr) { + } + /// Creates a reference to the provided `TableType`. + Ref(const TableType& ty) : Ref(ty.ptr.get()) { + } + + /// Returns the minimum size of this table type, in elements. + uint32_t min() const { + return wasm_tabletype_limits(ptr)->min; + } + + /// Returns the maximum size of this table type, in elements, if present. + std::optional max() const { + const auto* limits = wasm_tabletype_limits(ptr); + if (limits->max == wasm_limits_max_default) { + return std::nullopt; + } + return limits->max; + } + + /// Returns the type of value that is stored in this table. + ValType::Ref element() const { + return wasm_tabletype_element(ptr); + } + }; + + private: + Ref ref; + TableType(wasm_tabletype_t* ptr) : ptr(ptr), ref(ptr) { + } + + public: + /// Creates a new table type from the specified value type and minimum size. + /// The returned table will have no maximum size. + TableType(ValType ty, uint32_t min) : ptr(nullptr), ref(nullptr) { + wasm_limits_t limits; + limits.min = min; + limits.max = wasm_limits_max_default; + ptr.reset(wasm_tabletype_new(ty.ptr.release(), &limits)); + ref = ptr.get(); + } + /// Creates a new table type from the specified value type, minimum size, and + /// maximum size. + TableType(ValType ty, uint32_t min, uint32_t max) // NOLINT + : ptr(nullptr), ref(nullptr) { + wasm_limits_t limits; + limits.min = min; + limits.max = max; + ptr.reset(wasm_tabletype_new(ty.ptr.release(), &limits)); + ref = ptr.get(); + } + /// Clones the given reference into a new table type. + TableType(Ref other) : TableType(wasm_tabletype_copy(other.ptr)) { + } + /// Copies another table type into this one. + TableType(const TableType& other) : TableType(wasm_tabletype_copy(other.ptr.get())) { + } + /// Copies another table type into this one. + TableType& operator=(const TableType& other) { + ptr.reset(wasm_tabletype_copy(other.ptr.get())); + return *this; + } + ~TableType() = default; + /// Moves the table type resources from another type to this one. + TableType(TableType&& other) = default; + /// Moves the table type resources from another type to this one. + TableType& operator=(TableType&& other) = default; + + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator->() { + return &ref; + } + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator*() { + return &ref; + } +}; + +/** + * \brief Type information about a WebAssembly global + */ +class GlobalType { + friend class Global; + + struct deleter { + void operator()(wasm_globaltype_t* p) const { + wasm_globaltype_delete(p); + } + }; + + std::unique_ptr ptr; + + public: + /// Non-owning reference to a `Global`, must not be used after the original + /// owner is deleted. + class Ref { + friend class GlobalType; + const wasm_globaltype_t* ptr; + + public: + /// Creates a new reference from the raw underlying C API representation. + Ref(const wasm_globaltype_t* ptr) : ptr(ptr) { + } + /// Creates a new reference to the specified type. + Ref(const GlobalType& ty) : Ref(ty.ptr.get()) { + } + + /// Returns whether or not this global type is mutable. + bool is_mutable() const { + return wasm_globaltype_mutability(ptr) == WASM_VAR; + } + + /// Returns the type of value stored within this global type. + ValType::Ref content() const { + return wasm_globaltype_content(ptr); + } + }; + + private: + Ref ref; + GlobalType(wasm_globaltype_t* ptr) : ptr(ptr), ref(ptr) { + } + + public: + /// Creates a new global type from the specified value type and mutability. + GlobalType(ValType ty, bool mut) + : GlobalType(wasm_globaltype_new(ty.ptr.release(), + (wasm_mutability_t)(mut ? WASM_VAR : WASM_CONST))) { + } + /// Clones a reference into a uniquely owned global type. + GlobalType(Ref other) : GlobalType(wasm_globaltype_copy(other.ptr)) { + } + /// Copies other type information into this one. + GlobalType(const GlobalType& other) : GlobalType(wasm_globaltype_copy(other.ptr.get())) { + } + /// Copies other type information into this one. + GlobalType& operator=(const GlobalType& other) { + ptr.reset(wasm_globaltype_copy(other.ptr.get())); + return *this; + } + ~GlobalType() = default; + /// Moves the underlying type information from another global into this one. + GlobalType(GlobalType&& other) = default; + /// Moves the underlying type information from another global into this one. + GlobalType& operator=(GlobalType&& other) = default; + + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator->() { + return &ref; + } + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator*() { + return &ref; + } +}; + +/** + * \brief Type information for a WebAssembly function. + */ +class FuncType { + friend class Func; + friend class Linker; + + struct deleter { + void operator()(wasm_functype_t* p) const { + wasm_functype_delete(p); + } + }; + + std::unique_ptr ptr; + + public: + /// Non-owning reference to a `FuncType`, must not be used after the original + /// owner has been deleted. + class Ref { + friend class FuncType; + const wasm_functype_t* ptr; + + public: + /// Creates a new reference from the underlying C API representation. + Ref(const wasm_functype_t* ptr) : ptr(ptr) { + } + /// Creates a new reference to the given type. + Ref(const FuncType& ty) : Ref(ty.ptr.get()) { + } + + /// Returns the list of types this function type takes as parameters. + ValType::ListRef params() const { + return wasm_functype_params(ptr); + } + + /// Returns the list of types this function type returns. + ValType::ListRef results() const { + return wasm_functype_results(ptr); + } + }; + + private: + Ref ref; + FuncType(wasm_functype_t* ptr) : ptr(ptr), ref(ptr) { + } + + public: + /// Creates a new function type from the given list of parameters and results. + FuncType(std::initializer_list params, std::initializer_list results) + : ref(nullptr) { + *this = FuncType::from_iters(params, results); + } + + /// Copies a reference into a uniquely owned function type. + FuncType(Ref other) : FuncType(wasm_functype_copy(other.ptr)) { + } + /// Copies another type's information into this one. + FuncType(const FuncType& other) : FuncType(wasm_functype_copy(other.ptr.get())) { + } + /// Copies another type's information into this one. + FuncType& operator=(const FuncType& other) { + ptr.reset(wasm_functype_copy(other.ptr.get())); + return *this; + } + ~FuncType() = default; + /// Moves type information from another type into this one. + FuncType(FuncType&& other) = default; + /// Moves type information from another type into this one. + FuncType& operator=(FuncType&& other) = default; + + /// Creates a new function type from the given list of parameters and results. + template static FuncType from_iters(P params, R results) { + wasm_valtype_vec_t param_vec; + wasm_valtype_vec_t result_vec; + wasm_valtype_vec_new_uninitialized(¶m_vec, params.size()); + wasm_valtype_vec_new_uninitialized(&result_vec, results.size()); + size_t i = 0; + + for (auto val : params) { + param_vec.data[i++] = val.ptr.release(); // NOLINT + } + i = 0; + for (auto val : results) { + result_vec.data[i++] = val.ptr.release(); // NOLINT + } + + return wasm_functype_new(¶m_vec, &result_vec); + } + + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator->() { + return &ref; + } + /// \brief Returns the underlying `Ref`, a non-owning reference pointing to + /// this instance. + Ref* operator*() { + return &ref; + } +}; + +/** + * \brief Type information about a WebAssembly import. + */ +class ImportType { + public: + /// Non-owning reference to an `ImportType`, must not be used after the + /// original owner is deleted. + class Ref { + friend class ExternType; + + const wasm_importtype_t* ptr; + + // TODO: can this circle be broken another way? + const wasm_externtype_t* raw_type() { + return wasm_importtype_type(ptr); + } + + public: + /// Creates a new reference from the raw underlying C API representation. + Ref(const wasm_importtype_t* ptr) : ptr(ptr) { + } + + /// Returns the module name associated with this import. + std::string_view module() { + const auto* name = wasm_importtype_module(ptr); + return std::string_view(name->data, name->size); + } + + /// Returns the field name associated with this import. + std::string_view name() { + const auto* name = wasm_importtype_name(ptr); + return std::string_view(name->data, name->size); + } + }; + + /// An owned list of `ImportType` instances. + class List { + friend class Module; + wasm_importtype_vec_t list; + + public: + /// Creates an empty list + List() : list{} { + list.size = 0; + list.data = nullptr; + } + List(const List& other) = delete; + /// Moves another list into this one. + List(List&& other) noexcept : list(other.list) { + other.list.size = 0; + } + ~List() { + if (list.size > 0) { + wasm_importtype_vec_delete(&list); + } + } + + List& operator=(const List& other) = delete; + /// Moves another list into this one. + List& operator=(List&& other) noexcept { + std::swap(list, other.list); + return *this; + } + + /// Iterator type, which is a list of non-owning `ImportType::Ref` + /// instances. + typedef const Ref* iterator; + /// Returns the start of iteration. + iterator begin() const { + return reinterpret_cast(&list.data[0]); // NOLINT + } + /// Returns the end of iteration. + iterator end() const { + return reinterpret_cast(&list.data[list.size]); // NOLINT + } + /// Returns the size of this list. + size_t size() const { + return list.size; + } + }; +}; + +/** + * \brief Type information about a WebAssembly export + */ +class ExportType { + public: + /// \brief Non-owning reference to an `ExportType`. + /// + /// Note to get type information you can use `ExternType::from_export`. + class Ref { + friend class ExternType; + + const wasm_exporttype_t* ptr; + + const wasm_externtype_t* raw_type() { + return wasm_exporttype_type(ptr); + } + + public: + /// Creates a new reference from the raw underlying C API representation. + Ref(const wasm_exporttype_t* ptr) : ptr(ptr) { + } + + /// Returns the name of this export. + std::string_view name() { + const auto* name = wasm_exporttype_name(ptr); + return std::string_view(name->data, name->size); + } + }; + + /// An owned list of `ExportType` instances. + class List { + friend class Module; + wasm_exporttype_vec_t list; + + public: + /// Creates an empty list + List() : list{} { + list.size = 0; + list.data = nullptr; + } + List(const List& other) = delete; + /// Moves another list into this one. + List(List&& other) noexcept : list(other.list) { + other.list.size = 0; + } + ~List() { + if (list.size > 0) { + wasm_exporttype_vec_delete(&list); + } + } + + List& operator=(const List& other) = delete; + /// Moves another list into this one. + List& operator=(List&& other) noexcept { + std::swap(list, other.list); + return *this; + } + + /// Iterator type, which is a list of non-owning `ExportType::Ref` + /// instances. + typedef const Ref* iterator; + /// Returns the start of iteration. + iterator begin() const { + return reinterpret_cast(&list.data[0]); // NOLINT + } + /// Returns the end of iteration. + iterator end() const { + return reinterpret_cast(&list.data[list.size]); // NOLINT + } + /// Returns the size of this list. + size_t size() const { + return list.size; + } + }; +}; + +/** + * \brief Generic type of a WebAssembly item. + */ +class ExternType { + friend class ExportType; + friend class ImportType; + + public: + /// \typedef Ref + /// \brief Non-owning reference to an item's type + /// + /// This cannot be used after the original owner has been deleted, and + /// otherwise this is used to determine what the actual type of the outer item + /// is. + typedef std::variant Ref; + + /// Extract the type of the item imported by the provided type. + static Ref from_import(ImportType::Ref ty) { + // TODO: this would ideally be some sort of implicit constructor, unsure how + // to do that though... + return ref_from_c(ty.raw_type()); + } + + /// Extract the type of the item exported by the provided type. + static Ref from_export(ExportType::Ref ty) { + // TODO: this would ideally be some sort of implicit constructor, unsure how + // to do that though... + return ref_from_c(ty.raw_type()); + } + + private: + static Ref ref_from_c(const wasm_externtype_t* ptr) { + switch (wasm_externtype_kind(ptr)) { + case WASM_EXTERN_FUNC: + return wasm_externtype_as_functype_const(ptr); + case WASM_EXTERN_GLOBAL: + return wasm_externtype_as_globaltype_const(ptr); + case WASM_EXTERN_TABLE: + return wasm_externtype_as_tabletype_const(ptr); + case WASM_EXTERN_MEMORY: + return wasm_externtype_as_memorytype_const(ptr); + } + std::abort(); + } +}; + +/** + * \brief Non-owning reference to a WebAssembly function frame as part of a + * `Trace` + * + * A `FrameRef` represents a WebAssembly function frame on the stack which was + * collected as part of a trap. + */ +class FrameRef { + wasm_frame_t* frame; + + public: + /// Returns the WebAssembly function index of this function, in the original + /// module. + uint32_t func_index() const { + return wasm_frame_func_index(frame); + } + /// Returns the offset, in bytes from the start of the function in the + /// original module, to this frame's program counter. + size_t func_offset() const { + return wasm_frame_func_offset(frame); + } + /// Returns the offset, in bytes from the start of the original module, + /// to this frame's program counter. + size_t module_offset() const { + return wasm_frame_module_offset(frame); + } + + /// Returns the name, if present, associated with this function. + /// + /// Note that this requires that the `name` section is present in the original + /// WebAssembly binary. + std::optional func_name() const { + const auto* name = wasmtime_frame_func_name(frame); + if (name != nullptr) { + return std::string_view(name->data, name->size); + } + return std::nullopt; + } + + /// Returns the name, if present, associated with this function's module. + /// + /// Note that this requires that the `name` section is present in the original + /// WebAssembly binary. + std::optional module_name() const { + const auto* name = wasmtime_frame_module_name(frame); + if (name != nullptr) { + return std::string_view(name->data, name->size); + } + return std::nullopt; + } +}; + +/** + * \brief An owned vector of `FrameRef` instances representing the WebAssembly + * call-stack on a trap. + * + * This can be used to iterate over the frames of a trap and determine what was + * running when a trap happened. + */ +class Trace { + friend class Trap; + friend class Error; + + wasm_frame_vec_t vec; + + Trace(wasm_frame_vec_t vec) : vec(vec) { + } + + public: + ~Trace() { + wasm_frame_vec_delete(&vec); + } + + Trace(const Trace& other) = delete; + Trace(Trace&& other) = delete; + Trace& operator=(const Trace& other) = delete; + Trace& operator=(Trace&& other) = delete; + + /// Iterator used to iterate over this trace. + typedef const FrameRef* iterator; + + /// Returns the start of iteration + iterator begin() const { + return reinterpret_cast(&vec.data[0]); // NOLINT + } + /// Returns the end of iteration + iterator end() const { + return reinterpret_cast(&vec.data[vec.size]); // NOLINT + } + /// Returns the size of this trace, or how many frames it contains. + size_t size() const { + return vec.size; + } +}; + +inline Trace Error::trace() const { + wasm_frame_vec_t frames; + wasmtime_error_wasm_trace(ptr.get(), &frames); + return Trace(frames); +} + +/** + * \brief Information about a WebAssembly trap. + * + * Traps can happen during normal wasm execution (such as the `unreachable` + * instruction) but they can also happen in host-provided functions to a host + * function can simulate raising a trap. + * + * Traps have a message associated with them as well as a trace of WebAssembly + * frames on the stack. + */ +class Trap { + friend class Linker; + friend class Instance; + friend class Func; + template friend class TypedFunc; + + struct deleter { + void operator()(wasm_trap_t* p) const { + wasm_trap_delete(p); + } + }; + + std::unique_ptr ptr; + + Trap(wasm_trap_t* ptr) : ptr(ptr) { + } + + public: + /// Creates a new host-defined trap with the specified message. + explicit Trap(std::string_view msg) : Trap(wasmtime_trap_new(msg.data(), msg.size())) { + } + + /// Returns the descriptive message associated with this trap + std::string message() const { + wasm_byte_vec_t msg; + wasm_trap_message(ptr.get(), &msg); + std::string ret(msg.data, msg.size - 1); + wasm_byte_vec_delete(&msg); + return ret; + } + + /// Returns the trace of WebAssembly frames associated with this trap. + /// + /// Note that the `trace` cannot outlive this error object. + Trace trace() const { + wasm_frame_vec_t frames; + wasm_trap_trace(ptr.get(), &frames); + return Trace(frames); + } +}; + +/// Structure used to represent either a `Trap` or an `Error`. +struct TrapError { + /// Storage for what this trap represents. + std::variant data; + + /// Creates a new `TrapError` from a `Trap` + TrapError(Trap t) : data(std::move(t)) { + } + /// Creates a new `TrapError` from an `Error` + TrapError(Error e) : data(std::move(e)) { + } + + /// Dispatches internally to return the message associated with this error. + std::string message() const { + if (const auto* trap = std::get_if(&data)) { + return trap->message(); + } + if (const auto* error = std::get_if(&data)) { + return std::string(error->message()); + } + std::abort(); + } +}; + +/// Result used by functions which can fail because of invariants being violated +/// (such as a type error) as well as because of a WebAssembly trap. +template using TrapResult = Result; + +/** + * \brief Representation of a compiled WebAssembly module. + * + * This type contains JIT code of a compiled WebAssembly module. A `Module` is + * connected to an `Engine` and can only be instantiated within that `Engine`. + * You can inspect a `Module` for its type information. This is passed as an + * argument to other APIs to instantiate it. + */ +class Module { + friend class Store; + friend class Instance; + friend class Linker; + + struct deleter { + void operator()(wasmtime_module_t* p) const { + wasmtime_module_delete(p); + } + }; + + std::unique_ptr ptr; + + Module(wasmtime_module_t* raw) : ptr(raw) { + } + + public: + /// Copies another module into this one. + Module(const Module& other) : ptr(wasmtime_module_clone(other.ptr.get())) { + } + /// Copies another module into this one. + Module& operator=(const Module& other) { + ptr.reset(wasmtime_module_clone(other.ptr.get())); + return *this; + } + ~Module() = default; + /// Moves resources from another module into this one. + Module(Module&& other) = default; + /// Moves resources from another module into this one. + Module& operator=(Module&& other) = default; + + /** + * \brief Compiles a module from the WebAssembly text format. + * + * This function will automatically use `wat2wasm` on the input and then + * delegate to the #compile function. + */ + static Result compile(Engine& engine, std::string_view wat) { + auto wasm = wat2wasm(wat); + if (!wasm) { + return wasm.err(); + } + auto bytes = wasm.ok(); + return compile(engine, bytes); + } + + /** + * \brief Compiles a module from the WebAssembly binary format. + * + * This function compiles the provided WebAssembly binary specified by `wasm` + * within the compilation settings configured by `engine`. This method is + * synchronous and will not return until the module has finished compiling. + * + * This function can fail if the WebAssembly binary is invalid or doesn't + * validate (or similar). + */ + static Result compile(Engine& engine, Span wasm) { + wasmtime_module_t* ret = nullptr; + auto* error = wasmtime_module_new(engine.ptr.get(), wasm.data(), wasm.size(), &ret); + if (error != nullptr) { + return Error(error); + } + return Module(ret); + } + + /** + * \brief Validates the provided WebAssembly binary without compiling it. + * + * This function will validate whether the provided binary is indeed valid + * within the compilation settings of the `engine` provided. + */ + static Result validate(Engine& engine, Span wasm) { + auto* error = wasmtime_module_validate(engine.ptr.get(), wasm.data(), wasm.size()); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } + + /** + * \brief Deserializes a previous list of bytes created with `serialize`. + * + * This function is intended to be much faster than `compile` where it uses + * the artifacts of a previous compilation to quickly create an in-memory + * module ready for instantiation. + * + * It is not safe to pass arbitrary input to this function, it is only safe to + * pass in output from previous calls to `serialize`. For more information see + * the Rust documentation - + * https://docs.wasmtime.dev/api/wasmtime/struct.Module.html#method.deserialize + */ + static Result deserialize(Engine& engine, Span wasm) { + wasmtime_module_t* ret = nullptr; + auto* error = wasmtime_module_deserialize(engine.ptr.get(), wasm.data(), wasm.size(), &ret); + if (error != nullptr) { + return Error(error); + } + return Module(ret); + } + + /** + * \brief Deserializes a module from an on-disk file. + * + * This function is the same as `deserialize` except that it reads the data + * for the serialized module from the path on disk. This can be faster than + * the alternative which may require copying the data around. + * + * It is not safe to pass arbitrary input to this function, it is only safe to + * pass in output from previous calls to `serialize`. For more information see + * the Rust documentation - + * https://docs.wasmtime.dev/api/wasmtime/struct.Module.html#method.deserialize + */ + static Result deserialize_file(Engine& engine, const std::string& path) { + wasmtime_module_t* ret = nullptr; + auto* error = wasmtime_module_deserialize_file(engine.ptr.get(), path.c_str(), &ret); + if (error != nullptr) { + return Error(error); + } + return Module(ret); + } + + /// Returns the list of types imported by this module. + ImportType::List imports() const { + ImportType::List list; + wasmtime_module_imports(ptr.get(), &list.list); + return list; + } + + /// Returns the list of types exported by this module. + ExportType::List exports() const { + ExportType::List list; + wasmtime_module_exports(ptr.get(), &list.list); + return list; + } + + /** + * \brief Serializes this module to a list of bytes. + * + * The returned bytes can then be used to later pass to `deserialize` to + * quickly recreate this module in a different process perhaps. + */ + Result> serialize() const { + wasm_byte_vec_t bytes; + auto* error = wasmtime_module_serialize(ptr.get(), &bytes); + if (error != nullptr) { + return Error(error); + } + std::vector ret; + // NOLINTNEXTLINE TODO can this be done without triggering lints? + Span raw(reinterpret_cast(bytes.data), bytes.size); + ret.assign(raw.begin(), raw.end()); + wasm_byte_vec_delete(&bytes); + return ret; + } +}; + +/** + * \brief Configuration for an instance of WASI. + * + * This is inserted into a store with `Store::Context::set_wasi`. + */ +class WasiConfig { + friend class Store; + + struct deleter { + void operator()(wasi_config_t* p) const { + wasi_config_delete(p); + } + }; + + std::unique_ptr ptr; + + public: + /// Creates a new configuration object with default settings. + WasiConfig() : ptr(wasi_config_new()) { + } + + /// Configures the argv explicitly with the given string array. + void argv(const std::vector& args) { + std::vector ptrs; + ptrs.reserve(args.size()); + for (const auto& arg : args) { + ptrs.push_back(arg.c_str()); + } + + wasi_config_set_argv(ptr.get(), (int)args.size(), ptrs.data()); + } + + /// Configures the argv for wasm to be inherited from this process itself. + void inherit_argv() { + wasi_config_inherit_argv(ptr.get()); + } + + /// Configures the environment variables available to wasm, specified here as + /// a list of pairs where the first element of the pair is the key and the + /// second element is the value. + void env(const std::vector>& env) { + std::vector names; + std::vector values; + for (const auto& [name, value] : env) { + names.push_back(name.c_str()); + values.push_back(value.c_str()); + } + wasi_config_set_env(ptr.get(), (int)env.size(), names.data(), values.data()); + } + + /// Indicates that the entire environment of this process should be inherited + /// by the wasi configuration. + void inherit_env() { + wasi_config_inherit_env(ptr.get()); + } + + /// Configures the provided file to be used for the stdin of this WASI + /// configuration. + [[nodiscard]] bool stdin_file(const std::string& path) { + return wasi_config_set_stdin_file(ptr.get(), path.c_str()); + } + + /// Configures this WASI configuration to inherit its stdin from the host + /// process. + void inherit_stdin() { + return wasi_config_inherit_stdin(ptr.get()); + } + + /// Configures the provided file to be created and all stdout output will be + /// written there. + [[nodiscard]] bool stdout_file(const std::string& path) { + return wasi_config_set_stdout_file(ptr.get(), path.c_str()); + } + + /// Configures this WASI configuration to inherit its stdout from the host + /// process. + void inherit_stdout() { + return wasi_config_inherit_stdout(ptr.get()); + } + + /// Configures the provided file to be created and all stderr output will be + /// written there. + [[nodiscard]] bool stderr_file(const std::string& path) { + return wasi_config_set_stderr_file(ptr.get(), path.c_str()); + } + + /// Configures this WASI configuration to inherit its stdout from the host + /// process. + void inherit_stderr() { + return wasi_config_inherit_stderr(ptr.get()); + } + + /// Opens `path` to be opened as `guest_path` in the WASI pseudo-filesystem. + [[nodiscard]] bool preopen_dir(const std::string& path, const std::string& guest_path) { + return wasi_config_preopen_dir(ptr.get(), path.c_str(), guest_path.c_str()); + } +}; + +class Caller; +template class TypedFunc; + +/** + * \brief Owner of all WebAssembly objects + * + * A `Store` owns all WebAssembly objects such as instances, globals, functions, + * memories, etc. A `Store` is one of the main central points about working with + * WebAssembly since it's an argument to almost all APIs. The `Store` serves as + * a form of "context" to give meaning to the pointers of `Func` and friends. + * + * A `Store` can be sent between threads but it cannot generally be shared + * concurrently between threads. Memory associated with WebAssembly instances + * will be deallocated when the `Store` is deallocated. + */ +class Store { + struct deleter { + void operator()(wasmtime_store_t* p) const { + wasmtime_store_delete(p); + } + }; + + std::unique_ptr ptr; + + static void finalizer(void* ptr) { + std::unique_ptr _ptr(static_cast(ptr)); + } + + public: + /// Creates a new `Store` within the provided `Engine`. + explicit Store(Engine& engine) : ptr(wasmtime_store_new(engine.ptr.get(), nullptr, finalizer)) { + } + + /** + * \brief An interior pointer into a `Store`. + * + * A `Context` object is created from either a `Store` or a `Caller`. It is an + * interior pointer into a `Store` and cannot be used outside the lifetime of + * the original object it was created from. + * + * This object is an argument to most APIs in Wasmtime but typically doesn't + * need to be constructed explicitly since it can be created from a `Store&` + * or a `Caller&`. + */ + class Context { + friend class Global; + friend class Table; + friend class Memory; + friend class Func; + friend class Instance; + friend class Linker; + friend class ExternRef; + friend class Val; + wasmtime_context_t* ptr; + + Context(wasmtime_context_t* ptr) : ptr(ptr) { + } + + public: + /// Creates a context referencing the provided `Store`. + Context(Store& store) : Context(wasmtime_store_context(store.ptr.get())) { + } + /// Creates a context referencing the provided `Store`. + Context(Store* store) : Context(*store) { + } + /// Creates a context referencing the provided `Caller`. + Context(Caller& caller); + /// Creates a context referencing the provided `Caller`. + Context(Caller* caller); + + /// Runs a garbage collection pass in the referenced store to collect loose + /// `externref` values, if any are available. + void gc() { + wasmtime_context_gc(ptr); + } + + /// Injects fuel to be consumed within this store. + /// + /// Stores start with 0 fuel and if `Config::consume_fuel` is enabled then + /// this is required if you want to let WebAssembly actually execute. + /// + /// Returns an error if fuel consumption isn't enabled. + Result set_fuel(uint64_t fuel) { + auto* error = wasmtime_context_set_fuel(ptr, fuel); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } + + /// Returns the amount of fuel consumed so far by executing WebAssembly. + /// + /// Returns `std::nullopt` if fuel consumption is not enabled. + Result get_fuel() const { + uint64_t fuel = 0; + auto* error = wasmtime_context_get_fuel(ptr, &fuel); + if (error != nullptr) { + return Error(error); + } + return fuel; + } + + /// Set user specified data associated with this store. + void set_data(std::any data) const { + finalizer(static_cast(wasmtime_context_get_data(ptr))); + wasmtime_context_set_data(ptr, std::make_unique(std::move(data)).release()); + } + + /// Get user specified data associated with this store. + std::any& get_data() const { + return *static_cast(wasmtime_context_get_data(ptr)); + } + + /// Configures the WASI state used by this store. + /// + /// This will only have an effect if used in conjunction with + /// `Linker::define_wasi` because otherwise no host functions will use the + /// WASI state. + Result set_wasi(WasiConfig config) { + auto* error = wasmtime_context_set_wasi(ptr, config.ptr.release()); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } + + /// Configures this store's epoch deadline to be the specified number of + /// ticks beyond the engine's current epoch. + /// + /// By default the deadline is the current engine's epoch, immediately + /// interrupting code if epoch interruption is enabled. This must be called + /// to extend the deadline to allow interruption. + void set_epoch_deadline(uint64_t ticks_beyond_current) { + wasmtime_context_set_epoch_deadline(ptr, ticks_beyond_current); + } + + /// Returns the raw context pointer for the C API. + wasmtime_context_t* raw_context() { + return ptr; + } + }; + + /// \brief Provides limits for a store. Used by hosts to limit resource + /// consumption of instances. Use negative value to keep the default value + /// for the limit. + /// + /// \param memory_size the maximum number of bytes a linear memory can grow + /// to. Growing a linear memory beyond this limit will fail. By default, + /// linear memory will not be limited. + /// + /// \param table_elements the maximum number of elements in a table. + /// Growing a table beyond this limit will fail. By default, table elements + /// will not be limited. + /// + /// \param instances the maximum number of instances that can be created + /// for a Store. Module instantiation will fail if this limit is exceeded. + /// This value defaults to 10,000. + /// + /// \param tables the maximum number of tables that can be created for a + /// Store. Module instantiation will fail if this limit is exceeded. This + /// value defaults to 10,000. + /// + /// \param memories the maximum number of linear + /// memories that can be created for a Store. Instantiation will fail with an + /// error if this limit is exceeded. This value defaults to 10,000. + /// + /// Use any negative value for the parameters that should be kept on + /// the default values. + /// + /// Note that the limits are only used to limit the creation/growth of + /// resources in the future, this does not retroactively attempt to apply + /// limits to the store. + void limiter(int64_t memory_size, int64_t table_elements, int64_t instances, int64_t tables, + int64_t memories) { + wasmtime_store_limiter(ptr.get(), memory_size, table_elements, instances, tables, memories); + } + + /// Explicit function to acquire a `Context` from this store. + Context context() { + return this; + } +}; + +/** + * \brief Representation of a WebAssembly `externref` value. + * + * This class represents an value that cannot be forged by WebAssembly itself. + * All `ExternRef` values are guaranteed to be created by the host and its + * embedding. It's suitable to place private data structures in here which + * WebAssembly will not have access to, only other host functions will have + * access to them. + * + * Note that `ExternRef` values are rooted within a `Store` and must be manually + * unrooted via the `unroot` function. If this is not used then values will + * never be candidates for garbage collection. + */ +class ExternRef { + friend class Val; + + wasmtime_externref_t val; + + static void finalizer(void* ptr) { + std::unique_ptr _ptr(static_cast(ptr)); + } + + public: + /// Creates a new `ExternRef` directly from its C-API representation. + explicit ExternRef(wasmtime_externref_t val) : val(val) { + } + + /// Creates a new `externref` value from the provided argument. + /// + /// Note that `val` should be safe to send across threads and should own any + /// memory that it points to. Also note that `ExternRef` is similar to a + /// `std::shared_ptr` in that there can be many references to the same value. + template explicit ExternRef(Store::Context cx, T val) { + void* ptr = std::make_unique(std::move(val)).release(); + bool ok = wasmtime_externref_new(cx.ptr, ptr, finalizer, &this->val); + if (!ok) { + fprintf(stderr, "failed to allocate a new externref"); + abort(); + } + } + + /// Creates a new `ExternRef` which is separately rooted from this one. + ExternRef clone(Store::Context cx) { + wasmtime_externref_t other; + wasmtime_externref_clone(cx.ptr, &val, &other); + return ExternRef(other); + } + + /// Returns the underlying host data associated with this `ExternRef`. + std::any& data(Store::Context cx) { + return *static_cast(wasmtime_externref_data(cx.ptr, &val)); + } + + /// Unroots this value from the context provided, enabling a future GC to + /// collect the internal object if there are no more references. + void unroot(Store::Context cx) { + wasmtime_externref_unroot(cx.ptr, &val); + } + + /// Returns the raw underlying C API value. + /// + /// This class still retains ownership of the pointer. + const wasmtime_externref_t* raw() const { + return &val; + } +}; + +class Func; +class Global; +class Instance; +class Memory; +class Table; + +/// \typedef Extern +/// \brief Representation of an external WebAssembly item +typedef std::variant Extern; + +/// \brief Container for the `v128` WebAssembly type. +struct V128 { + /// \brief The little-endian bytes of the `v128` value. + wasmtime_v128 v128; + + /// \brief Creates a new zero-value `v128`. + V128() : v128{} { + memset(&v128[0], 0, sizeof(wasmtime_v128)); + } + + /// \brief Creates a new `V128` from its C API representation. + V128(const wasmtime_v128& v) : v128{} { + memcpy(&v128[0], &v[0], sizeof(wasmtime_v128)); + } +}; + +/** + * \brief Representation of a generic WebAssembly value. + * + * This is roughly equivalent to a tagged union of all possible WebAssembly + * values. This is later used as an argument with functions, globals, tables, + * etc. + * + * Note that a `Val` can represent owned GC pointers. In this case the `unroot` + * method must be used to ensure that they can later be garbage-collected. + */ +class Val { + friend class Global; + friend class Table; + friend class Func; + + wasmtime_val_t val; + + Val() : val{} { + val.kind = WASMTIME_I32; + val.of.i32 = 0; + } + Val(wasmtime_val_t val) : val(val) { + } + + public: + /// Creates a new `i32` WebAssembly value. + Val(int32_t i32) : val{} { + val.kind = WASMTIME_I32; + val.of.i32 = i32; + } + /// Creates a new `i64` WebAssembly value. + Val(int64_t i64) : val{} { + val.kind = WASMTIME_I64; + val.of.i64 = i64; + } + /// Creates a new `f32` WebAssembly value. + Val(float f32) : val{} { + val.kind = WASMTIME_F32; + val.of.f32 = f32; + } + /// Creates a new `f64` WebAssembly value. + Val(double f64) : val{} { + val.kind = WASMTIME_F64; + val.of.f64 = f64; + } + /// Creates a new `v128` WebAssembly value. + Val(const V128& v128) : val{} { + val.kind = WASMTIME_V128; + memcpy(&val.of.v128[0], &v128.v128[0], sizeof(wasmtime_v128)); + } + /// Creates a new `funcref` WebAssembly value. + Val(std::optional func); + /// Creates a new `funcref` WebAssembly value which is not `ref.null func`. + Val(Func func); + /// Creates a new `externref` value. + Val(std::optional ptr) : val{} { + val.kind = WASMTIME_EXTERNREF; + if (ptr) { + val.of.externref = ptr->val; + } else { + wasmtime_externref_set_null(&val.of.externref); + } + } + /// Creates a new `externref` WebAssembly value which is not `ref.null + /// extern`. + Val(ExternRef ptr); + + /// Returns the kind of value that this value has. + ValKind kind() const { + switch (val.kind) { + case WASMTIME_I32: + return ValKind::I32; + case WASMTIME_I64: + return ValKind::I64; + case WASMTIME_F32: + return ValKind::F32; + case WASMTIME_F64: + return ValKind::F64; + case WASMTIME_FUNCREF: + return ValKind::FuncRef; + case WASMTIME_EXTERNREF: + return ValKind::ExternRef; + case WASMTIME_V128: + return ValKind::V128; + } + std::abort(); + } + + /// Returns the underlying `i32`, requires `kind() == KindI32` or aborts the + /// process. + int32_t i32() const { + if (val.kind != WASMTIME_I32) { + std::abort(); + } + return val.of.i32; + } + + /// Returns the underlying `i64`, requires `kind() == KindI64` or aborts the + /// process. + int64_t i64() const { + if (val.kind != WASMTIME_I64) { + std::abort(); + } + return val.of.i64; + } + + /// Returns the underlying `f32`, requires `kind() == KindF32` or aborts the + /// process. + float f32() const { + if (val.kind != WASMTIME_F32) { + std::abort(); + } + return val.of.f32; + } + + /// Returns the underlying `f64`, requires `kind() == KindF64` or aborts the + /// process. + double f64() const { + if (val.kind != WASMTIME_F64) { + std::abort(); + } + return val.of.f64; + } + + /// Returns the underlying `v128`, requires `kind() == KindV128` or aborts + /// the process. + V128 v128() const { + if (val.kind != WASMTIME_V128) { + std::abort(); + } + return val.of.v128; + } + + /// Returns the underlying `externref`, requires `kind() == KindExternRef` or + /// aborts the process. + /// + /// Note that `externref` is a nullable reference, hence the `optional` return + /// value. + std::optional externref(Store::Context cx) const { + if (val.kind != WASMTIME_EXTERNREF) { + std::abort(); + } + if (val.of.externref.store_id == 0) { + return std::nullopt; + } + wasmtime_externref_t other; + wasmtime_externref_clone(cx.ptr, &val.of.externref, &other); + return ExternRef(other); + } + + /// Returns the underlying `funcref`, requires `kind() == KindFuncRef` or + /// aborts the process. + /// + /// Note that `funcref` is a nullable reference, hence the `optional` return + /// value. + std::optional funcref() const; + + /// Unroots any GC references this `Val` points to within the `cx` provided. + void unroot(Store::Context cx) { + wasmtime_val_unroot(cx.ptr, &val); + } +}; + +/** + * \brief Structure provided to host functions to lookup caller information or + * acquire a `Store::Context`. + * + * This structure is passed to all host functions created with `Func`. It can be + * used to create a `Store::Context`. + */ +class Caller { + friend class Func; + friend class Store; + wasmtime_caller_t* ptr; + Caller(wasmtime_caller_t* ptr) : ptr(ptr) { + } + + public: + /// Attempts to load an exported item from the calling instance. + /// + /// For more information see the Rust documentation - + /// https://docs.wasmtime.dev/api/wasmtime/struct.Caller.html#method.get_export + std::optional get_export(std::string_view name); + + /// Explicitly acquire a `Store::Context` from this `Caller`. + Store::Context context() { + return this; + } +}; + +inline Store::Context::Context(Caller& caller) : Context(wasmtime_caller_context(caller.ptr)) { +} +inline Store::Context::Context(Caller* caller) : Context(*caller) { +} + +namespace detail { + +/// A "trait" for native types that correspond to WebAssembly types for use with +/// `Func::wrap` and `TypedFunc::call` +template struct WasmType { static const bool valid = false; }; + +/// Helper macro to define `WasmType` definitions for primitive types like +/// int32_t and such. +// NOLINTNEXTLINE +#define NATIVE_WASM_TYPE(native, valkind, field) \ + template <> struct WasmType { \ + static const bool valid = true; \ + static const ValKind kind = ValKind::valkind; \ + static void store(Store::Context cx, wasmtime_val_raw_t* p, const native& t) { \ + p->field = t; \ + } \ + static native load(Store::Context cx, wasmtime_val_raw_t* p) { \ + return p->field; \ + } \ + }; + +NATIVE_WASM_TYPE(int32_t, I32, i32) +NATIVE_WASM_TYPE(uint32_t, I32, i32) +NATIVE_WASM_TYPE(int64_t, I64, i64) +NATIVE_WASM_TYPE(uint64_t, I64, i64) +NATIVE_WASM_TYPE(float, F32, f32) +NATIVE_WASM_TYPE(double, F64, f64) + +#undef NATIVE_WASM_TYPE + +/// Type information for `externref`, represented on the host as an optional +/// `ExternRef`. +template <> struct WasmType> { + static const bool valid = true; + static const ValKind kind = ValKind::ExternRef; + static void store(Store::Context cx, wasmtime_val_raw_t* p, const std::optional& ref) { + if (ref) { + p->externref = wasmtime_externref_to_raw(cx.raw_context(), ref->raw()); + } else { + p->externref = 0; + } + } + static std::optional load(Store::Context cx, wasmtime_val_raw_t* p) { + if (p->externref == 0) { + return std::nullopt; + } + wasmtime_externref_t val; + wasmtime_externref_from_raw(cx.raw_context(), p->externref, &val); + return ExternRef(val); + } +}; + +/// Type information for the `V128` host value used as a wasm value. +template <> struct WasmType { + static const bool valid = true; + static const ValKind kind = ValKind::V128; + static void store(Store::Context cx, wasmtime_val_raw_t* p, const V128& t) { + memcpy(&p->v128[0], &t.v128[0], sizeof(wasmtime_v128)); + } + static V128 load(Store::Context cx, wasmtime_val_raw_t* p) { + return p->v128; + } +}; + +/// A "trait" for a list of types and operations on them, used for `Func::wrap` +/// and `TypedFunc::call` +/// +/// The base case is a single type which is a list of one element. +template struct WasmTypeList { + static const bool valid = WasmType::valid; + static const size_t size = 1; + static bool matches(ValType::ListRef types) { + return WasmTypeList>::matches(types); + } + static void store(Store::Context cx, wasmtime_val_raw_t* storage, const T& t) { + WasmType::store(cx, storage, t); + } + static T load(Store::Context cx, wasmtime_val_raw_t* storage) { + return WasmType::load(cx, storage); + } + static std::vector types() { + return {WasmType::kind}; + } +}; + +/// std::monostate translates to an empty list of types. +template <> struct WasmTypeList { + static const bool valid = true; + static const size_t size = 0; + static bool matches(ValType::ListRef types) { + return types.size() == 0; + } + static void store(Store::Context cx, wasmtime_val_raw_t* storage, const std::monostate& t) { + } + static std::monostate load(Store::Context cx, wasmtime_val_raw_t* storage) { + return std::monostate(); + } + static std::vector types() { + return {}; + } +}; + +/// std::tuple<> translates to the corresponding list of types +template struct WasmTypeList> { + static const bool valid = (WasmType::valid && ...); + static const size_t size = sizeof...(T); + static bool matches(ValType::ListRef types) { + if (types.size() != size) { + return false; + } + size_t n = 0; + return ((WasmType::kind == types.begin()[n++].kind()) && ...); + } + static void store(Store::Context cx, wasmtime_val_raw_t* storage, const std::tuple& t) { + size_t n = 0; + std::apply( + [&](const auto&... val) { + (WasmType::store(cx, &storage[n++], val), ...); // NOLINT + }, + t); + } + static std::tuple load(Store::Context cx, wasmtime_val_raw_t* storage) { + size_t n = 0; + return std::tuple{WasmType::load(cx, &storage[n++])...}; // NOLINT + } + static std::vector types() { + return {WasmType::kind...}; + } +}; + +/// A "trait" for what can be returned from closures specified to `Func::wrap`. +/// +/// The base case here is a bare return value like `int32_t`. +template struct WasmHostRet { + using Results = WasmTypeList; + + template + static std::optional invoke(F f, Caller cx, wasmtime_val_raw_t* raw, A... args) { + auto ret = f(args...); + Results::store(cx, raw, ret); + return std::nullopt; + } +}; + +/// Host functions can return nothing +template <> struct WasmHostRet { + using Results = WasmTypeList>; + + template + static std::optional invoke(F f, Caller cx, wasmtime_val_raw_t* raw, A... args) { + f(args...); + return std::nullopt; + } +}; + +// Alternative method of returning "nothing" (also enables `std::monostate` in +// the `R` type of `Result` below) +template <> struct WasmHostRet : public WasmHostRet {}; + +/// Host functions can return a result which allows them to also possibly return +/// a trap. +template struct WasmHostRet> { + using Results = WasmTypeList; + + template + static std::optional invoke(F f, Caller cx, wasmtime_val_raw_t* raw, A... args) { + Result ret = f(args...); + if (!ret) { + return ret.err(); + } + Results::store(cx, raw, ret.ok()); + return std::nullopt; + } +}; + +template struct WasmHostFunc; + +/// Base type information for host function pointers being used as wasm +/// functions +template struct WasmHostFunc { + using Params = WasmTypeList>; + using Results = typename WasmHostRet::Results; + + template + static std::optional invoke(F& f, Caller cx, wasmtime_val_raw_t* raw) { + auto params = Params::load(cx, raw); + return std::apply( + [&](const auto&... val) { return WasmHostRet::invoke(f, cx, raw, val...); }, params); + } +}; + +/// Function type information, but with a `Caller` first parameter +template +struct WasmHostFunc : public WasmHostFunc { + // Override `invoke` here to pass the `cx` as the first parameter + template + static std::optional invoke(F& f, Caller cx, wasmtime_val_raw_t* raw) { + auto params = WasmTypeList>::load(cx, raw); + return std::apply( + [&](const auto&... val) { return WasmHostRet::invoke(f, cx, raw, cx, val...); }, params); + } +}; + +/// Function type information, but with as a host method. +template +struct WasmHostFunc : public WasmHostFunc {}; + +/// Function type information, but with as a const host method. +template +struct WasmHostFunc : public WasmHostFunc {}; + +/// Function type information, but with as a host method with a `Caller` first +/// parameter. +template +struct WasmHostFunc : public WasmHostFunc {}; + +/// Function type information, but with as a host const method with a `Caller` +/// first parameter. +template +struct WasmHostFunc : public WasmHostFunc {}; + +// Forward... something? Not entirely sure but this makes things work. +template +struct WasmHostFunc> + : public WasmHostFunc {}; + +} // namespace detail + +using namespace detail; + +/** + * \brief Representation of a WebAssembly function. + * + * This class represents a WebAssembly function, either created through + * instantiating a module or a host function. + * + * Note that this type does not itself own any resources. It points to resources + * owned within a `Store` and the `Store` must be passed in as the first + * argument to the functions defined on `Func`. Note that if the wrong `Store` + * is passed in then the process will be aborted. + */ +class Func { + friend class Val; + friend class Instance; + friend class Linker; + template friend class TypedFunc; + + wasmtime_func_t func; + + template + static wasm_trap_t* raw_callback(void* env, wasmtime_caller_t* caller, const wasmtime_val_t* args, + size_t nargs, wasmtime_val_t* results, size_t nresults) { + static_assert(alignof(Val) == alignof(wasmtime_val_t)); + static_assert(sizeof(Val) == sizeof(wasmtime_val_t)); + F* func = reinterpret_cast(env); // NOLINT + Span args_span(reinterpret_cast(args), // NOLINT + nargs); + Span results_span(reinterpret_cast(results), // NOLINT + nresults); + Result result = (*func)(Caller(caller), args_span, results_span); + if (!result) { + return result.err().ptr.release(); + } + return nullptr; + } + + template + static wasm_trap_t* raw_callback_unchecked(void* env, wasmtime_caller_t* caller, + wasmtime_val_raw_t* args_and_results, + size_t nargs_and_results) { + using HostFunc = WasmHostFunc; + Caller cx(caller); + F* func = reinterpret_cast(env); // NOLINT + auto trap = HostFunc::invoke(*func, cx, args_and_results); + if (trap) { + return trap->ptr.release(); + } + return nullptr; + } + + template static void raw_finalize(void* env) { + std::unique_ptr ptr(reinterpret_cast(env)); // NOLINT + } + + public: + /// Creates a new function from the raw underlying C API representation. + Func(wasmtime_func_t func) : func(func) { + } + + /** + * \brief Creates a new host-defined function. + * + * This constructor is used to create a host function within the store + * provided. This is how WebAssembly can call into the host and make use of + * external functionality. + * + * > **Note**: host functions created this way are more flexible but not + * > as fast to call as those created by `Func::wrap`. + * + * \param cx the store to create the function within + * \param ty the type of the function that will be created + * \param f the host callback to be executed when this function is called. + * + * The parameter `f` is expected to be a lambda (or a lambda lookalike) which + * takes three parameters: + * + * * The first parameter is a `Caller` to get recursive access to the store + * and other caller state. + * * The second parameter is a `Span` which is the list of + * parameters to the function. These parameters are guaranteed to be of the + * types specified by `ty` when constructing this function. + * * The last argument is `Span` which is where to write the return + * values of the function. The function must produce the types of values + * specified by `ty` or otherwise a trap will be raised. + * + * The parameter `f` is expected to return `Result`. + * This allows `f` to raise a trap if desired, or otherwise return no trap and + * finish successfully. If a trap is raised then the results pointer does not + * need to be written to. + */ + template , F, + Caller, Span, Span>, + bool> = true> + Func(Store::Context cx, const FuncType& ty, F f) : func({}) { + wasmtime_func_new(cx.ptr, ty.ptr.get(), raw_callback, std::make_unique(f).release(), + raw_finalize, &func); + } + + /** + * \brief Creates a new host function from the provided callback `f`, + * inferring the WebAssembly function type from the host signature. + * + * This function is akin to the `Func` constructor except that the WebAssembly + * type does not need to be specified and additionally the signature of `f` + * is different. The main goal of this function is to enable WebAssembly to + * call the function `f` as-fast-as-possible without having to validate any + * types or such. + * + * The function `f` can optionally take a `Caller` as its first parameter, + * but otherwise its arguments are translated to WebAssembly types: + * + * * `int32_t`, `uint32_t` - `i32` + * * `int64_t`, `uint64_t` - `i64` + * * `float` - `f32` + * * `double` - `f64` + * * `std::optional` - `funcref` + * * `std::optional` - `externref` + * * `wasmtime::V128` - `v128` + * + * The function may only take these arguments and if it takes any other kinds + * of arguments then it will fail to compile. + * + * The function may return a few different flavors of return values: + * + * * `void` - interpreted as returning nothing + * * Any type above - interpreted as a singular return value. + * * `std::tuple` where `T` is one of the valid argument types - + * interpreted as returning multiple values. + * * `Result` where `T` is another valid return type - interpreted as + * a function that returns `T` to wasm but is optionally allowed to also + * raise a trap. + * + * It's recommended, if possible, to use this function over the `Func` + * constructor since this is generally easier to work with and also enables + * a faster path for WebAssembly to call this function. + */ + template ::Params::valid, bool> = true, + std::enable_if_t::Results::valid, bool> = true> + static Func wrap(Store::Context cx, F f) { + using HostFunc = WasmHostFunc; + auto params = HostFunc::Params::types(); + auto results = HostFunc::Results::types(); + auto ty = FuncType::from_iters(params, results); + wasmtime_func_t func; + wasmtime_func_new_unchecked(cx.ptr, ty.ptr.get(), raw_callback_unchecked, + std::make_unique(f).release(), raw_finalize, &func); + return func; + } + + /** + * \brief Invoke a WebAssembly function. + * + * This function will execute this WebAssembly function. This function muts be + * defined within the `cx`'s store provided. The `params` argument is the list + * of parameters that are passed to the wasm function, and the types of the + * values within `params` must match the type signature of this function. + * + * This may return one of three values: + * + * * First the function could succeed, returning a vector of values + * representing the results of the function. + * * Otherwise a `Trap` might be generated by the WebAssembly function. + * * Finally an `Error` could be returned indicating that `params` were not of + * the right type. + * + * > **Note**: for optimized calls into WebAssembly where the function + * > signature is statically known it's recommended to use `Func::typed` and + * > `TypedFunc::call`. + */ + template + TrapResult> call(Store::Context cx, const I& begin, const I& end) const { + std::vector raw_params; + raw_params.reserve(end - begin); + for (auto i = begin; i != end; i++) { + raw_params.push_back(i->val); + } + size_t nresults = this->type(cx)->results().size(); + std::vector raw_results(nresults); + + wasm_trap_t* trap = nullptr; + auto* error = wasmtime_func_call(cx.ptr, &func, raw_params.data(), raw_params.size(), + raw_results.data(), raw_results.capacity(), &trap); + if (error != nullptr) { + return TrapError(Error(error)); + } + if (trap != nullptr) { + return TrapError(Trap(trap)); + } + + std::vector results; + results.reserve(nresults); + for (size_t i = 0; i < nresults; i++) { + results.push_back(raw_results[i]); + } + return results; + } + + TrapResult> call(Store::Context cx, const std::vector& params) const { + return this->call(cx, params.begin(), params.end()); + } + + TrapResult> call(Store::Context cx, + const std::initializer_list& params) const { + return this->call(cx, params.begin(), params.end()); + } + + /// Returns the type of this function. + FuncType type(Store::Context cx) const { + return wasmtime_func_type(cx.ptr, &func); + } + + /** + * \brief Statically checks this function against the provided types. + * + * This function will check whether it takes the statically known `Params` + * and returns the statically known `Results`. If the type check succeeds then + * a `TypedFunc` is returned which enables a faster method of invoking + * WebAssembly functions. + * + * The `Params` and `Results` specified as template parameters here are the + * parameters and results of the wasm function. They can either be a bare + * type which means that the wasm takes/returns one value, or they can be a + * `std::tuple` of types to represent multiple arguments or multiple + * returns. + * + * The valid types for this function are those mentioned as the arguments + * for `Func::wrap`. + */ + template ::valid, bool> = true, + std::enable_if_t::valid, bool> = true> + Result, Trap> typed(Store::Context cx) const { + auto ty = this->type(cx); + if (!WasmTypeList::matches(ty->params()) || + !WasmTypeList::matches(ty->results())) { + return Trap("static type for this function does not match actual type"); + } + TypedFunc ret(*this); + return ret; + } + + /// Returns the raw underlying C API function this is using. + const wasmtime_func_t& raw_func() const { + return func; + } +}; + +/** + * \brief A version of a WebAssembly `Func` where the type signature of the + * function is statically known. + */ +template class TypedFunc { + friend class Func; + Func f; + TypedFunc(Func func) : f(func) { + } + + public: + /** + * \brief Calls this function with the provided parameters. + * + * This function is akin to `Func::call` except that since static type + * information is available it statically takes its parameters and statically + * returns its results. + * + * Note that this function still may return a `Trap` indicating that calling + * the WebAssembly function failed. + */ + TrapResult call(Store::Context cx, Params params) const { + std::array::size, WasmTypeList::size)> + storage; + wasmtime_val_raw_t* ptr = storage.data(); + if (ptr == nullptr) + ptr = reinterpret_cast(alignof(wasmtime_val_raw_t)); + WasmTypeList::store(cx, ptr, params); + wasm_trap_t* trap = nullptr; + auto* error = + wasmtime_func_call_unchecked(cx.raw_context(), &f.func, ptr, storage.size(), &trap); + if (error != nullptr) { + return TrapError(Error(error)); + } + if (trap != nullptr) { + return TrapError(Trap(trap)); + } + return WasmTypeList::load(cx, ptr); + } + + /// Returns the underlying un-typed `Func` for this function. + const Func& func() const { + return f; + } +}; + +inline Val::Val(std::optional func) : val{} { + val.kind = WASMTIME_FUNCREF; + if (func) { + val.of.funcref = (*func).func; + } else { + wasmtime_funcref_set_null(&val.of.funcref); + } +} + +inline Val::Val(Func func) : Val(std::optional(func)) { +} +inline Val::Val(ExternRef ptr) : Val(std::optional(ptr)) { +} + +inline std::optional Val::funcref() const { + if (val.kind != WASMTIME_FUNCREF) { + std::abort(); + } + if (val.of.funcref.store_id == 0) { + return std::nullopt; + } + return Func(val.of.funcref); +} + +/// Definition for the `funcref` native wasm type +template <> struct detail::WasmType> { + /// @private + static const bool valid = true; + /// @private + static const ValKind kind = ValKind::FuncRef; + /// @private + static void store(Store::Context cx, wasmtime_val_raw_t* p, const std::optional func) { + if (func) { + p->funcref = wasmtime_func_to_raw(cx.raw_context(), &func->raw_func()); + } else { + p->funcref = 0; + } + } + /// @private + static std::optional load(Store::Context cx, wasmtime_val_raw_t* p) { + if (p->funcref == 0) { + return std::nullopt; + } + wasmtime_func_t ret; + wasmtime_func_from_raw(cx.raw_context(), p->funcref, &ret); + return ret; + } +}; + +/** + * \brief A WebAssembly global. + * + * This class represents a WebAssembly global, either created through + * instantiating a module or a host global. Globals contain a WebAssembly value + * and can be read and optionally written to. + * + * Note that this type does not itself own any resources. It points to resources + * owned within a `Store` and the `Store` must be passed in as the first + * argument to the functions defined on `Global`. Note that if the wrong `Store` + * is passed in then the process will be aborted. + */ +class Global { + friend class Instance; + wasmtime_global_t global; + + public: + /// Creates as global from the raw underlying C API representation. + Global(wasmtime_global_t global) : global(global) { + } + + /** + * \brief Create a new WebAssembly global. + * + * \param cx the store in which to create the global + * \param ty the type that this global will have + * \param init the initial value of the global + * + * This function can fail if `init` does not have a value that matches `ty`. + */ + static Result create(Store::Context cx, const GlobalType& ty, const Val& init) { + wasmtime_global_t global; + auto* error = wasmtime_global_new(cx.ptr, ty.ptr.get(), &init.val, &global); + if (error != nullptr) { + return Error(error); + } + return Global(global); + } + + /// Returns the type of this global. + GlobalType type(Store::Context cx) const { + return wasmtime_global_type(cx.ptr, &global); + } + + /// Returns the current value of this global. + Val get(Store::Context cx) const; + + /// Sets this global to a new value. + /// + /// This can fail if `val` has the wrong type or if this global isn't mutable. + Result set(Store::Context cx, const Val& val) const { + auto* error = wasmtime_global_set(cx.ptr, &global, &val.val); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } +}; + +/** + * \brief A WebAssembly table. + * + * This class represents a WebAssembly table, either created through + * instantiating a module or a host table. Tables are contiguous vectors of + * WebAssembly reference types, currently either `externref` or `funcref`. + * + * Note that this type does not itself own any resources. It points to resources + * owned within a `Store` and the `Store` must be passed in as the first + * argument to the functions defined on `Table`. Note that if the wrong `Store` + * is passed in then the process will be aborted. + */ +class Table { + friend class Instance; + wasmtime_table_t table; + + public: + /// Creates a new table from the raw underlying C API representation. + Table(wasmtime_table_t table) : table(table) { + } + + /** + * \brief Creates a new host-defined table. + * + * \param cx the store in which to create the table. + * \param ty the type of the table to be created + * \param init the initial value for all table slots. + * + * Returns an error if `init` has the wrong value for the `ty` specified. + */ + static Result create(Store::Context cx, const TableType& ty, const Val& init) { + wasmtime_table_t table; + auto* error = wasmtime_table_new(cx.ptr, ty.ptr.get(), &init.val, &table); + if (error != nullptr) { + return Error(error); + } + return Table(table); + } + + /// Returns the type of this table. + TableType type(Store::Context cx) const { + return wasmtime_table_type(cx.ptr, &table); + } + + /// Returns the size, in elements, that the table currently has. + size_t size(Store::Context cx) const { + return wasmtime_table_size(cx.ptr, &table); + } + + /// Loads a value from the specified index in this table. + /// + /// Returns `std::nullopt` if `idx` is out of bounds. + std::optional get(Store::Context cx, uint32_t idx) const { + Val val; + if (wasmtime_table_get(cx.ptr, &table, idx, &val.val)) { + return val; + } + return std::nullopt; + } + + /// Stores a value into the specified index in this table. + /// + /// Returns an error if `idx` is out of bounds or if `val` has the wrong type. + Result set(Store::Context cx, uint32_t idx, const Val& val) const { + auto* error = wasmtime_table_set(cx.ptr, &table, idx, &val.val); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } + + /// Grow this table. + /// + /// \param cx the store that owns this table. + /// \param delta the number of new elements to be added to this table. + /// \param init the initial value of all new elements in this table. + /// + /// Returns an error if `init` has the wrong type for this table. Otherwise + /// returns the previous size of the table before growth. + Result grow(Store::Context cx, uint32_t delta, const Val& init) const { + uint32_t prev = 0; + auto* error = wasmtime_table_grow(cx.ptr, &table, delta, &init.val, &prev); + if (error != nullptr) { + return Error(error); + } + return prev; + } +}; + +// gcc 8.3.0 seems to require that this comes after the definition of `Table`. I +// don't know why... +inline Val Global::get(Store::Context cx) const { + Val val; + wasmtime_global_get(cx.ptr, &global, &val.val); + return val; +} + +/** + * \brief A WebAssembly linear memory. + * + * This class represents a WebAssembly memory, either created through + * instantiating a module or a host memory. + * + * Note that this type does not itself own any resources. It points to resources + * owned within a `Store` and the `Store` must be passed in as the first + * argument to the functions defined on `Table`. Note that if the wrong `Store` + * is passed in then the process will be aborted. + */ +class Memory { + friend class Instance; + wasmtime_memory_t memory; + + public: + /// Creates a new memory from the raw underlying C API representation. + Memory(wasmtime_memory_t memory) : memory(memory) { + } + + /// Creates a new host-defined memory with the type specified. + static Result create(Store::Context cx, const MemoryType& ty) { + wasmtime_memory_t memory; + auto* error = wasmtime_memory_new(cx.ptr, ty.ptr.get(), &memory); + if (error != nullptr) { + return Error(error); + } + return Memory(memory); + } + + /// Returns the type of this memory. + MemoryType type(Store::Context cx) const { + return wasmtime_memory_type(cx.ptr, &memory); + } + + /// Returns the size, in WebAssembly pages, of this memory. + uint64_t size(Store::Context cx) const { + return wasmtime_memory_size(cx.ptr, &memory); + } + + /// Returns a `span` of where this memory is located in the host. + /// + /// Note that embedders need to be very careful in their usage of the returned + /// `span`. It can be invalidated with calls to `grow` and/or calls into + /// WebAssembly. + Span data(Store::Context cx) const { + auto* base = wasmtime_memory_data(cx.ptr, &memory); + auto size = wasmtime_memory_data_size(cx.ptr, &memory); + return {base, size}; + } + + /// Grows the memory by `delta` WebAssembly pages. + /// + /// On success returns the previous size of this memory in units of + /// WebAssembly pages. + Result grow(Store::Context cx, uint64_t delta) const { + uint64_t prev = 0; + auto* error = wasmtime_memory_grow(cx.ptr, &memory, delta, &prev); + if (error != nullptr) { + return Error(error); + } + return prev; + } +}; + +/** + * \brief A WebAssembly instance. + * + * This class represents a WebAssembly instance, created by instantiating a + * module. An instance is the collection of items exported by the module, which + * can be accessed through the `Store` that owns the instance. + * + * Note that this type does not itself own any resources. It points to resources + * owned within a `Store` and the `Store` must be passed in as the first + * argument to the functions defined on `Instance`. Note that if the wrong + * `Store` is passed in then the process will be aborted. + */ +class Instance { + friend class Linker; + friend class Caller; + + wasmtime_instance_t instance; + + static Extern cvt(wasmtime_extern_t& e) { + switch (e.kind) { + case WASMTIME_EXTERN_FUNC: + return Func(e.of.func); + case WASMTIME_EXTERN_GLOBAL: + return Global(e.of.global); + case WASMTIME_EXTERN_MEMORY: + return Memory(e.of.memory); + case WASMTIME_EXTERN_TABLE: + return Table(e.of.table); + } + std::abort(); + } + + static void cvt(const Extern& e, wasmtime_extern_t& raw) { + if (const auto* func = std::get_if(&e)) { + raw.kind = WASMTIME_EXTERN_FUNC; + raw.of.func = func->func; + } else if (const auto* global = std::get_if(&e)) { + raw.kind = WASMTIME_EXTERN_GLOBAL; + raw.of.global = global->global; + } else if (const auto* table = std::get_if
(&e)) { + raw.kind = WASMTIME_EXTERN_TABLE; + raw.of.table = table->table; + } else if (const auto* memory = std::get_if(&e)) { + raw.kind = WASMTIME_EXTERN_MEMORY; + raw.of.memory = memory->memory; + } else { + std::abort(); + } + } + + public: + /// Creates a new instance from the raw underlying C API representation. + Instance(wasmtime_instance_t instance) : instance(instance) { + } + + /** + * \brief Instantiates the module `m` with the provided `imports` + * + * \param cx the store in which to instantiate the provided module + * \param m the module to instantiate + * \param imports the list of imports to use to instantiate the module + * + * This `imports` parameter is expected to line up 1:1 with the imports + * required by the `m`. The type of `m` can be inspected to determine in which + * order to provide the imports. Note that this is a relatively low-level API + * and it's generally recommended to use `Linker` instead for name-based + * instantiation. + * + * This function can return an error if any of the `imports` have the wrong + * type, or if the wrong number of `imports` is provided. + */ + static TrapResult create(Store::Context cx, const Module& m, + const std::vector& imports) { + std::vector raw_imports; + for (const auto& item : imports) { + raw_imports.push_back(wasmtime_extern_t{}); + auto& last = raw_imports.back(); + Instance::cvt(item, last); + } + wasmtime_instance_t instance; + wasm_trap_t* trap = nullptr; + auto* error = wasmtime_instance_new(cx.ptr, m.ptr.get(), raw_imports.data(), raw_imports.size(), + &instance, &trap); + if (error != nullptr) { + return TrapError(Error(error)); + } + if (trap != nullptr) { + return TrapError(Trap(trap)); + } + return Instance(instance); + } + + /** + * \brief Load an instance's export by name. + * + * This function will look for an export named `name` on this instance and, if + * found, return it as an `Extern`. + */ + std::optional get(Store::Context cx, std::string_view name) { + wasmtime_extern_t e; + if (!wasmtime_instance_export_get(cx.ptr, &instance, name.data(), name.size(), &e)) { + return std::nullopt; + } + return Instance::cvt(e); + } + + /** + * \brief Load an instance's export by index. + * + * This function will look for the `idx`th export of this instance. This will + * return both the name of the export as well as the exported item itself. + */ + std::optional> get(Store::Context cx, size_t idx) { + wasmtime_extern_t e; + // I'm not sure why clang-tidy thinks this is using va_list or anything + // related to that... + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) + char* name = nullptr; + size_t len = 0; + if (!wasmtime_instance_export_nth(cx.ptr, &instance, idx, &name, &len, &e)) { + return std::nullopt; + } + std::string_view n(name, len); + return std::pair(n, Instance::cvt(e)); + } +}; + +inline std::optional Caller::get_export(std::string_view name) { + wasmtime_extern_t item; + if (wasmtime_caller_export_get(ptr, name.data(), name.size(), &item)) { + return Instance::cvt(item); + } + return std::nullopt; +} + +/** + * \brief Helper class for linking modules together with name-based resolution. + * + * This class is used for easily instantiating `Module`s by defining names into + * the linker and performing name-based resolution during instantiation. A + * `Linker` can also be used to link in WASI functions to instantiate a module. + */ +class Linker { + struct deleter { + void operator()(wasmtime_linker_t* p) const { + wasmtime_linker_delete(p); + } + }; + + std::unique_ptr ptr; + + public: + /// Creates a new linker which will instantiate in the given engine. + explicit Linker(Engine& engine) : ptr(wasmtime_linker_new(engine.ptr.get())) { + } + + /// Configures whether shadowing previous names is allowed or not. + /// + /// By default shadowing is not allowed. + void allow_shadowing(bool allow) { + wasmtime_linker_allow_shadowing(ptr.get(), allow); + } + + /// Defines the provided item into this linker with the given name. + Result define(Store::Context cx, std::string_view module, std::string_view name, + const Extern& item) { + wasmtime_extern_t raw; + Instance::cvt(item, raw); + auto* error = wasmtime_linker_define(ptr.get(), cx.ptr, module.data(), module.size(), + name.data(), name.size(), &raw); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } + + /// Defines WASI functions within this linker. + /// + /// Note that `Store::Context::set_wasi` must also be used for instantiated + /// modules to have access to configured WASI state. + Result define_wasi() { + auto* error = wasmtime_linker_define_wasi(ptr.get()); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } + + /// Defines all exports of the `instance` provided in this linker with the + /// given module name of `name`. + Result define_instance(Store::Context cx, std::string_view name, + Instance instance) { + auto* error = wasmtime_linker_define_instance(ptr.get(), cx.ptr, name.data(), name.size(), + &instance.instance); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } + + /// Instantiates the module `m` provided within the store `cx` using the items + /// defined within this linker. + TrapResult instantiate(Store::Context cx, const Module& m) { + wasmtime_instance_t instance; + wasm_trap_t* trap = nullptr; + auto* error = wasmtime_linker_instantiate(ptr.get(), cx.ptr, m.ptr.get(), &instance, &trap); + if (error != nullptr) { + return TrapError(Error(error)); + } + if (trap != nullptr) { + return TrapError(Trap(trap)); + } + return Instance(instance); + } + + /// Defines instantiations of the module `m` within this linker under the + /// given `name`. + Result module(Store::Context cx, std::string_view name, const Module& m) { + auto* error = wasmtime_linker_module(ptr.get(), cx.ptr, name.data(), name.size(), m.ptr.get()); + if (error != nullptr) { + return Error(error); + } + return std::monostate(); + } + + /// Attempts to load the specified named item from this linker, returning + /// `std::nullopt` if it was not defined. + [[nodiscard]] std::optional get(Store::Context cx, std::string_view module, + std::string_view name) { + wasmtime_extern_t item; + if (wasmtime_linker_get(ptr.get(), cx.ptr, module.data(), module.size(), name.data(), + name.size(), &item)) { + return Instance::cvt(item); + } + return std::nullopt; + } + + /// Defines a new function in this linker in the style of the `Func` + /// constructor. + template , F, + Caller, Span, Span>, + bool> = true> + Result func_new(std::string_view module, std::string_view name, + const FuncType& ty, F f) { + auto* error = wasmtime_linker_define_func( + ptr.get(), module.data(), module.length(), name.data(), name.length(), ty.ptr.get(), + Func::raw_callback, std::make_unique(f).release(), Func::raw_finalize); + + if (error != nullptr) { + return Error(error); + } + + return std::monostate(); + } + + /// Defines a new function in this linker in the style of the `Func::wrap` + /// constructor. + template ::Params::valid, bool> = true, + std::enable_if_t::Results::valid, bool> = true> + Result func_wrap(std::string_view module, std::string_view name, F f) { + using HostFunc = WasmHostFunc; + auto params = HostFunc::Params::types(); + auto results = HostFunc::Results::types(); + auto ty = FuncType::from_iters(params, results); + auto* error = wasmtime_linker_define_func_unchecked( + ptr.get(), module.data(), module.length(), name.data(), name.length(), ty.ptr.get(), + Func::raw_callback_unchecked, std::make_unique(f).release(), Func::raw_finalize); + + if (error != nullptr) { + return Error(error); + } + + return std::monostate(); + } + + /// Loads the "default" function, according to WASI commands and reactors, of + /// the module named `name` in this linker. + Result get_default(Store::Context cx, std::string_view name) { + wasmtime_func_t item; + auto* error = wasmtime_linker_get_default(ptr.get(), cx.ptr, name.data(), name.size(), &item); + if (error != nullptr) { + return Error(error); + } + return Func(item); + } +}; + +} // namespace wasmtime + +#endif // WASMTIME_HH From 7390de50aef2969da6f05f14ea321db1b9f87a75 Mon Sep 17 00:00:00 2001 From: Kostas Kyrimis Date: Thu, 16 May 2024 23:20:22 +0300 Subject: [PATCH 2/6] chore(wasm): add sdk and move loading wasm modules at startup via flag (#3053) --- src/server/wasm/wasm_family.cc | 61 +++++++++++++++++++--------- src/server/wasm/wasm_family.h | 5 ++- src/server/wasm/wasm_registry.cc | 24 ++++++++--- src/server/wasm/wasm_registry.h | 13 +++--- src/wasm_sdk/README.md | 26 ++++++++++++ src/wasm_sdk/cpp/README.md | 6 +++ src/wasm_sdk/cpp/examples/example.cc | 15 +++++++ src/wasm_sdk/cpp/include/dragonfly.h | 25 ++++++++++++ 8 files changed, 144 insertions(+), 31 deletions(-) create mode 100644 src/wasm_sdk/README.md create mode 100644 src/wasm_sdk/cpp/README.md create mode 100644 src/wasm_sdk/cpp/examples/example.cc create mode 100644 src/wasm_sdk/cpp/include/dragonfly.h diff --git a/src/server/wasm/wasm_family.cc b/src/server/wasm/wasm_family.cc index c6130b62f0bc..83dd1c968b64 100644 --- a/src/server/wasm/wasm_family.cc +++ b/src/server/wasm/wasm_family.cc @@ -4,11 +4,15 @@ #include "server/wasm/wasm_family.h" #include "absl/strings/str_cat.h" +#include "base/flags.h" #include "facade/facade_types.h" #include "server/acl/acl_commands_def.h" #include "server/command_registry.h" #include "server/conn_context.h" +ABSL_FLAG(std::string, wasmpaths, "", + "Comma separated list of paths (including wasm file) to load WASM modules from"); + namespace dfly { namespace wasm { @@ -20,43 +24,60 @@ CommandId::Handler HandlerFunc(WasmFamily* wasm, MemberFunc f) { #define HFUNC(x) SetHandler(HandlerFunc(this, &WasmFamily::x)) +WasmFamily::WasmFamily() { + if (auto wasm_modules = absl::GetFlag(FLAGS_wasmpaths); !wasm_modules.empty()) { + registry_ = std::make_unique(); + } +} + void WasmFamily::Register(dfly::CommandRegistry* registry) { using CI = dfly::CommandId; registry->StartFamily(); - *registry << CI{"WASMCALL", dfly::CO::LOADING, 2, 0, 0, acl::WASM}.HFUNC(Call); - *registry << CI{"WASMLOAD", dfly::CO::LOADING, 2, 0, 0, acl::WASM}.HFUNC(Load); - *registry << CI{"WASMDEL", dfly::CO::LOADING, 2, 0, 0, acl::WASM}.HFUNC(Delete); + *registry << CI{"WASMCALL", dfly::CO::LOADING, 3, 0, 0, acl::WASM}.HFUNC(Call); + // *registry << CI{"WASMLOAD", dfly::CO::LOADING, 2, 0, 0, acl::WASM}.HFUNC(Load); + // *registry << CI{"WASMDEL", dfly::CO::LOADING, 2, 0, 0, acl::WASM}.HFUNC(Delete); } void WasmFamily::Load(CmdArgList args, ConnectionContext* cntx) { - auto path = absl::StrCat(facade::ToSV(args[0]), "\0"); - if (auto res = registry_.Add(path); !res.empty()) { - cntx->SendError(res); - return; - } - auto slash = path.rfind('/'); - auto name = path; - if (slash != path.npos) { - name = name.substr(slash + 1); - } - cntx->SendOk(); + // TODO figure out how to load modules dynamically + // auto path = absl::StrCat(facade::ToSV(args[0]), "\0"); + // if (auto res = registry_->Add(path); !res.empty()) { + // cntx->SendError(res); + // return; + // } + // auto slash = path.rfind('/'); + // auto name = path; + // if (slash != path.npos) { + // name = name.substr(slash + 1); + // } + // cntx->SendOk(); } void WasmFamily::Call(CmdArgList args, ConnectionContext* cntx) { - auto name = facade::ToSV(args[0]); - auto res = registry_.GetInstanceFromModule(name); + if (!registry_) { + cntx->SendError("Wasm is not enabled"); + return; + } + auto module_name = facade::ToSV(args[0]); + auto exported_fun_name = facade::ToSV(args[1]); + auto res = registry_->GetInstanceFromModule(module_name); if (!res) { - cntx->SendError(absl::StrCat("Could not find module with ", name)); + cntx->SendError(absl::StrCat("Could not find module with ", module_name)); + return; + } + auto& wasm_function = *res; + auto wasm_result = wasm_function(exported_fun_name); + if (!wasm_result.empty()) { + cntx->SendError(wasm_result); return; } - auto& wasm_instance = *res; - wasm_instance(); cntx->SendOk(); } void WasmFamily::Delete(CmdArgList args, ConnectionContext* cntx) { + // TODO figure out how to load modules dynamically auto name = facade::ToSV(args[0]); - cntx->SendLong(registry_.Delete(name)); + cntx->SendLong(registry_->Delete(name)); } } // namespace wasm diff --git a/src/server/wasm/wasm_family.h b/src/server/wasm/wasm_family.h index 6b1bdaa70990..af52dda01675 100644 --- a/src/server/wasm/wasm_family.h +++ b/src/server/wasm/wasm_family.h @@ -4,6 +4,8 @@ #pragma once +#include + #include "facade/facade_types.h" #include "server/command_registry.h" #include "server/wasm/wasm_registry.h" @@ -15,6 +17,7 @@ namespace wasm { class WasmFamily final { public: + WasmFamily(); void Register(CommandRegistry* registry); private: @@ -22,7 +25,7 @@ class WasmFamily final { void Call(facade::CmdArgList args, ConnectionContext* cntx); void Delete(facade::CmdArgList args, ConnectionContext* cntx); - WasmRegistry registry_; + std::unique_ptr registry_; }; } // namespace wasm diff --git a/src/server/wasm/wasm_registry.cc b/src/server/wasm/wasm_registry.cc index 6f06e5567090..23cb86adaf30 100644 --- a/src/server/wasm/wasm_registry.cc +++ b/src/server/wasm/wasm_registry.cc @@ -11,11 +11,15 @@ #include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "base/flags.h" #include "base/logging.h" #include "io/file_util.h" #include "server/wasm/api.h" #include "server/wasm/wasmtime.hh" +ABSL_DECLARE_FLAG(std::string, wasmpaths); + namespace dfly::wasm { WasmRegistry::WasmRegistry() @@ -37,17 +41,28 @@ WasmRegistry::WasmRegistry() store_.context().set_wasi(std::move(wasi)).unwrap(); linker_.define_wasi().unwrap(); + + InstantiateAndLinkModules(); } WasmRegistry::~WasmRegistry() { } +void WasmRegistry::InstantiateAndLinkModules() { + auto wasm_modules = absl::GetFlag(FLAGS_wasmpaths); + std::vector modules = absl::StrSplit(wasm_modules, ","); + for (auto mod_path : modules) { + Add(mod_path); + } +} + std::string WasmRegistry::Add(std::string_view path) { // 1. Read the wasm file in path auto is_file_read = io::ReadFileToString(path); if (!is_file_read) { - return absl::StrCat("File error for path: ", path, " with error ", - is_file_read.error().message()); + LOG(ERROR) << "File error for path: " << path << " with error " + << is_file_read.error().message(); + exit(1); } // In this context the cast is safe @@ -57,7 +72,8 @@ std::string WasmRegistry::Add(std::string_view path) { // 2. Setup && compile auto result = wasmtime::Module::compile(engine_, wasm_bin); if (!result) { - return absl::StrCat("Error compiling file: ", path, " with error: ", result.err().message()); + LOG(ERROR) << "Error compiling file: " << path << " with error: " << result.err().message(); + exit(1); } // 3. Insert to registry @@ -65,9 +81,7 @@ std::string WasmRegistry::Add(std::string_view path) { auto name = path; if (slash != path.npos) { name = name.substr(slash + 1); - // HELLO } - std::unique_lock lock(mu_); modules_.emplace(name, std::move(result.ok())); return {}; diff --git a/src/server/wasm/wasm_registry.h b/src/server/wasm/wasm_registry.h index 8208edc0ad0b..7bdc5f1db1d9 100644 --- a/src/server/wasm/wasm_registry.h +++ b/src/server/wasm/wasm_registry.h @@ -10,6 +10,7 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" #include "base/logging.h" #include "server/wasm/api.h" #include "server/wasm/wasmtime.hh" @@ -41,7 +42,6 @@ class WasmRegistry { WasmRegistry(const WasmRegistry&) = delete; WasmRegistry(WasmRegistry&&) = delete; ~WasmRegistry(); - std::string Add(std::string_view path); bool Delete(std::string_view name); // Very light-weight. Each Module is compiled *once* but each UDF call, e,g, `CALLWASM` @@ -54,16 +54,16 @@ class WasmRegistry { : instance_{instance}, store_(store) { } - void operator()() { + std::string operator()(std::string_view export_func_name) { // Users will export functions for their modules via the attribute // __attribute__((export_name(func_name))). We will expose this in our sdk - auto extern_def = instance_.get(*store_, "my_fun"); + auto extern_def = instance_.get(*store_, export_func_name); if (!extern_def) { - // return error - return; + return absl::StrCat("No exported function with name ", export_func_name, " found"); } auto run = std::get(*extern_def); run.call(store_, {}).unwrap(); + return {}; } private: @@ -74,6 +74,9 @@ class WasmRegistry { std::optional GetInstanceFromModule(std::string_view module_name); private: + std::string Add(std::string_view path); + void InstantiateAndLinkModules(); + absl::flat_hash_map modules_; mutable util::fb2::SharedMutex mu_; diff --git a/src/wasm_sdk/README.md b/src/wasm_sdk/README.md new file mode 100644 index 000000000000..af8d79bfd34c --- /dev/null +++ b/src/wasm_sdk/README.md @@ -0,0 +1,26 @@ +# Public facing API for experimental Dragonfly wasm functions + +This is the top folder for our public API. Each subfolder should serve as the +sdk for each language we support. For example: +`/python` subfolder would contain the declarations (but not the definitions! these will be +provided/exported by dragonfly and users wasm binaries will be linked against them +to resolve the symbols) for our API. A client function would be: + +``` +#include + +DF_EXPORT("my_fun") +void my_fun() { + // work here + return 1; +} +``` + +Loading the WASM binary in df is done via the flag `wasmpaths="path/to/mod1.wasm,path/to/mode2.wasm` + +Calling an exported function is as simple as: + +``` +> CALL mode2.wasm my_fun +> 1 +``` diff --git a/src/wasm_sdk/cpp/README.md b/src/wasm_sdk/cpp/README.md new file mode 100644 index 000000000000..a6f7b17a9ad4 --- /dev/null +++ b/src/wasm_sdk/cpp/README.md @@ -0,0 +1,6 @@ +# Dragonfly C++ SDK for wasm + +1. Download the wasi-sdk found in https://github.com/WebAssembly/wasi-sdk.git +2. Compile your source files with wasi-sdk clang: `wasi-sdk-22.0/bin/clang++` +3. Load the module at startup via `WASMPATH` +4. Call an exported function via `WASMCALL module.wasm function_name` diff --git a/src/wasm_sdk/cpp/examples/example.cc b/src/wasm_sdk/cpp/examples/example.cc new file mode 100644 index 000000000000..b015da4fb06b --- /dev/null +++ b/src/wasm_sdk/cpp/examples/example.cc @@ -0,0 +1,15 @@ +#include "dragonfly.h" + +// wasi-sdk-22.0/bin/clang++ -std=c++11 example.cc -o example.wasm +// +// Launch Dragonfly via `./dragonfly --alsologtostderr --wasmpaths="/path/to/example.wasm" +// And call it via redis-cli: +// > CALLWASM example.wasm my_fun +// +// You can also export multiple functions per module + +DF_EXPORT("my_fun") +void my_fun() { + // call exported dragonfly function hello() + hello(); +} diff --git a/src/wasm_sdk/cpp/include/dragonfly.h b/src/wasm_sdk/cpp/include/dragonfly.h new file mode 100644 index 000000000000..2aef03094e83 --- /dev/null +++ b/src/wasm_sdk/cpp/include/dragonfly.h @@ -0,0 +1,25 @@ +// Copyright 2022, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +namespace dragonfly { + +/* PUBLIC FACING API */ + +extern "C" { +#define WASM_IMPORT(mod, name) __attribute__((import_module(#mod), import_name(#name))) + +WASM_IMPORT(dragonfly, hello) +void hello(); + +// Add rest of functions here +} + +/* Used to export functions from wasm modules */ + +// Use this macro if you want your function to be available within dragonfly, +// that is if you want to call `my_fun` from module `module.wasm` via `WASMCALL` +// then you have to register it first via DF_EXPORT("my_fun") +#define DF_EXPORT(name) __attribute__((export_name(name))) + +} // namespace dragonfly From 852f956aabd5c65c0a2848bbadee488711b68f09 Mon Sep 17 00:00:00 2001 From: Kostas Kyrimis Date: Thu, 16 May 2024 23:43:48 +0300 Subject: [PATCH 3/6] chore(wasm): Expiriment with memory (#3055) chore(wasm): experiment with memory --- src/server/wasm/api.h | 4 +++- src/server/wasm/wasm_family.cc | 3 +++ src/server/wasm/wasm_registry.cc | 34 ++++++++++++++++++++++------ src/server/wasm/wasm_registry.h | 6 ++++- src/wasm_sdk/cpp/examples/example.cc | 7 ++++-- src/wasm_sdk/cpp/include/dragonfly.h | 32 +++++++++++++++++++++++++- 6 files changed, 74 insertions(+), 12 deletions(-) diff --git a/src/server/wasm/api.h b/src/server/wasm/api.h index 05919062a2be..727dd13269bf 100644 --- a/src/server/wasm/api.h +++ b/src/server/wasm/api.h @@ -17,7 +17,9 @@ namespace dfly::wasm::api { template bool RegisterApiFunction(std::string_view name, Fn f, wasmtime::Linker* linker) { std::string module_name = "dragonfly"; - auto args_signature = wasmtime::FuncType({}, {}); + + auto args_signature = wasmtime::FuncType({}, {wasmtime::ValKind::I32}); + auto res = linker->func_new(module_name, name, args_signature, f); return (bool)res; } diff --git a/src/server/wasm/wasm_family.cc b/src/server/wasm/wasm_family.cc index 83dd1c968b64..c5c0c2273b6f 100644 --- a/src/server/wasm/wasm_family.cc +++ b/src/server/wasm/wasm_family.cc @@ -66,6 +66,9 @@ void WasmFamily::Call(CmdArgList args, ConnectionContext* cntx) { return; } auto& wasm_function = *res; + // This is fine because we block and is atomic in respect to proactor. + // When we switch to async execution of the runtime this will be able to suspend and resume + // which will make the current approach invalid. auto wasm_result = wasm_function(exported_fun_name); if (!wasm_result.empty()) { cntx->SendError(wasm_result); diff --git a/src/server/wasm/wasm_registry.cc b/src/server/wasm/wasm_registry.cc index 23cb86adaf30..049674643951 100644 --- a/src/server/wasm/wasm_registry.cc +++ b/src/server/wasm/wasm_registry.cc @@ -24,13 +24,33 @@ namespace dfly::wasm { WasmRegistry::WasmRegistry() : engine_(WasmRegistry::GetConfig()), linker_(engine_), store_(engine_) { - api::RegisterApiFunction( - "hello", - [](auto...) { - LOG(INFO) << "Hello from WASM"; - return std::monostate(); - }, - &linker_); + auto hellofunc = [](wasmtime::Caller caller, auto params, auto results) { + auto res = caller.get_export("allocate_on_guest_mem"); + std::string value = "Hello world from wasm!"; + value.push_back('\0'); + + // Call the exported alloc to allocate memory on the guest + auto alloc = std::get(*res); + const int32_t alloc_size = static_cast(value.size()); + auto result = alloc.call(caller.context(), {wasmtime::Val{alloc_size}}); + if (!result) { + // handle errors + } + auto wasm_value = result.ok().front(); + auto offset = wasm_value.i32(); + + wasmtime::Memory memory = std::get(*caller.get_export("memory")); + + uint8_t* data = memory.data(caller.context()).data() + offset; + absl::little_endian::Store32(data, value.size()); + // TODO inject payload size at the front so we dont have to push an extra \0 + memcpy(data, value.c_str(), value.size()); + + results[0] = wasm_value; + return std::monostate(); + }; + + api::RegisterApiFunction("hello", hellofunc, &linker_); wasmtime::WasiConfig wasi; wasi.inherit_argv(); diff --git a/src/server/wasm/wasm_registry.h b/src/server/wasm/wasm_registry.h index 7bdc5f1db1d9..a14d8c225212 100644 --- a/src/server/wasm/wasm_registry.h +++ b/src/server/wasm/wasm_registry.h @@ -62,10 +62,14 @@ class WasmRegistry { return absl::StrCat("No exported function with name ", export_func_name, " found"); } auto run = std::get(*extern_def); - run.call(store_, {}).unwrap(); + auto res = run.call(store_, {}).unwrap(); return {}; } + wasmtime::Instance* GetInstance() { + return &instance_; + } + private: wasmtime::Instance instance_; wasmtime::Store* store_; diff --git a/src/wasm_sdk/cpp/examples/example.cc b/src/wasm_sdk/cpp/examples/example.cc index b015da4fb06b..075e81bc9867 100644 --- a/src/wasm_sdk/cpp/examples/example.cc +++ b/src/wasm_sdk/cpp/examples/example.cc @@ -10,6 +10,9 @@ DF_EXPORT("my_fun") void my_fun() { - // call exported dragonfly function hello() - hello(); + std::string result = dragonfly::hello_world(); + // passes + assert(result == "Hello world from wasm!"); + // when we return the bson here we need to be carefull with the allocation. + // does that mean the host will need to free the memory? e.g, export a `free` function? } diff --git a/src/wasm_sdk/cpp/include/dragonfly.h b/src/wasm_sdk/cpp/include/dragonfly.h index 2aef03094e83..c8d2ec209fb3 100644 --- a/src/wasm_sdk/cpp/include/dragonfly.h +++ b/src/wasm_sdk/cpp/include/dragonfly.h @@ -2,6 +2,11 @@ // See LICENSE for licensing terms. // +#include +#include +#include +#include + namespace dragonfly { /* PUBLIC FACING API */ @@ -10,11 +15,17 @@ extern "C" { #define WASM_IMPORT(mod, name) __attribute__((import_module(#mod), import_name(#name))) WASM_IMPORT(dragonfly, hello) -void hello(); +uint8_t* hello(); // Add rest of functions here } +inline std::string deserialize(uint8_t* ptr); + +inline std::string hello_world() { + return deserialize(hello()); +} + /* Used to export functions from wasm modules */ // Use this macro if you want your function to be available within dragonfly, @@ -22,4 +33,23 @@ void hello(); // then you have to register it first via DF_EXPORT("my_fun") #define DF_EXPORT(name) __attribute__((export_name(name))) +/* Private and NOT part of the public API */ +DF_EXPORT("allocate_on_guest_mem") +inline uint8_t* allocate_on_guest_mem(size_t bytes) { + return new uint8_t[bytes]; +} + +/* Entry point to deserialize data coming from Dragonfly */ +/* For now this is hardcoded and only returns a string and should */ +/* be extended with json */ +inline std::string deserialize(/*Get ownership*/ uint8_t* ptr) { + // TODO Figure out how to reduce copies. This is two copies: + // 1. Host allocates via allocate_on_guest_mem and copies data + // 2. Data is deserialized on a new location + char* start = reinterpret_cast(ptr); + std::string res(start); + delete[] ptr; + return res; +} + } // namespace dragonfly From 7969f777a162e0cd55611be95b6809f6cf832f78 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Fri, 17 May 2024 14:38:14 +0300 Subject: [PATCH 4/6] wasm: Calling commands (#3048) * chore: back and forth --- src/facade/CMakeLists.txt | 2 +- src/facade/reply_capture.h | 4 - src/facade/reply_formats.cc | 177 +++++++++++++++++++++++++++ src/facade/reply_formats.h | 9 ++ src/server/http_api.cc | 156 +---------------------- src/server/main_service.cc | 5 +- src/server/main_service.h | 3 +- src/server/wasm/api.h | 4 +- src/server/wasm/wasm_family.cc | 5 +- src/server/wasm/wasm_family.h | 3 +- src/server/wasm/wasm_registry.cc | 87 +++++++++---- src/server/wasm/wasm_registry.h | 15 ++- src/wasm_sdk/cpp/examples/example.cc | 15 ++- src/wasm_sdk/cpp/include/dragonfly.h | 48 ++++---- 14 files changed, 315 insertions(+), 218 deletions(-) create mode 100644 src/facade/reply_formats.cc create mode 100644 src/facade/reply_formats.h diff --git a/src/facade/CMakeLists.txt b/src/facade/CMakeLists.txt index feef1da4883a..ec8b139caee8 100644 --- a/src/facade/CMakeLists.txt +++ b/src/facade/CMakeLists.txt @@ -1,6 +1,6 @@ add_library(dfly_facade conn_context.cc dragonfly_listener.cc dragonfly_connection.cc facade.cc memcache_parser.cc redis_parser.cc reply_builder.cc op_status.cc service_interface.cc - reply_capture.cc resp_expr.cc cmd_arg_parser.cc tls_error.cc) + reply_capture.cc resp_expr.cc cmd_arg_parser.cc tls_error.cc reply_formats.cc) if (DF_USE_SSL) set(TLS_LIB tls_lib) diff --git a/src/facade/reply_capture.h b/src/facade/reply_capture.h index 7fe2843d23a7..d571071a0b06 100644 --- a/src/facade/reply_capture.h +++ b/src/facade/reply_capture.h @@ -14,14 +14,10 @@ namespace facade { -struct CaptureVisitor; - // CapturingReplyBuilder allows capturing replies and retrieveing them with Take(). // Those replies can be stored standalone and sent with // CapturingReplyBuilder::Apply() to another reply builder. class CapturingReplyBuilder : public RedisReplyBuilder { - friend struct CaptureVisitor; - public: void SendError(std::string_view str, std::string_view type = {}) override; void SendError(ErrorReply error) override; diff --git a/src/facade/reply_formats.cc b/src/facade/reply_formats.cc new file mode 100644 index 000000000000..5d5ed2a4ec50 --- /dev/null +++ b/src/facade/reply_formats.cc @@ -0,0 +1,177 @@ +#include "facade/reply_formats.h" + +#include + +#include "absl/strings/str_cat.h" +#include "base/logging.h" +#include "facade/reply_capture.h" + +namespace facade { + +namespace { + +using namespace std; + +// Escape a string so that it is legal to print it in JSON text. +std::string JsonEscape(string_view input) { + auto hex_digit = [](unsigned c) -> char { + DCHECK_LT(c, 0xFu); + return c < 10 ? c + '0' : c - 10 + 'a'; + }; + + string out; + out.reserve(input.size() + 2); + out.push_back('\"'); + + auto p = input.begin(); + auto e = input.end(); + + while (p < e) { + uint8_t c = *p; + if (c == '\\' || c == '\"') { + out.push_back('\\'); + out.push_back(*p++); + } else if (c <= 0x1f) { + switch (c) { + case '\b': + out.append("\\b"); + p++; + break; + case '\f': + out.append("\\f"); + p++; + break; + case '\n': + out.append("\\n"); + p++; + break; + case '\r': + out.append("\\r"); + p++; + break; + case '\t': + out.append("\\t"); + p++; + break; + default: + // this condition captures non readable chars with value < 32, + // so size = 1 byte (e.g control chars). + out.append("\\u00"); + out.push_back(hex_digit((c & 0xf0) >> 4)); + out.push_back(hex_digit(c & 0xf)); + p++; + } + } else { + out.push_back(*p++); + } + } + + out.push_back('\"'); + return out; +} + +struct CaptureVisitor { + CaptureVisitor() { + str = R"({"result":)"; + } + + void operator()(monostate) { + } + + void operator()(long v) { + absl::StrAppend(&str, v); + } + + void operator()(double v) { + absl::StrAppend(&str, v); + } + + void operator()(const CapturingReplyBuilder::SimpleString& ss) { + absl::StrAppend(&str, "\"", ss, "\""); + } + + void operator()(const CapturingReplyBuilder::BulkString& bs) { + absl::StrAppend(&str, JsonEscape(bs)); + } + + void operator()(CapturingReplyBuilder::Null) { + absl::StrAppend(&str, "null"); + } + + void operator()(CapturingReplyBuilder::Error err) { + str = absl::StrCat(R"({"error": ")", err.first, "\""); + } + + void operator()(facade::OpStatus status) { + absl::StrAppend(&str, "\"", facade::StatusToMsg(status), "\""); + } + + void operator()(const CapturingReplyBuilder::StrArrPayload& sa) { + absl::StrAppend(&str, "["); + for (const auto& val : sa.arr) { + absl::StrAppend(&str, JsonEscape(val), ","); + } + if (sa.arr.size()) + str.pop_back(); + absl::StrAppend(&str, "]"); + } + + void operator()(const unique_ptr& cp) { + if (!cp) { + absl::StrAppend(&str, "null"); + return; + } + if (cp->len == 0 && cp->type == facade::RedisReplyBuilder::ARRAY) { + absl::StrAppend(&str, "[]"); + return; + } + + absl::StrAppend(&str, "["); + for (auto& pl : cp->arr) { + visit(*this, std::move(pl)); + } + } + + void operator()(const facade::SinkReplyBuilder::MGetResponse& resp) { + absl::StrAppend(&str, "["); + for (const auto& val : resp.resp_arr) { + if (val) { + absl::StrAppend(&str, JsonEscape(val->value), ","); + } else { + absl::StrAppend(&str, "null,"); + } + } + + if (resp.resp_arr.size()) + str.pop_back(); + absl::StrAppend(&str, "]"); + } + + void operator()(const CapturingReplyBuilder::ScoredArray& sarr) { + absl::StrAppend(&str, "["); + for (const auto& [key, score] : sarr.arr) { + absl::StrAppend(&str, "{", JsonEscape(key), ":", score, "},"); + } + if (sarr.arr.size() > 0) { + str.pop_back(); + } + absl::StrAppend(&str, "]"); + } + + string Take() { + absl::StrAppend(&str, "}\r\n"); + return std::move(str); + } + + string str; +}; + +} // namespace + +std::string FormatToJson(CapturingReplyBuilder::Payload&& payload) { + CaptureVisitor visitor{}; + std::visit(visitor, payload); + return visitor.Take(); +} + +}; // namespace facade diff --git a/src/facade/reply_formats.h b/src/facade/reply_formats.h new file mode 100644 index 000000000000..1822ebdb6c2a --- /dev/null +++ b/src/facade/reply_formats.h @@ -0,0 +1,9 @@ +#pragma once + +#include "facade/reply_capture.h" + +namespace facade { + +std::string FormatToJson(facade::CapturingReplyBuilder::Payload&& value); + +}; diff --git a/src/server/http_api.cc b/src/server/http_api.cc index c00d2a3725d6..bfc1004d5937 100644 --- a/src/server/http_api.cc +++ b/src/server/http_api.cc @@ -8,6 +8,7 @@ #include "core/flatbuffers.h" #include "facade/conn_context.h" #include "facade/reply_builder.h" +#include "facade/reply_formats.h" #include "server/main_service.h" #include "util/http/http_common.h" @@ -37,155 +38,6 @@ bool IsVectorOfStrings(flexbuffers::Reference req) { return true; } -// Escape a string so that it is legal to print it in JSON text. -std::string JsonEscape(string_view input) { - auto hex_digit = [](unsigned c) -> char { - DCHECK_LT(c, 0xFu); - return c < 10 ? c + '0' : c - 10 + 'a'; - }; - - string out; - out.reserve(input.size() + 2); - out.push_back('\"'); - - auto p = input.begin(); - auto e = input.end(); - - while (p < e) { - uint8_t c = *p; - if (c == '\\' || c == '\"') { - out.push_back('\\'); - out.push_back(*p++); - } else if (c <= 0x1f) { - switch (c) { - case '\b': - out.append("\\b"); - p++; - break; - case '\f': - out.append("\\f"); - p++; - break; - case '\n': - out.append("\\n"); - p++; - break; - case '\r': - out.append("\\r"); - p++; - break; - case '\t': - out.append("\\t"); - p++; - break; - default: - // this condition captures non readable chars with value < 32, - // so size = 1 byte (e.g control chars). - out.append("\\u00"); - out.push_back(hex_digit((c & 0xf0) >> 4)); - out.push_back(hex_digit(c & 0xf)); - p++; - } - } else { - out.push_back(*p++); - } - } - - out.push_back('\"'); - return out; -} - -struct CaptureVisitor { - CaptureVisitor() { - str = R"({"result":)"; - } - - void operator()(monostate) { - } - - void operator()(long v) { - absl::StrAppend(&str, v); - } - - void operator()(double v) { - absl::StrAppend(&str, v); - } - - void operator()(const CapturingReplyBuilder::SimpleString& ss) { - absl::StrAppend(&str, "\"", ss, "\""); - } - - void operator()(const CapturingReplyBuilder::BulkString& bs) { - absl::StrAppend(&str, JsonEscape(bs)); - } - - void operator()(CapturingReplyBuilder::Null) { - absl::StrAppend(&str, "null"); - } - - void operator()(CapturingReplyBuilder::Error err) { - str = absl::StrCat(R"({"error": ")", err.first, "\""); - } - - void operator()(facade::OpStatus status) { - absl::StrAppend(&str, "\"", facade::StatusToMsg(status), "\""); - } - - void operator()(const CapturingReplyBuilder::StrArrPayload& sa) { - absl::StrAppend(&str, "["); - for (const auto& val : sa.arr) { - absl::StrAppend(&str, JsonEscape(val), ","); - } - if (sa.arr.size()) - str.pop_back(); - absl::StrAppend(&str, "]"); - } - - void operator()(unique_ptr cp) { - if (!cp) { - absl::StrAppend(&str, "null"); - return; - } - if (cp->len == 0 && cp->type == facade::RedisReplyBuilder::ARRAY) { - absl::StrAppend(&str, "[]"); - return; - } - - absl::StrAppend(&str, "["); - for (auto& pl : cp->arr) { - visit(*this, std::move(pl)); - } - } - - void operator()(facade::SinkReplyBuilder::MGetResponse resp) { - absl::StrAppend(&str, "["); - for (const auto& val : resp.resp_arr) { - if (val) { - absl::StrAppend(&str, JsonEscape(val->value), ","); - } else { - absl::StrAppend(&str, "null,"); - } - } - - if (resp.resp_arr.size()) - str.pop_back(); - absl::StrAppend(&str, "]"); - } - - void operator()(const CapturingReplyBuilder::ScoredArray& sarr) { - absl::StrAppend(&str, "["); - for (const auto& [key, score] : sarr.arr) { - absl::StrAppend(&str, "{", JsonEscape(key), ":", score, "},"); - } - if (sarr.arr.size() > 0) { - str.pop_back(); - } - absl::StrAppend(&str, "]"); - } - - string str; -}; - } // namespace void HttpAPI(const http::QueryArgs& args, HttpRequest&& req, Service* service, @@ -231,16 +83,12 @@ void HttpAPI(const http::QueryArgs& args, HttpRequest&& req, Service* service, auto* prev = context->Inject(&reply_builder); // TODO: to finish this. service->DispatchCommand(absl::MakeSpan(cmd_slices), context); - facade::CapturingReplyBuilder::Payload payload = reply_builder.Take(); context->Inject(prev); auto response = http::MakeStringResponse(); http::SetMime(http::kJsonMime, &response); - CaptureVisitor visitor; - std::visit(visitor, std::move(payload)); - visitor.str.append("}\r\n"); - response.body() = visitor.str; + response.body() = facade::FormatToJson(reply_builder.Take()); http_cntx->Invoke(std::move(response)); } diff --git a/src/server/main_service.cc b/src/server/main_service.cc index 4244e273323b..ddd9fd8e8b51 100644 --- a/src/server/main_service.cc +++ b/src/server/main_service.cc @@ -789,7 +789,8 @@ Service::Service(ProactorPool* pp) : pp_(*pp), acl_family_(&user_registry_, pp), server_family_(this), - cluster_family_(&server_family_) { + cluster_family_(&server_family_), + wasm_family_(*this) { CHECK(pp); CHECK(shard_set == NULL); @@ -1255,7 +1256,7 @@ void Service::DispatchCommand(CmdArgList args, facade::ConnectionContext* cntx) // if this is a read command, and client tracking has enabled, // start tracking all the updates to the keys in this read command - if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn()->IsTrackingOn() && + if ((cid->opt_mask() & CO::READONLY) && dfly_cntx->conn() && dfly_cntx->conn()->IsTrackingOn() && cid->IsTransactional()) { facade::Connection::WeakRef conn_ref = dfly_cntx->conn()->Borrow(); auto cb = [&, conn_ref](Transaction* t, EngineShard* shard) { diff --git a/src/server/main_service.h b/src/server/main_service.h index b70d539f2631..0efe05728f5b 100644 --- a/src/server/main_service.h +++ b/src/server/main_service.h @@ -184,9 +184,10 @@ class Service : public facade::ServiceInterface { acl::AclFamily acl_family_; ServerFamily server_family_; cluster::ClusterFamily cluster_family_; + wasm::WasmFamily wasm_family_; + CommandRegistry registry_; absl::flat_hash_map unknown_cmds_; - wasm::WasmFamily wasm_family_; const CommandId* exec_cid_; // command id of EXEC command for pipeline squashing diff --git a/src/server/wasm/api.h b/src/server/wasm/api.h index 727dd13269bf..099c536ee43f 100644 --- a/src/server/wasm/api.h +++ b/src/server/wasm/api.h @@ -4,6 +4,7 @@ #pragma once +#include #include #include #include @@ -14,11 +15,12 @@ namespace dfly::wasm::api { +// TODO Make extensible or don't use it at all template bool RegisterApiFunction(std::string_view name, Fn f, wasmtime::Linker* linker) { std::string module_name = "dragonfly"; - auto args_signature = wasmtime::FuncType({}, {wasmtime::ValKind::I32}); + auto args_signature = wasmtime::FuncType({wasmtime::ValKind::I32}, {}); auto res = linker->func_new(module_name, name, args_signature, f); return (bool)res; diff --git a/src/server/wasm/wasm_family.cc b/src/server/wasm/wasm_family.cc index c5c0c2273b6f..14c9c466a4ea 100644 --- a/src/server/wasm/wasm_family.cc +++ b/src/server/wasm/wasm_family.cc @@ -6,6 +6,7 @@ #include "absl/strings/str_cat.h" #include "base/flags.h" #include "facade/facade_types.h" +#include "facade/service_interface.h" #include "server/acl/acl_commands_def.h" #include "server/command_registry.h" #include "server/conn_context.h" @@ -24,9 +25,9 @@ CommandId::Handler HandlerFunc(WasmFamily* wasm, MemberFunc f) { #define HFUNC(x) SetHandler(HandlerFunc(this, &WasmFamily::x)) -WasmFamily::WasmFamily() { +WasmFamily::WasmFamily(facade::ServiceInterface& service) { if (auto wasm_modules = absl::GetFlag(FLAGS_wasmpaths); !wasm_modules.empty()) { - registry_ = std::make_unique(); + registry_ = std::make_unique(service); } } diff --git a/src/server/wasm/wasm_family.h b/src/server/wasm/wasm_family.h index af52dda01675..1233caaae7b9 100644 --- a/src/server/wasm/wasm_family.h +++ b/src/server/wasm/wasm_family.h @@ -7,6 +7,7 @@ #include #include "facade/facade_types.h" +#include "facade/service_interface.h" #include "server/command_registry.h" #include "server/wasm/wasm_registry.h" @@ -17,7 +18,7 @@ namespace wasm { class WasmFamily final { public: - WasmFamily(); + WasmFamily(facade::ServiceInterface& service); void Register(CommandRegistry* registry); private: diff --git a/src/server/wasm/wasm_registry.cc b/src/server/wasm/wasm_registry.cc index 049674643951..90ed5adfe300 100644 --- a/src/server/wasm/wasm_registry.cc +++ b/src/server/wasm/wasm_registry.cc @@ -4,8 +4,12 @@ #include "server/wasm/wasm_registry.h" +#include +#include +#include #include +#include #include #include @@ -14,7 +18,12 @@ #include "absl/strings/str_split.h" #include "base/flags.h" #include "base/logging.h" +#include "facade/facade_types.h" +#include "facade/reply_capture.h" +#include "facade/reply_formats.h" +#include "facade/service_interface.h" #include "io/file_util.h" +#include "server/conn_context.h" #include "server/wasm/api.h" #include "server/wasm/wasmtime.hh" @@ -22,35 +31,69 @@ ABSL_DECLARE_FLAG(std::string, wasmpaths); namespace dfly::wasm { -WasmRegistry::WasmRegistry() - : engine_(WasmRegistry::GetConfig()), linker_(engine_), store_(engine_) { - auto hellofunc = [](wasmtime::Caller caller, auto params, auto results) { - auto res = caller.get_export("allocate_on_guest_mem"); - std::string value = "Hello world from wasm!"; - value.push_back('\0'); - - // Call the exported alloc to allocate memory on the guest - auto alloc = std::get(*res); - const int32_t alloc_size = static_cast(value.size()); - auto result = alloc.call(caller.context(), {wasmtime::Val{alloc_size}}); - if (!result) { - // handle errors - } - auto wasm_value = result.ok().front(); - auto offset = wasm_value.i32(); +namespace { + +std::vector ParseArguments(uint8_t* data) { + uint32_t parts; + memcpy(&parts, data, sizeof(uint32_t)); + data += sizeof(uint32_t); + + std::vector out; + for (int32_t i = 0; i < parts; i++) { + uint32_t length; + memcpy(&length, data, sizeof(uint32_t)); + data += sizeof(uint32_t); + + out.emplace_back(reinterpret_cast(data), length); + data += length; + } + + return out; +} + +std::string CallCommand(facade::ServiceInterface* service, + std::vector arguments) { + facade::CapturingReplyBuilder capture; + ConnectionContext cntx(nullptr, nullptr); + delete cntx.Inject(&capture); + + service->DispatchCommand(absl::MakeSpan(arguments), &cntx); + cntx.Inject(nullptr); + return facade::FormatToJson(capture.Take()); +} + +std::optional> ProvideMemory(wasmtime::Caller* caller, wasmtime::Memory* memory, + size_t bytes) { + auto res = caller->get_export("provide_buffer"); + auto alloc_func = std::get(*res); + + auto alloc_res = alloc_func.call(caller->context(), {wasmtime::Val{int32_t(bytes)}}); + if (!alloc_res) + return std::nullopt; + + int32_t offset = alloc_res.ok().front().i32(); + uint8_t* ptr = memory->data(caller->context()).data() + offset; + return {{ptr, bytes}}; +} + +} // namespace + +WasmRegistry::WasmRegistry(facade::ServiceInterface& service) + : engine_(WasmRegistry::GetConfig()), linker_(engine_), store_(engine_) { + auto callfunc = [&service](wasmtime::Caller caller, wasmtime::Span params, + auto results) { wasmtime::Memory memory = std::get(*caller.get_export("memory")); - uint8_t* data = memory.data(caller.context()).data() + offset; - absl::little_endian::Store32(data, value.size()); - // TODO inject payload size at the front so we dont have to push an extra \0 - memcpy(data, value.c_str(), value.size()); + std::string result = CallCommand( + &service, ParseArguments(memory.data(caller.context()).data() + params[0].i32())); - results[0] = wasm_value; + auto result_buffer = ProvideMemory(&caller, &memory, result.size()); + memcpy(result_buffer->data(), result.data(), result.size()); return std::monostate(); }; - api::RegisterApiFunction("hello", hellofunc, &linker_); + api::RegisterApiFunction("call", callfunc, &linker_); wasmtime::WasiConfig wasi; wasi.inherit_argv(); diff --git a/src/server/wasm/wasm_registry.h b/src/server/wasm/wasm_registry.h index a14d8c225212..8d1919cb5b3e 100644 --- a/src/server/wasm/wasm_registry.h +++ b/src/server/wasm/wasm_registry.h @@ -12,6 +12,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "base/logging.h" +#include "facade/service_interface.h" #include "server/wasm/api.h" #include "server/wasm/wasmtime.hh" #include "util/fibers/synchronization.h" @@ -38,10 +39,11 @@ class WasmModule { class WasmRegistry { public: - WasmRegistry(); + WasmRegistry(facade::ServiceInterface& service); WasmRegistry(const WasmRegistry&) = delete; WasmRegistry(WasmRegistry&&) = delete; ~WasmRegistry(); + bool Delete(std::string_view name); // Very light-weight. Each Module is compiled *once* but each UDF call, e,g, `CALLWASM` @@ -78,6 +80,12 @@ class WasmRegistry { std::optional GetInstanceFromModule(std::string_view module_name); private: + static wasmtime::Config GetConfig() { + wasmtime::Config config; + config.epoch_interruption(false); + return config; + } + std::string Add(std::string_view path); void InstantiateAndLinkModules(); @@ -89,11 +97,6 @@ class WasmRegistry { wasmtime::Engine engine_; wasmtime::Linker linker_; wasmtime::Store store_; - static wasmtime::Config GetConfig() { - wasmtime::Config config; - config.epoch_interruption(false); - return config; - } }; } // namespace dfly::wasm diff --git a/src/wasm_sdk/cpp/examples/example.cc b/src/wasm_sdk/cpp/examples/example.cc index 075e81bc9867..c89df330ff43 100644 --- a/src/wasm_sdk/cpp/examples/example.cc +++ b/src/wasm_sdk/cpp/examples/example.cc @@ -1,3 +1,5 @@ +#include + #include "dragonfly.h" // wasi-sdk-22.0/bin/clang++ -std=c++11 example.cc -o example.wasm @@ -10,9 +12,16 @@ DF_EXPORT("my_fun") void my_fun() { - std::string result = dragonfly::hello_world(); - // passes - assert(result == "Hello world from wasm!"); + std::string result; + + result = dragonfly::call({"GET", "A"}); + printf("%s", result.c_str()); + + result = dragonfly::call({"SET", "A", "WORKS"}); + printf("%s", result.c_str()); + + result = dragonfly::call({"GET", "A"}); + printf("%s", result.c_str()); // when we return the bson here we need to be carefull with the allocation. // does that mean the host will need to free the memory? e.g, export a `free` function? } diff --git a/src/wasm_sdk/cpp/include/dragonfly.h b/src/wasm_sdk/cpp/include/dragonfly.h index c8d2ec209fb3..11c8d9eea6cc 100644 --- a/src/wasm_sdk/cpp/include/dragonfly.h +++ b/src/wasm_sdk/cpp/include/dragonfly.h @@ -3,7 +3,9 @@ // #include +#include #include +#include #include #include @@ -14,16 +16,32 @@ namespace dragonfly { extern "C" { #define WASM_IMPORT(mod, name) __attribute__((import_module(#mod), import_name(#name))) -WASM_IMPORT(dragonfly, hello) -uint8_t* hello(); +WASM_IMPORT(dragonfly, call) +void call(uint32_t); // Add rest of functions here } -inline std::string deserialize(uint8_t* ptr); +// guard against multiple defines +std::string call_buffer; -inline std::string hello_world() { - return deserialize(hello()); +inline std::string_view call(std::initializer_list arguments) { + std::string data(4, 'x'); + + uint32_t parts = arguments.size(); + memcpy((void*)data.data(), &parts, 4); + + for (const std::string& str : arguments) { + data.append(4, 'x'); + + uint32_t strsize = str.size(); + memcpy((void*)(data.data() + data.size() - 4), &strsize, 4); + + data += str; + } + + call((uint64_t)data.data()); + return call_buffer; } /* Used to export functions from wasm modules */ @@ -34,22 +52,10 @@ inline std::string hello_world() { #define DF_EXPORT(name) __attribute__((export_name(name))) /* Private and NOT part of the public API */ -DF_EXPORT("allocate_on_guest_mem") -inline uint8_t* allocate_on_guest_mem(size_t bytes) { - return new uint8_t[bytes]; -} - -/* Entry point to deserialize data coming from Dragonfly */ -/* For now this is hardcoded and only returns a string and should */ -/* be extended with json */ -inline std::string deserialize(/*Get ownership*/ uint8_t* ptr) { - // TODO Figure out how to reduce copies. This is two copies: - // 1. Host allocates via allocate_on_guest_mem and copies data - // 2. Data is deserialized on a new location - char* start = reinterpret_cast(ptr); - std::string res(start); - delete[] ptr; - return res; +DF_EXPORT("provide_buffer") +inline char* provide_buffer(size_t bytes) { + call_buffer.resize(bytes); + return call_buffer.data(); } } // namespace dragonfly From 1380986e5f9b81b33032058333c5f077548f7e4c Mon Sep 17 00:00:00 2001 From: Vladislav Date: Fri, 24 May 2024 10:51:22 +0300 Subject: [PATCH 5/6] wasm: Rust support (#3061) --- src/server/wasm/api.h | 1 + src/server/wasm/wasm_family.cc | 11 ++-- src/server/wasm/wasm_registry.h | 21 +++++- src/wasm_sdk/cpp/include/dragonfly.h | 2 +- src/wasm_sdk/rustw/Cargo.lock | 95 ++++++++++++++++++++++++++++ src/wasm_sdk/rustw/Cargo.toml | 13 ++++ src/wasm_sdk/rustw/README.md | 3 + src/wasm_sdk/rustw/src/lib.rs | 68 ++++++++++++++++++++ 8 files changed, 205 insertions(+), 9 deletions(-) create mode 100644 src/wasm_sdk/rustw/Cargo.lock create mode 100644 src/wasm_sdk/rustw/Cargo.toml create mode 100644 src/wasm_sdk/rustw/README.md create mode 100644 src/wasm_sdk/rustw/src/lib.rs diff --git a/src/server/wasm/api.h b/src/server/wasm/api.h index 099c536ee43f..5cc3edec5d95 100644 --- a/src/server/wasm/api.h +++ b/src/server/wasm/api.h @@ -23,6 +23,7 @@ bool RegisterApiFunction(std::string_view name, Fn f, wasmtime::Linker* linker) auto args_signature = wasmtime::FuncType({wasmtime::ValKind::I32}, {}); auto res = linker->func_new(module_name, name, args_signature, f); + linker->func_new("env", name, args_signature, f); return (bool)res; } diff --git a/src/server/wasm/wasm_family.cc b/src/server/wasm/wasm_family.cc index 14c9c466a4ea..0c5a05519488 100644 --- a/src/server/wasm/wasm_family.cc +++ b/src/server/wasm/wasm_family.cc @@ -5,6 +5,7 @@ #include "absl/strings/str_cat.h" #include "base/flags.h" +#include "core/overloaded.h" #include "facade/facade_types.h" #include "facade/service_interface.h" #include "server/acl/acl_commands_def.h" @@ -71,11 +72,11 @@ void WasmFamily::Call(CmdArgList args, ConnectionContext* cntx) { // When we switch to async execution of the runtime this will be able to suspend and resume // which will make the current approach invalid. auto wasm_result = wasm_function(exported_fun_name); - if (!wasm_result.empty()) { - cntx->SendError(wasm_result); - return; - } - cntx->SendOk(); + + Overloaded result_handler{[cntx](std::monostate) { cntx->SendOk(); }, + [cntx](facade::ErrorReply err) { cntx->SendError(err); }, + [cntx](std::string result) { cntx->SendSimpleString(result); }}; + visit(result_handler, wasm_result); } void WasmFamily::Delete(CmdArgList args, ConnectionContext* cntx) { diff --git a/src/server/wasm/wasm_registry.h b/src/server/wasm/wasm_registry.h index 8d1919cb5b3e..c64d90cb7a76 100644 --- a/src/server/wasm/wasm_registry.h +++ b/src/server/wasm/wasm_registry.h @@ -11,7 +11,9 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" +#include "base/expected.hpp" #include "base/logging.h" +#include "facade/facade_types.h" #include "facade/service_interface.h" #include "server/wasm/api.h" #include "server/wasm/wasmtime.hh" @@ -56,16 +58,29 @@ class WasmRegistry { : instance_{instance}, store_(store) { } - std::string operator()(std::string_view export_func_name) { + std::variant operator()( + std::string_view export_func_name) { // Users will export functions for their modules via the attribute // __attribute__((export_name(func_name))). We will expose this in our sdk auto extern_def = instance_.get(*store_, export_func_name); if (!extern_def) { - return absl::StrCat("No exported function with name ", export_func_name, " found"); + return facade::ErrorReply( + absl::StrCat("No exported function with name ", export_func_name, " found")); } + auto run = std::get(*extern_def); auto res = run.call(store_, {}).unwrap(); - return {}; + + if (res.size() == 1) { + uint32_t offset = res[0].i32(); + wasmtime::Memory mem = + std::get(*instance_.get(store_->context(), "memory")); + + char* ptr = reinterpret_cast(mem.data(store_->context()).data() + offset); + return std::string{ptr, strlen(ptr)}; + } + + return std::monostate{}; } wasmtime::Instance* GetInstance() { diff --git a/src/wasm_sdk/cpp/include/dragonfly.h b/src/wasm_sdk/cpp/include/dragonfly.h index 11c8d9eea6cc..f481f87479c7 100644 --- a/src/wasm_sdk/cpp/include/dragonfly.h +++ b/src/wasm_sdk/cpp/include/dragonfly.h @@ -53,7 +53,7 @@ inline std::string_view call(std::initializer_list arguments) { /* Private and NOT part of the public API */ DF_EXPORT("provide_buffer") -inline char* provide_buffer(size_t bytes) { +inline char* provide_buffer(int32_t bytes) { call_buffer.resize(bytes); return call_buffer.data(); } diff --git a/src/wasm_sdk/rustw/Cargo.lock b/src/wasm_sdk/rustw/Cargo.lock new file mode 100644 index 000000000000..7d307bdb63b5 --- /dev/null +++ b/src/wasm_sdk/rustw/Cargo.lock @@ -0,0 +1,95 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "build_html" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3108fe6fe7ac796fb7625bdde8fa2b67b5a7731496251ca57c7b8cadd78a16a1" + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "proc-macro2" +version = "1.0.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rustw" +version = "0.1.0" +dependencies = [ + "build_html", + "serde_json", +] + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "serde" +version = "1.0.202" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "226b61a0d411b2ba5ff6d7f73a476ac4f8bb900373459cd00fab8512828ba395" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.202" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6048858004bcff69094cd972ed40a32500f153bd3be9f716b2eed2e8217c4838" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "syn" +version = "2.0.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ad3dee41f36859875573074334c200d1add8e4a87bb37113ebd31d926b7b11f" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" diff --git a/src/wasm_sdk/rustw/Cargo.toml b/src/wasm_sdk/rustw/Cargo.toml new file mode 100644 index 000000000000..46ff737a416c --- /dev/null +++ b/src/wasm_sdk/rustw/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "rustw" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +crate-type = ["cdylib"] + +[dependencies] +build_html = "2.4.0" +serde_json = "1.0.117" diff --git a/src/wasm_sdk/rustw/README.md b/src/wasm_sdk/rustw/README.md new file mode 100644 index 000000000000..9e70099d5d96 --- /dev/null +++ b/src/wasm_sdk/rustw/README.md @@ -0,0 +1,3 @@ +https://github.com/bytecodealliance/cargo-wasi/tree/main + +then just run `cargo wasi build` diff --git a/src/wasm_sdk/rustw/src/lib.rs b/src/wasm_sdk/rustw/src/lib.rs new file mode 100644 index 000000000000..b292e2be2458 --- /dev/null +++ b/src/wasm_sdk/rustw/src/lib.rs @@ -0,0 +1,68 @@ +use std::io::Write; + +use build_html::{Html, HtmlContainer}; + +extern "C" { + pub fn call(cmd: *const u8); +} + +static mut BUFFER: Option = None; + +pub fn run(command: &[&str]) -> serde_json::Value { + let mut bytes = Vec::::new(); + + let res = unsafe { + bytes + .write(&std::mem::transmute::(command.len() as i32)) + .unwrap(); + + for part in command { + bytes + .write(&std::mem::transmute::(part.len() as i32)) + .unwrap(); + bytes.write(part.as_bytes()).unwrap(); + } + + call(bytes.as_ptr()); + + BUFFER.take().unwrap() + }; + + println!("GOT! {}", res); + serde_json::from_str(&res).unwrap() +} + +#[no_mangle] +pub unsafe fn provide_buffer(bytes: i32) -> *mut u8 { + BUFFER.insert(str::repeat(" ", bytes as usize)).as_mut_ptr() +} + +pub unsafe fn leak_string(s: String) -> *const u8{ + let len = s.len(); + let ptr = s.leak().as_mut_ptr(); + *ptr.wrapping_add(len) = b'\0'; + ptr +} + +#[no_mangle] +pub fn my_fun() -> *const u8 { + let titles: Vec = run(&["LRANGE", "TITLES", "0", "-1"]) + .get("result") + .unwrap() + .as_array() + .unwrap() + .iter() + .map(|v| v.as_str().unwrap().to_owned()) + .collect(); + + let list_items = titles.into_iter() + .map(|t| format!("Title item {}", t)) + .fold(build_html::Container::new(build_html::ContainerType::OrderedList), |a, n| a.with_paragraph(n)); + + let page = build_html::HtmlPage::new() + .with_title("MY ENTRIES") + .with_container(build_html::Container::default() + .with_container(list_items) + ).to_html_string(); + return unsafe { leak_string(page) }; +} From c3764d5d60c8eb5f7bbb7eca69618d2154e54e34 Mon Sep 17 00:00:00 2001 From: Kostas Kyrimis Date: Fri, 24 May 2024 10:54:23 +0300 Subject: [PATCH 6/6] chore(wasm): add Go SDK (#3059) --- src/wasm_sdk/go/README.md | 6 +++++ src/wasm_sdk/go/example.go | 49 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 src/wasm_sdk/go/README.md create mode 100644 src/wasm_sdk/go/example.go diff --git a/src/wasm_sdk/go/README.md b/src/wasm_sdk/go/README.md new file mode 100644 index 000000000000..eb74e560e152 --- /dev/null +++ b/src/wasm_sdk/go/README.md @@ -0,0 +1,6 @@ +# Dragonfly Go SDK for wasm + +1. Download tiny go +2. Compile with wasi as target: `tinygo build -o go.wasm -target=wasi example.go` +3. Load the module at startup via `WASMPATH` +4. Call an exported function via `WASMCALL go.wasm go_hi` diff --git a/src/wasm_sdk/go/example.go b/src/wasm_sdk/go/example.go new file mode 100644 index 000000000000..5297d74a61d2 --- /dev/null +++ b/src/wasm_sdk/go/example.go @@ -0,0 +1,49 @@ +package main + +import "encoding/binary" +import "strings" + +var buffer []byte + +//export provide_buffer +func provide_buffer(bytes int) *byte { + buffer = []byte(strings.Repeat("x", bytes)) + return &buffer[0] +} + +//go:wasm-module dragonfly +//export call +func call(str *byte) + +func toByteArray(i int) (arr []byte) { + arr = []byte("xxxx") + binary.LittleEndian.PutUint32(arr[0:4], uint32(i)) + return +} + +func send(args... string) string { + var data []byte + data = append(data, toByteArray(len(args))...) + data = data[0:4] + + for _, element := range args { + var sz = toByteArray(len(element)) + var payload = []byte(element) + data = append(data, sz...) + data = append(data, payload...) + } + + call(&data[0]) + return string(buffer) +} + +//export go_hi +func go_hi() { + var res = send("set", "foo", "bar"); + println("Result is ", res) + res = send("get", "foo"); + println("Result is ", res) +} + +// main is required for the `wasi` target, even if it isn't used. +func main() {}