From bb38a8676605248bddd7926027de856a4ab92cac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christian=20D=C3=B6ring?= Date: Wed, 23 Oct 2024 08:19:38 +0200 Subject: [PATCH] Added tests for ThreadState freezing --- tests/CMakeLists.txt | 2 +- tests/record.cpp | 553 +++++++++++++++++++++++++++++++++++++++++++ tests/vcall.cpp | 116 +++++++++ 3 files changed, 670 insertions(+), 1 deletion(-) create mode 100644 tests/record.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7c4573e5..117dfe68 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,6 +1,6 @@ enable_testing() -set(TEST_FILES basics.cpp mem.cpp loop.cpp vcall.cpp graphviz.cpp reductions.cpp array.cpp) +set(TEST_FILES basics.cpp mem.cpp loop.cpp vcall.cpp graphviz.cpp reductions.cpp array.cpp record.cpp) add_executable(test_half half.cpp) add_test(NAME test_half COMMAND test_half diff --git a/tests/record.cpp b/tests/record.cpp new file mode 100644 index 00000000..a08c72da --- /dev/null +++ b/tests/record.cpp @@ -0,0 +1,553 @@ +#include "drjit-core/array.h" +#include "drjit-core/jit.h" +#include "test.h" + +/** + * Basic addition test. + * Supplying a different input should replay the operation, with this input. + * In this case, the input at replay is incremented and should result in an + * incremented output. + */ +TEST_BOTH(01_basic_replay) { + Recording *recording; + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 r0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + + uint32_t inputs[] = {i0.index()}; + + jit_freeze_start(Backend, inputs, 1); + + UInt32 o0 = i0 + 1; + o0.eval(); + + uint32_t outputs[] = {o0.index()}; + + recording = jit_freeze_stop(Backend, outputs, 1); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 r0(2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + + uint32_t inputs[] = {i0.index()}; + uint32_t outputs[1]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_freeze_destroy(recording); +} + +/** + * This tests a single kernel with multiple unique inputs and outputs. + */ +TEST_BOTH(02_MIMO) { + Recording *recording; + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 i1(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 r0(0, 2, 4, 6, 8, 10, 12, 14, 16, 18); + UInt32 r1(0, 1, 4, 9, 16, 25, 36, 49, 64, 81); + + uint32_t inputs[] = { + i0.index(), + i1.index(), + }; + + jit_freeze_start(Backend, inputs, 2); + + UInt32 o0 = i0 + i1; + UInt32 o1 = i0 * i1; + o0.schedule(); + o1.schedule(); + jit_eval(); + + uint32_t outputs[] = { + o0.index(), + o1.index(), + }; + + recording = jit_freeze_stop(Backend, outputs, 2); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(r1.index(), outputs[1]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 i1(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 r0(2, 4, 6, 8, 10, 12, 14, 16, 18, 20); + UInt32 r1(1, 4, 9, 16, 25, 36, 49, 64, 81, 100); + + uint32_t inputs[] = { + i0.index(), + i1.index(), + }; + uint32_t outputs[2]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(r1.index(), outputs[1]))); + } + + jit_freeze_destroy(recording); +} + +/** + * This tests if the recording feature works, when supplying the same variable + * twice in the input. In the final implementation this test-case should never + * occur, as variables would be deduplicated in beforehand. + */ +TEST_BOTH(03_deduplicating_input) { + Recording *recording; + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 r0(0, 2, 4, 6, 8, 10, 12, 14, 16, 18); + UInt32 r1(0, 1, 4, 9, 16, 25, 36, 49, 64, 81); + + uint32_t inputs[] = { + i0.index(), + i0.index(), + }; + + jit_freeze_start(Backend, inputs, 2); + + UInt32 o0 = i0 + i0; + UInt32 o1 = i0 * i0; + o0.schedule(); + o1.schedule(); + jit_eval(); + + uint32_t outputs[] = { + o0.index(), + o1.index(), + }; + + recording = jit_freeze_stop(Backend, outputs, 2); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(r1.index(), outputs[1]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 r0(2, 4, 6, 8, 10, 12, 14, 16, 18, 20); + UInt32 r1(1, 4, 9, 16, 25, 36, 49, 64, 81, 100); + + uint32_t inputs[] = { + i0.index(), + i0.index(), + }; + uint32_t outputs[2]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(r1.index(), outputs[1]))); + } + + jit_freeze_destroy(recording); +} + +/** + * This tests if the recording feature works, when supplying the same variable + * twice in the output. In the final implementation this test-case should never + * occur, as variables would be deduplicated in beforehand. + */ +TEST_BOTH(04_deduplicating_output) { + Recording *recording; + jit_set_log_level_stderr(LogLevel::Debug); + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 i1(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 r0(0, 2, 4, 6, 8, 10, 12, 14, 16, 18); + UInt32 r1(0, 2, 4, 6, 8, 10, 12, 14, 16, 18); + + uint32_t inputs[] = { + i0.index(), + i1.index(), + }; + + jit_freeze_start(Backend, inputs, 2); + + UInt32 o0 = i0 + i1; + UInt32 o1 = i0 + i1; + o0.schedule(); + o1.schedule(); + jit_eval(); + + uint32_t outputs[] = { + o0.index(), + o1.index(), + }; + + recording = jit_freeze_stop(Backend, outputs, 2); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(r1.index(), outputs[1]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 i1(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 r0(2, 4, 6, 8, 10, 12, 14, 16, 18, 20); + UInt32 r1(2, 4, 6, 8, 10, 12, 14, 16, 18, 20); + + uint32_t inputs[] = { + i0.index(), + i1.index(), + }; + uint32_t outputs[2]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(r1.index(), outputs[1]))); + } + + jit_freeze_destroy(recording); +} + +/** + * This tests, weather it is possible to record multiple kernels in sequence. + * The input of the second kernel relies on the execution of the first. + * On LLVM, the correctness of barrier operations is therefore tested. + */ +TEST_BOTH(05_sequential_kernels) { + Recording *recording; + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 r0(2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + + uint32_t inputs[] = { + i0.index(), + }; + + jit_freeze_start(Backend, inputs, 1); + + UInt32 tmp = i0 + 1; + tmp.schedule(); + jit_eval(); + UInt32 o0 = tmp + 1; + o0.schedule(); + jit_eval(); + + uint32_t outputs[] = { + o0.index(), + }; + + recording = jit_freeze_stop(Backend, outputs, 1); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 r0(2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + + uint32_t inputs[] = { + i0.index(), + }; + uint32_t outputs[1]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_freeze_destroy(recording); +} + +/** + * This tests, weather it is possible to record multiple kernels in parallel. + * The variables of the kernels are of different size, therefore two kernels are + * generated. At replay these can be executed in parallel (LLVM) or sequence + * (CUDA). + */ +TEST_BOTH(06_parallel_kernels) { + Recording *recording; + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 i1(0, 1, 2, 3, 4, 5); + UInt32 r0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 r1(1, 2, 3, 4, 5, 6); + + uint32_t inputs[] = { + i0.index(), + i1.index(), + }; + + jit_freeze_start(Backend, inputs, 2); + + UInt32 o0 = i0 + 1; + UInt32 o1 = i1 + 1; + o0.schedule(); + o1.schedule(); + jit_eval(); + + uint32_t outputs[] = { + o0.index(), + o1.index(), + }; + + recording = jit_freeze_stop(Backend, outputs, 2); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(r1.index(), outputs[1]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 i1(1, 2, 3, 4, 5, 6); + UInt32 r0(2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + UInt32 r1(2, 3, 4, 5, 6, 7); + + uint32_t inputs[] = { + i0.index(), + i1.index(), + }; + uint32_t outputs[2]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(r1.index(), outputs[1]))); + } + + jit_freeze_destroy(recording); +} + +/** + * This tests recording and replay of a horizontal reduction operation (hsum). + */ +TEST_BOTH(07_reduce_hsum) { + Recording *recording; + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 r0 = opaque(45, 1); + + uint32_t inputs[] = { + i0.index(), + }; + + jit_freeze_start(Backend, inputs, 1); + + UInt32 o0 = hsum(i0); + o0.schedule(); + jit_eval(); + + uint32_t outputs[] = { + o0.index(), + }; + + recording = jit_freeze_stop(Backend, outputs, 1); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 r0 = opaque(55, 1); + + uint32_t inputs[] = { + i0.index(), + }; + uint32_t outputs[1]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_freeze_destroy(recording); +} + +/** + * Tests recording of a prefix sum operation with different inputs at replay. + */ +TEST_BOTH(08_prefix_sum) { + Recording *recording; + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 r0(0, 1, 3, 6, 10, 15, 21, 28, 36, 45); + + uint32_t inputs[] = { + i0.index(), + }; + + jit_freeze_start(Backend, inputs, 1); + + uint32_t o0 = jit_var_block_prefix_reduce( + ReduceOp::Add, i0.index(), jit_var_size(i0.index()), 0, 0); + jit_var_schedule(o0); + jit_eval(); + + uint32_t outputs[] = { + o0, + }; + + recording = jit_freeze_stop(Backend, outputs, 1); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 r0(1, 3, 6, 10, 15, 21, 28, 36, 45, 55); + + uint32_t inputs[] = { + i0.index(), + }; + uint32_t outputs[1]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_freeze_destroy(recording); +} + +/** + * Basic addition test. + * Supplying a different input should replay the operation, with this input. + * In this case, the input at replay is incremented and should result in an + * incremented output. + */ +TEST_BOTH(9_resized_input) { + Recording *recording; + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2); + UInt32 r0(1, 2, 3); + + uint32_t inputs[] = {i0.index()}; + + jit_freeze_start(Backend, inputs, 1); + + UInt32 o0 = i0 + 1; + o0.eval(); + + uint32_t outputs[] = {o0.index()}; + + recording = jit_freeze_stop(Backend, outputs, 1); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(1, 2, 3, 4); + UInt32 r0(2, 3, 4, 5); + + uint32_t inputs[] = {i0.index()}; + uint32_t outputs[1]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_freeze_destroy(recording); +} + +TEST_BOTH(10_input_passthrough) { + Recording *recording; + + jit_log(LogLevel::Info, "Recording:"); + { + UInt32 i0(0, 1, 2, 3, 4, 5, 6, 7, 8, 9); + UInt32 r0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + + uint32_t inputs[] = {i0.index()}; + + jit_freeze_start(Backend, inputs, 1); + + UInt32 o0 = i0 + 1; + o0.eval(); + + uint32_t outputs[] = {o0.index(), i0.index()}; + + recording = jit_freeze_stop(Backend, outputs, 2); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(i0.index(), outputs[1]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + UInt32 i0(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + UInt32 r0(2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + + uint32_t inputs[] = {i0.index()}; + uint32_t outputs[2]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "o1: %s", jit_var_str(outputs[1])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + jit_assert(jit_var_all(jit_var_eq(i0.index(), outputs[1]))); + } + + jit_freeze_destroy(recording); +} diff --git a/tests/vcall.cpp b/tests/vcall.cpp index 18020e26..4582bbd2 100644 --- a/tests/vcall.cpp +++ b/tests/vcall.cpp @@ -1,3 +1,4 @@ +#include "drjit-core/jit.h" #include "test.h" #include "traits.h" @@ -1251,3 +1252,118 @@ TEST_BOTH(13_load_bool_data) { jit_registry_remove(&f2); } } + +TEST_BOTH(14_kernel_record) { + jit_set_flag(JitFlag::VCallOptimize, true); + jit_set_flag(JitFlag::SymbolicCalls, true); + + Recording *recording; + + struct Base { + virtual UInt32 f(UInt32 x) = 0; + }; + + struct A1 : Base { + UInt32 f(UInt32 x) override { return x + 1; } + }; + + struct A2 : Base { + UInt32 f(UInt32 x) override { return x + 2; } + }; + + A1 a1; + A2 a2; + + const char *domain = "Base"; + const size_t n_callables = 2; + const size_t n_inputs = 1; + const size_t n_outputs = 1; + + uint32_t i1 = jit_registry_put(Backend, domain, &a1); + uint32_t i2 = jit_registry_put(Backend, domain, &a2); + jit_assert(i1 == 1 && i2 == 2); + + using BasePtr = Array; + + auto f_call = [](void *self, uint32_t *inputs, uint32_t *outputs) { + Base *base = (Base *) self; + UInt32 i0 = UInt32::borrow(inputs[0]); + UInt32 o0 = base->f(i0); + jit_var_inc_ref(o0.index()); + + outputs[0] = o0.index(); + }; + + { + BasePtr self = arange(10) % 3; + self.eval(); + UInt32 i0 = arange(10); + i0.eval(); + UInt32 r0(0, 2, 4, 0, 5, 7, 0, 8, 10, 0); + + jit_log(LogLevel::Info, "Recording:"); + + uint32_t inputs[] = { + self.index(), + i0.index(), + }; + + jit_freeze_start(Backend, inputs, 2); + + uint32_t outputs[1]; + UInt32 o0; + + { + uint32_t vcall_inputs[n_inputs] = { i0.index() }; + uint32_t vcall_outputs[n_outputs] = { 0 }; + + Mask mask = Mask::steal(jit_var_bool(Backend, true)); + + jit_log(LogLevel::Info, "self: %u", self.index()); + jit_log(LogLevel::Info, "mask: %u", mask.index()); + jit_log(LogLevel::Info, "i0: %u", i0.index()); + symbolic_call( + Backend, domain, false, self.index(), mask.index(), f_call, + vcall_inputs, vcall_outputs); + + o0 = UInt32::borrow(vcall_outputs[0]); + o0.eval(); + + jit_log(LogLevel::Info, "o0: %u", o0.index()); + + outputs[0] = o0.index(); + } + + recording = jit_freeze_stop(Backend, outputs, 1); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_log(LogLevel::Info, "r0: %s", jit_var_str(r0.index())); + + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_log(LogLevel::Info, "Replay:"); + { + BasePtr self = (arange(10) + 1) % 3; + self.eval(); + UInt32 i0 = arange(10); + i0.eval(); + UInt32 r0(1, 3, 0, 4, 6, 0, 7, 9, 0, 10); + + uint32_t inputs[] = { + self.index(), + i0.index(), + }; + uint32_t outputs[1]; + + jit_freeze_replay(recording, inputs, outputs); + + jit_log(LogLevel::Info, "o0: %s", jit_var_str(outputs[0])); + jit_assert(jit_var_all(jit_var_eq(r0.index(), outputs[0]))); + } + + jit_freeze_destroy(recording); + + jit_registry_remove(&a1); + jit_registry_remove(&a2); +}