From 49e6f12e4fe0c3fa6832f77ca4b3b648f15cda7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Wed, 4 Dec 2024 08:38:45 +0100 Subject: [PATCH] Added tests for ThreadState freezing --- tests/CMakeLists.txt | 2 +- tests/record.cpp | 250 +++++++++++++++++++++++++++++++++++++++++++ tests/test.h | 245 ++++++++++++++++++++++++++++++++++++++++++ tests/vcall.cpp | 81 ++++++++++++++ 4 files changed, 577 insertions(+), 1 deletion(-) create mode 100644 tests/record.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7c4573e5..117dfe68 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,6 +1,6 @@ enable_testing() -set(TEST_FILES basics.cpp mem.cpp loop.cpp vcall.cpp graphviz.cpp reductions.cpp array.cpp) +set(TEST_FILES basics.cpp mem.cpp loop.cpp vcall.cpp graphviz.cpp reductions.cpp array.cpp record.cpp) add_executable(test_half half.cpp) add_test(NAME test_half COMMAND test_half diff --git a/tests/record.cpp b/tests/record.cpp new file mode 100644 index 00000000..481b7e27 --- /dev/null +++ b/tests/record.cpp @@ -0,0 +1,250 @@ +#include "drjit-core/array.h" +#include "drjit-core/jit.h" +#include "test.h" + +/** + * Basic addition test. + * Supplying a different input should replay the operation, with this input. + * In this case, the input at replay is incremented and should result in an + * incremented output. + */ +TEST_BOTH(01_basic_replay) { + + auto func = [](UInt32 input) { return input + 1; }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 3; i++) { + auto input = arange(10 + i); + + auto result = frozen(input); + + auto reference = func(input); + + jit_assert(all(eq(result, reference))); + } +} + +/** + * This tests a single kernel with multiple unique inputs and outputs. + */ +TEST_BOTH(02_MIMO) { + + auto func = [](UInt32 x, UInt32 y) { + return std::make_tuple(x + y, x * y); + }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 3; i++) { + auto x = arange(10 + i); + auto y = arange(10 + i) + 1; + + auto result = frozen(x, y); + + auto reference = func(x, y); + + jit_assert(all(eq(std::get<0>(result), std::get<0>(reference)))); + jit_assert(all(eq(std::get<1>(result), std::get<1>(reference)))); + } +} + +/** + * This tests if the recording feature works, when supplying the same variable + * twice in the input. In the final implementation this test-case should never + * occur, as variables would be deduplicated in beforehand. + */ +TEST_BOTH(03_deduplicating_input) { + + auto func = [](UInt32 x, UInt32 y) { return std::tuple(x + y, x * y); }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 3; i++) { + auto x = arange(10 + i); + + auto result = frozen(x, x); + + auto reference = func(x, x); + + jit_assert(all(eq(std::get<0>(result), std::get<0>(reference)))); + jit_assert(all(eq(std::get<1>(result), std::get<1>(reference)))); + } +} + +/** + * This tests, Whether it is possible to record multiple kernels in sequence. + * The input of the second kernel relies on the execution of the first. + * On LLVM, the correctness of barrier operations is therefore tested. + */ +TEST_BOTH(04_sequential_kernels) { + + auto func = [](UInt32 x) { + auto y = x + 1; + y.eval(); + return y + x; + }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 3; i++) { + auto x = arange(10 + i); + + auto result = frozen(x); + + auto reference = func(x); + + jit_assert(all(eq(result, reference))); + } +} + +/** + * This tests, Whether it is possible to record multiple independent kernels in + * the same recording. + * The variables of the kernels are of different size, therefore two kernels are + * generated. At replay these can be executed in parallel (LLVM) or sequence + * (CUDA). + */ +TEST_BOTH(05_parallel_kernels) { + + auto func = [](UInt32 x, UInt32 y) { return std::tuple(x + 1, y + 1); }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 3; i++) { + auto x = arange(10 + i); + auto y = arange(11 + i); + + auto result = frozen(x, y); + + auto reference = func(x, y); + + jit_assert(all(eq(std::get<0>(result), std::get<0>(reference)))); + jit_assert(all(eq(std::get<1>(result), std::get<1>(reference)))); + } +} + +/** + * This tests the recording and replay of a horizontal reduction operation + * (hsum). + */ +TEST_BOTH(06_reduce_hsum) { + + auto func = [](UInt32 x) { return hsum(x + 1); }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 3; i++) { + auto x = arange(10 + i); + + auto result = frozen(x); + + auto reference = func(x); + + jit_assert(all(eq(result, reference))); + } +} + +/** + * Tests recording of a prefix sum operation with different inputs at replay. + */ +TEST_BOTH(07_prefix_sum) { + + auto func = [](UInt32 x) { return block_prefix_sum(x, x.size()); }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 3; i++) { + auto x = arange(10 + i); + + auto result = frozen(x); + + auto reference = func(x); + + jit_assert(all(eq(result, reference))); + } +} + +/** + * Tests that it is possible to pass a single input to multiple outputs + * including directly in a frozen function without any use after free + * conditions. + */ +TEST_BOTH(08_input_passthrough) { + + auto func = [](UInt32 x) { + auto y = x + 1; + return std::tuple(y, x); + }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 3; i++) { + auto x = arange(10 + i); + x.make_opaque(); + + auto result = frozen(x); + + auto reference = func(x); + + jit_assert(all(eq(std::get<0>(result), std::get<0>(reference)))); + jit_assert(all(eq(std::get<1>(result), std::get<1>(reference)))); + } +} + +/** + * Tests if the dry run mode catches the case where LLVM kernels have to be + * replayed due to size changes in a scatter reduce operation. + */ +TEST_LLVM(09_dry_run) { + auto func = [](UInt32 target, UInt32 src) { + scatter_reduce(ReduceOp::Add, target, src, + arange(src.size()) % 2); + return target; + }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 4; i++) { + auto src = full(1, 10 + i); + src.make_opaque(); + + auto result = full(0, (i + 2)); + result.make_opaque(); + result = frozen(result, src); + + auto reference = full(0, (i + 2)); + reference.make_opaque(); + reference = frozen(reference, src); + + jit_assert(all(eq(result, reference))); + } +} + +/** + * Tests that scattering to a variable does not modify variables depending on + * the scatter target. This is ensured by the borrowing reference to the inputs + * in the FrozenFunction, which causes \c scatter to add a \c memcpy_async in + * the recording. + */ +TEST_LLVM(10_scatter) { + auto func = [](UInt32 x) { + scatter(x, UInt32(0), arange(x.size())); + // We have to return the input, since we do not perform input + // re-assignment in the \c FrozenFunction for the tests. + return x; + }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 4; i++) { + auto x = arange(10 + i); + + auto y = x + 1; + + x = frozen(x); + + jit_assert(all(eq(x, full(0, 10 + i)))); + jit_assert(all(eq(y, arange(10 + i) + 1))); + } +} diff --git a/tests/test.h b/tests/test.h index 88238553..22f1fd7a 100644 --- a/tests/test.h +++ b/tests/test.h @@ -3,8 +3,13 @@ #include #include #include +#include #include #include +#include +#include +#include +#include using namespace drjit; @@ -144,3 +149,243 @@ struct scoped_set_log_level { LogLevel m_stderr_level; }; +/// Operation, that can be applied to nested C++ traversable types. +/// The function receives a non-borrowing variable index from the \c JitArray it +/// is applied to and has to return an owned reference, transferring the +/// ownership back to the \c JitArray. +using apply_op = std::function; + +/// Traversable type, used to traverse frozen function inputs and outputs. +template struct traversable { + static constexpr bool value = false; + /// Apply the operation \c op to the C++ value v + static void apply(const apply_op &cb, T &v) { + (void) v; + (void) cb; + } +}; + +template +struct traversable().index())>> { + static constexpr bool value = true; + static void apply(const apply_op &cb, T &v) { v = T::steal(cb(v.index())); } +}; + +template +static void apply_tuple(const apply_op &cb, Tuple &t, + std::index_sequence) { + // Expands in left-to-right order + (traversable(t))>>::apply(cb, + std::get(t)), + ...); +} + +template struct traversable> { + static constexpr bool value = true; + static void apply(const apply_op &cb, std::tuple &v) { + apply_tuple(cb, v, std::index_sequence_for{}); + } +}; + +template +static void apply_arguments(const apply_op &cb, Args &&...args) { + (traversable>::apply(cb, args), ...); +} + +template static void make_opaque(Args &&...args) { + auto op = [](uint32_t index) { + int rv; + uint32_t new_index = jit_var_schedule_force(index, &rv); + return new_index; + }; + apply_arguments(op, args...); + jit_eval(); +} + +/// Constructable type, used to construct frozen function outputs +template struct constructable { + static constexpr bool value = false; + static T construct(const std::function &cb) { + static_assert(sizeof(T) == 0, "Could not construct type!"); + } +}; + +/// Construct any variable that has the \c borrow function +/// We have to use the \c borrow function instead of \c steal, since we would +/// split ownership between outputs if they appear twice in the type. +template +struct constructable> { + static constexpr bool value = true; + static T construct(const std::function &cb) { + return T::borrow(cb()); + } +}; + +template struct constructable> { + static constexpr bool value = true; + static std::tuple construct(const std::function &cb) { + // NOTE: initializer list to guarantee order of construct evaluation + return std::tuple{ constructable::construct(cb)... }; + } +}; + +/** + * \brief Minimal implementation of a FrozenFunction using the \c + * RecordThreadState + * + * This struct contains a single recording, that will be recorded when the + * function is first called. + * There are no checks that validate that the input layout hasn't changed. + * The registry is also not traversed, as this requires the registry pointers to + * inherit from \c nanobind::intrusive_base. + */ +template class FrozenFunction { + + JitBackend m_backend; + Func m_func; + + uint32_t m_outputs = 0; + Recording *m_recording = nullptr; + +public: + FrozenFunction(JitBackend backend, Func func) + : m_backend(backend), m_func(func), m_outputs(0) { + jit_log(LogLevel::Debug, "FrozenFunction()"); + } + ~FrozenFunction() { + if (m_recording) + jit_freeze_destroy(m_recording); + m_recording = nullptr; + } + + void clear() { + if (m_recording) { + jit_freeze_destroy(m_recording); + m_recording = nullptr; + m_outputs = 0; + } + } + + template + auto record(std::vector &input_vector, Args &&...args) { + using Output = typename std::invoke_result::type; + Output output; + + jit_log(LogLevel::Debug, "record:"); + + jit_freeze_start(m_backend, input_vector.data(), input_vector.size()); + + // Record the function, including evaluation of all side effects on the + // inputs and outputs + { + output = m_func(std::forward(args)...); + + make_opaque(output, args...); + } + + // Traverse output for \c jit_freeze_stop + // NOTE: in the implementation in drjit, we would also schedule the + // input and re-assign it. Since we pass variables by value, modified + // inputs have to be passed to the output explicitly. + std::vector output_vector; + { + auto op = [&output_vector](uint32_t index) { + // Take non borrowing reference to the index + output_vector.push_back(index); + + // Transfer ownership back to the \c JitArray + jit_var_inc_ref(index); + return index; + }; + + traversable::apply(op, output); + } + + m_recording = jit_freeze_stop(m_backend, output_vector.data(), + output_vector.size()); + m_outputs = (uint32_t) output_vector.size(); + + uint32_t counter = 0; + + // Construct output + { + output = + constructable::construct([&counter, &output_vector] { + return output_vector[counter++]; + }); + } + + // Output does not have to be released, as it is not borrowed, just + // referenced + + return output; + } + + template + auto replay(std::vector &input_vector, Args &&...args) { + using Output = typename std::invoke_result::type; + + jit_log(LogLevel::Debug, "dry run:"); + + int dryrun_success = + jit_freeze_dry_run(m_recording, input_vector.data()); + + if (!dryrun_success) { + clear(); + + return record(input_vector, args...); + } else { + std::vector output_vector(m_outputs, 0); + + jit_log(LogLevel::Debug, "replay:"); + // replay adds borrowing references to the \c output_vector + jit_freeze_replay(m_recording, input_vector.data(), + output_vector.data()); + + // Construct output + uint32_t counter = 0; + Output output = + constructable::construct([&counter, &output_vector] { + return output_vector[counter++]; + }); + + // Release the borrowed indices + for (uint32_t index : output_vector) + jit_var_dec_ref(index); + + return output; + } + } + + template auto operator()(Args &&...args) { + using Output = typename std::invoke_result::type; + + make_opaque(args...); + + // Make input opaque and add it to \c input_vector, borrowing it + std::vector input_vector; + auto op = [&input_vector](uint32_t index) { + // Borrow from the index and add it to the input_vector + jit_var_inc_ref(index); + input_vector.push_back(index); + + // Transfer ownership back to the \c JitArray + jit_var_inc_ref(index); + return index; + }; + apply_arguments(op, args...); + + Output output; + if (!m_recording) + output = record(input_vector, args...); + else + output = replay(input_vector, args...); + + // Release the borrowed indices + for (uint32_t i = 0; i < input_vector.size(); i++) + jit_var_dec_ref(input_vector[i]); + + return output; + } +}; diff --git a/tests/vcall.cpp b/tests/vcall.cpp index 18020e26..b1d6e342 100644 --- a/tests/vcall.cpp +++ b/tests/vcall.cpp @@ -1,3 +1,4 @@ +#include "drjit-core/jit.h" #include "test.h" #include "traits.h" @@ -1251,3 +1252,83 @@ TEST_BOTH(13_load_bool_data) { jit_registry_remove(&f2); } } + +/** + * This tests that it is possible to record vcalls in a frozen function. + * The registry has to stay constant between recording and replaying the frozen + * function. This is ensured in the python FrozenFunction. + * We do not test accessing a member field of any of the classes, as this would + * require registry traversal which requires nanobind for the \c + * nanobind::intrusive_base class. + */ +TEST_BOTH(14_frozen_vcall) { + jit_set_flag(JitFlag::VCallOptimize, true); + jit_set_flag(JitFlag::SymbolicCalls, true); + + struct Base { + virtual UInt32 f(UInt32 x) = 0; + }; + + struct A1 : Base { + UInt32 f(UInt32 x) override { return x + 1; } + }; + + struct A2 : Base { + UInt32 f(UInt32 x) override { return x + 2; } + }; + + A1 a1; + A2 a2; + + const char *domain = "Base"; + const size_t n_callables = 2; + const size_t n_inputs = 1; + const size_t n_outputs = 1; + + uint32_t i1 = jit_registry_put(Backend, domain, &a1); + uint32_t i2 = jit_registry_put(Backend, domain, &a2); + jit_assert(i1 == 1 && i2 == 2); + + using BasePtr = Array; + + auto f_call = [](void *self, uint32_t *inputs, uint32_t *outputs) { + Base *base = (Base *) self; + UInt32 x = UInt32::borrow(inputs[0]); + UInt32 y = base->f(x); + jit_var_inc_ref(y.index()); + + outputs[0] = y.index(); + }; + + auto func = [n_inputs, n_outputs, &f_call, &domain](UInt32 self, UInt32 x) { + uint32_t vcall_inputs[n_inputs] = { x.index() }; + uint32_t vcall_outputs[n_outputs] = { 0 }; + + Mask mask = Mask::steal(jit_var_bool(Backend, true)); + + symbolic_call( + Backend, domain, false, self.index(), mask.index(), f_call, + vcall_inputs, vcall_outputs); + + auto result = UInt32::borrow(vcall_outputs[0]); + + return result; + }; + + FrozenFunction frozen(Backend, func); + + for (uint32_t i = 0; i < 3; i++) { + // The size of the base pointer is changed when replaying + BasePtr self = (arange(10 + i) + i) % 3; + UInt32 x = arange(10 + i) + i; + + auto result = frozen(self, x); + + auto reference = func(self, x); + + jit_assert(all(eq(result, reference))); + } + + jit_registry_remove(&a1); + jit_registry_remove(&a2); +}