Skip to content

Commit

Permalink
Improved tests to use all and eq and added scatter test
Browse files Browse the repository at this point in the history
  • Loading branch information
DoeringChristian committed Nov 21, 2024
1 parent 26dde80 commit 21a8b62
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 39 deletions.
67 changes: 44 additions & 23 deletions tests/record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ TEST_BOTH(01_basic_replay) {

auto reference = func(input);

jit_assert(jit_var_all(jit_var_eq(result.index(), reference.index())));
jit_assert(all(eq(result, reference)));
}
}

Expand All @@ -44,10 +44,8 @@ TEST_BOTH(02_MIMO) {

auto reference = func(x, y);

jit_assert(jit_var_all(jit_var_eq(std::get<0>(result).index(),
std::get<0>(reference).index())));
jit_assert(jit_var_all(jit_var_eq(std::get<1>(result).index(),
std::get<1>(reference).index())));
jit_assert(all(eq(std::get<0>(result), std::get<0>(reference))));
jit_assert(all(eq(std::get<1>(result), std::get<1>(reference))));
}
}

Expand All @@ -69,10 +67,8 @@ TEST_BOTH(03_deduplicating_input) {

auto reference = func(x, x);

jit_assert(jit_var_all(jit_var_eq(std::get<0>(result).index(),
std::get<0>(reference).index())));
jit_assert(jit_var_all(jit_var_eq(std::get<1>(result).index(),
std::get<1>(reference).index())));
jit_assert(all(eq(std::get<0>(result), std::get<0>(reference))));
jit_assert(all(eq(std::get<1>(result), std::get<1>(reference))));
}
}

Expand All @@ -98,7 +94,7 @@ TEST_BOTH(04_sequential_kernels) {

auto reference = func(x);

jit_assert(jit_var_all(jit_var_eq(result.index(), reference.index())));
jit_assert(all(eq(result, reference)));
}
}

Expand All @@ -123,10 +119,8 @@ TEST_BOTH(05_parallel_kernels) {

auto reference = func(x, y);

jit_assert(jit_var_all(jit_var_eq(std::get<0>(result).index(),
std::get<0>(reference).index())));
jit_assert(jit_var_all(jit_var_eq(std::get<1>(result).index(),
std::get<1>(reference).index())));
jit_assert(all(eq(std::get<0>(result), std::get<0>(reference))));
jit_assert(all(eq(std::get<1>(result), std::get<1>(reference))));
}
}

Expand All @@ -147,7 +141,7 @@ TEST_BOTH(06_reduce_hsum) {

auto reference = func(x);

jit_assert(jit_var_all(jit_var_eq(result.index(), reference.index())));
jit_assert(all(eq(result, reference)));
}
}

Expand All @@ -168,7 +162,7 @@ TEST_BOTH(07_prefix_sum) {

auto reference = func(x);

jit_assert(jit_var_all(jit_var_eq(result.index(), reference.index())));
jit_assert(all(eq(result, reference)));
}
}

Expand All @@ -194,10 +188,8 @@ TEST_BOTH(08_input_passthrough) {

auto reference = func(x);

jit_assert(jit_var_all(jit_var_eq(std::get<0>(result).index(),
std::get<0>(reference).index())));
jit_assert(jit_var_all(jit_var_eq(std::get<1>(result).index(),
std::get<1>(reference).index())));
jit_assert(all(eq(std::get<0>(result), std::get<0>(reference))));
jit_assert(all(eq(std::get<1>(result), std::get<1>(reference))));
}
}

Expand All @@ -215,17 +207,46 @@ TEST_LLVM(09_dry_run) {
FrozenFunction frozen(Backend, func);

for (uint32_t i = 0; i < 4; i++) {
auto src = full<UInt32>(1, 10+i);
auto src = full<UInt32>(1, 10 + i);
src.make_opaque();

auto result = full<UInt32>(0, (i + 2));
result.make_opaque();
result = frozen(result, src);

auto reference = full<UInt32>(0, (i + 2));
reference.make_opaque();
reference = frozen(reference, src);

jit_assert(jit_var_all(jit_var_eq(result.index(), reference.index())));
jit_assert(all(eq(result, reference)));
}
}

/**
* Tests that scattering to a variable does not modify variables depending on
* the scatter target. This is ensured by the borrowing reference to the inputs
* in the FrozenFunction, which causes \c scatter to add a \c memcpy_async in
* the recording.
*/
TEST_LLVM(10_scatter) {
auto func = [](UInt32 x) {
scatter(x, UInt32(0), arange<UInt32>(x.size()));
// We have to return the input, since we do not perform input
// re-assignment in the \c FrozenFunction for the tests.
return x;
};

FrozenFunction frozen(Backend, func);

for (uint32_t i = 0; i < 4; i++) {
auto x = arange<UInt32>(10 + i);
x.make_opaque();

auto y = x + 1;

x = frozen(x);

jit_assert(all(eq(x, full<UInt32>(0, 10 + i))));
jit_assert(all(eq(y, arange<UInt32>(10 + i) + 1)));
}
}
20 changes: 4 additions & 16 deletions tests/test.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ template <typename Func> class FrozenFunction {
make_opaque(output, args...);
}

// Traverse output and (changed) input for \c jit_freeze_stop
// Traverse output for \c jit_freeze_stop
// NOTE: in the implementation in drjit, we would also schedule the
// input and re-assign it. Since we pass variables by value, modified
// inputs have to be passed to the output explicitly.
std::vector<uint32_t> output_vector;
{
auto op = [&output_vector](uint32_t index) {
Expand All @@ -295,7 +298,6 @@ template <typename Func> class FrozenFunction {
};

traversable<Output>::apply(op, output);
apply_arguments(op, args...);
}

m_recording = jit_freeze_stop(m_backend, output_vector.data(),
Expand All @@ -312,20 +314,6 @@ template <typename Func> class FrozenFunction {
});
}

// Re-assign input
{
auto op = [&output_vector, &counter](uint32_t index) {
// Borrow from the output_vector
uint32_t new_index = output_vector[counter++];
jit_var_inc_ref(new_index);

// Release old index
jit_var_dec_ref(index);
return new_index;
};
apply_arguments(op, args...);
}

// Output does not have to be released, as it is not borrowed, just
// referenced

Expand Down

0 comments on commit 21a8b62

Please sign in to comment.