From 311dc1f82c929ff2b3dc73c27e7a762acea848e0 Mon Sep 17 00:00:00 2001 From: Nicolas Roussel Date: Mon, 21 Nov 2022 08:22:45 +0100 Subject: [PATCH 1/4] Add VCallBranch JIT flag: virtual function calls are no longer dispatched by indirect function calls (CUDA) or direct callables (Optix) but converted into switch-like statements --- include/drjit-core/jit.h | 44 +++--- src/eval_cuda.cpp | 51 +++++-- src/optix_api.cpp | 28 ++-- src/vcall.cpp | 310 ++++++++++++++++++++++++--------------- tests/vcall.cpp | 15 +- 5 files changed, 274 insertions(+), 174 deletions(-) diff --git a/include/drjit-core/jit.h b/include/drjit-core/jit.h index 1aef7dd31..0dd76025e 100644 --- a/include/drjit-core/jit.h +++ b/include/drjit-core/jit.h @@ -1304,42 +1304,49 @@ enum class JitFlag : uint32_t { /// Record virtual function calls instead of splitting them into many small kernel launches VCallRecord = 16, + /** + * \brief Use branches instead of direct callables (in OptiX) or indirect + * function calls (in CUDA) for virtual function calls. + */ + VCallBranch = 32, + /// De-duplicate virtual function calls that produce the same code - VCallDeduplicate = 32, + VCallDeduplicate = 64, /// Enable constant propagation and elide unnecessary function arguments - VCallOptimize = 64, + VCallOptimize = 128, /** * \brief Inline calls if there is only a single instance? (off by default, * inlining can make kernels so large that they actually run slower in * CUDA/OptiX). */ - VCallInline = 128, + VCallInline = 256, /// Force execution through OptiX even if a kernel doesn't use ray tracing - ForceOptiX = 256, + ForceOptiX = 512, /// Temporarily postpone evaluation of statements with side effects - Recording = 512, + Recording = 1024, /// Print the intermediate representation of generated programs - PrintIR = 1024, + PrintIR = 2048, /// Enable writing of the kernel history - KernelHistory = 2048, + KernelHistory = 4096, /* Force synchronization after every kernel launch. This is useful to isolate crashes to a specific kernel, and to benchmark kernel runtime along with the KernelHistory feature. */ - LaunchBlocking = 4096, + LaunchBlocking = 8192, /// Exploit literal constants during AD (used in the Dr.Jit parent project) - ADOptimize = 8192, + ADOptimize = 16384, /// Default flags Default = (uint32_t) ConstProp | (uint32_t) ValueNumbering | (uint32_t) LoopRecord | (uint32_t) LoopOptimize | + //(uint32_t) VCallBranch | (uint32_t) VCallRecord | (uint32_t) VCallDeduplicate | (uint32_t) VCallOptimize | (uint32_t) ADOptimize }; @@ -1350,15 +1357,16 @@ enum JitFlag { JitFlagLoopRecord = 4, JitFlagLoopOptimize = 8, JitFlagVCallRecord = 16, - JitFlagVCallDeduplicate = 32, - JitFlagVCallOptimize = 64, - JitFlagVCallInline = 128, - JitFlagForceOptiX = 256, - JitFlagRecording = 512, - JitFlagPrintIR = 1024, - JitFlagKernelHistory = 2048, - JitFlagLaunchBlocking = 4096, - JitFlagADOptimize = 8192 + JitFlagVCallBranch = 32, + JitFlagVCallDeduplicate = 64, + JitFlagVCallOptimize = 128, + JitFlagVCallInline = 256, + JitFlagForceOptiX = 512, + JitFlagRecording = 1024, + JitFlagPrintIR = 2048, + JitFlagKernelHistory = 4096, + JitFlagLaunchBlocking = 8192, + JitFlagADOptimize = 16384, }; #endif diff --git a/src/eval_cuda.cpp b/src/eval_cuda.cpp index 9d4741bc9..d4af5553b 100644 --- a/src/eval_cuda.cpp +++ b/src/eval_cuda.cpp @@ -203,30 +203,48 @@ void jitc_assemble_cuda(ThreadState *ts, ScheduledGroup group, it.second.callable_index = ctr++; } - if (callable_count > 0 && !uses_optix) { + if (callable_count > 0) { size_t insertion_point = (char *) strstr(buffer.get(), ".address_size 64\n\n") - buffer.get() + 18, insertion_start = buffer.size(); - buffer.fmt(".extern .global .u64 callables[%u];\n\n", - callable_count_unique); + if (jit_flag(JitFlag::VCallBranch)) { + // Copy signatures to very beginning + for (const auto &it : globals_map) { + if (!it.first.callable) + continue; - jitc_insert_code_at(insertion_point, insertion_start); + const char* func_definition = globals.get() + it.second.start; + const char* signature_begin = strstr(func_definition, ".func"); + const char* signature_end = strstr(func_definition, "{"); - buffer.fmt("\n.visible .global .align 8 .u64 callables[%u] = {\n", - callable_count_unique); - for (auto const &it : globals_map) { - if (!it.first.callable) - continue; + buffer.put(".visible "); + buffer.put(signature_begin, + signature_end - 1 - signature_begin); + buffer.put(";\n"); + } + buffer.fmt("\n"); + jitc_insert_code_at(insertion_point, insertion_start); + } else if (!uses_optix) { + buffer.fmt(".extern .global .u64 callables[%u];\n\n", + callable_count_unique); + jitc_insert_code_at(insertion_point, insertion_start); + + buffer.fmt("\n.visible .global .align 8 .u64 callables[%u] = {\n", + callable_count_unique); + for (auto const &it : globals_map) { + if (!it.first.callable) + continue; + + buffer.fmt(" func_%016llx%016llx%s\n", + (unsigned long long) it.first.hash.high64, + (unsigned long long) it.first.hash.low64, + it.second.callable_index + 1 < callable_count_unique ? "," : ""); + } - buffer.fmt(" func_%016llx%016llx%s\n", - (unsigned long long) it.first.hash.high64, - (unsigned long long) it.first.hash.low64, - it.second.callable_index + 1 < callable_count_unique ? "," : ""); + buffer.put("};\n\n"); } - - buffer.put("};\n\n"); } jitc_vcall_upload(ts); @@ -245,8 +263,9 @@ void jitc_assemble_cuda_func(const char *name, uint32_t inst_id, buffer.put(".visible .func"); if (out_size) buffer.fmt(" (.param .align %u .b8 result[%u])", out_align, out_size); + bool uses_direct_callables = uses_optix && !(jit_flag(JitFlag::VCallBranch)); buffer.fmt(" %s^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^(", - uses_optix ? "__direct_callable__" : "func_"); + uses_direct_callables ? "__direct_callable__" : "func_"); if (use_self) { buffer.put(".reg .u32 self"); diff --git a/src/optix_api.cpp b/src/optix_api.cpp index 830ee6ec8..c4e90c542 100644 --- a/src/optix_api.cpp +++ b/src/optix_api.cpp @@ -662,19 +662,21 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size, pgd[0].raygen.module = kernel.optix.mod; pgd[0].raygen.entryFunctionName = strdup(kern_name); - for (auto const &it : globals_map) { - if (!it.first.callable) - continue; - - char *name = (char *) malloc_check(52); - snprintf(name, 52, "__direct_callable__%016llx%016llx", - (unsigned long long) it.first.hash.high64, - (unsigned long long) it.first.hash.low64); - - uint32_t index = 1 + it.second.callable_index; - pgd[index].kind = OPTIX_PROGRAM_GROUP_KIND_CALLABLES; - pgd[index].callables.moduleDC = kernel.optix.mod; - pgd[index].callables.entryFunctionNameDC = name; + if (!jit_flag(JitFlag::VCallBranch)) { + for (auto const &it : globals_map) { + if (!it.first.callable) + continue; + + char *name = (char *) malloc_check(52); + snprintf(name, 52, "__direct_callable__%016llx%016llx", + (unsigned long long) it.first.hash.high64, + (unsigned long long) it.first.hash.low64); + + uint32_t index = 1 + it.second.callable_index; + pgd[index].kind = OPTIX_PROGRAM_GROUP_KIND_CALLABLES; + pgd[index].callables.moduleDC = kernel.optix.mod; + pgd[index].callables.entryFunctionNameDC = name; + } } kernel.optix.pg = new OptixProgramGroup[n_programs]; diff --git a/src/vcall.cpp b/src/vcall.cpp index dfeb946f6..84e81c8c1 100644 --- a/src/vcall.cpp +++ b/src/vcall.cpp @@ -69,6 +69,8 @@ struct VCall { /// Does this vcall need self as argument bool use_self = false; + CallablesSet callables_set; + ~VCall() { for (uint32_t index : out_nested) jitc_var_dec_ref(index); @@ -773,7 +775,7 @@ static void jitc_var_vcall_assemble(VCall *vcall, ThreadState *ts = thread_state(vcall->backend); - CallablesSet callables_set; + vcall->callables_set.clear(); for (uint32_t i = 0; i < vcall->n_inst; ++i) { XXH128_hash_t hash = jitc_assemble_func( ts, vcall->name, i, in_size, in_align, out_size, out_align, @@ -783,7 +785,7 @@ static void jitc_var_vcall_assemble(VCall *vcall, vcall->side_effects.data() + vcall->checkpoints[i], vcall->use_self); vcall->inst_hash[i] = hash; - callables_set.insert(hash); + vcall->callables_set.insert(hash); } size_t se_count = vcall->side_effects.size(); @@ -819,7 +821,7 @@ static void jitc_var_vcall_assemble(VCall *vcall, InfoSym, "jit_var_vcall_assemble(): indirect call (\"%s\") to %zu/%u instances, " "passing %u/%u inputs (%u/%u bytes), %u/%u outputs (%u/%u bytes), %zu side effects", - vcall->name, callables_set.size(), vcall->n_inst, n_in_active, + vcall->name, vcall->callables_set.size(), vcall->n_inst, n_in_active, vcall->in_count_initial, in_size, vcall->in_size_initial, n_out_active, n_out, out_size, vcall->out_size_initial, se_count); @@ -857,10 +859,13 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, // 3. Turn callable ID into a function pointer // ===================================================== - if (!uses_optix) - buffer.fmt(" ld.global.u64 %%rd2, callables[%%r3];\n"); - else - buffer.put(" call (%rd2), _optix_call_direct_callable, (%r3);\n"); + bool branch_vcall = jit_flag(JitFlag::VCallBranch); + if (!branch_vcall) { + if (!uses_optix) + buffer.fmt(" ld.global.u64 %%rd2, callables[%%r3];\n"); + else + buffer.put(" call (%rd2), _optix_call_direct_callable, (%r3);\n"); + } // ===================================================== // 4. Obtain pointer to supplemental call data @@ -895,131 +900,169 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, v2->reg_index, v2->reg_index); } - buffer.put(" {\n"); - - // Call prototype - buffer.put(" proto: .callprototype"); - if (out_size) - buffer.fmt(" (.param .align %u .b8 result[%u])", out_align, out_size); - buffer.put(" _("); - if (vcall->use_self) { - buffer.put(".reg .u32 self"); - if (data_reg || in_size) - buffer.put(", "); - } - if (data_reg) { - buffer.put(".reg .u64 data"); - if (in_size) - buffer.put(", "); + // Switch statement: branch to call + if (branch_vcall) { + for (size_t i = 0; i < vcall->callables_set.size(); ++i) { + buffer.fmt(" setp.eq.u32 %%p0, %%r3, %u;\n", (uint32_t) i); + buffer.fmt(" @%%p0 bra l_%u_%u;\n", vcall->id, (uint32_t) i); + } + buffer.put("\n"); } - if (in_size) - buffer.fmt(".param .align %u .b8 params[%u]", in_align, in_size); - buffer.put(");\n"); - - // Input/output parameter arrays - if (out_size) - buffer.fmt(" .param .align %u .b8 out[%u];\n", out_align, out_size); - if (in_size) - buffer.fmt(" .param .align %u .b8 in[%u];\n", in_align, in_size); - - // ===================================================== - // 5.1. Pass the input arguments - // ===================================================== - uint32_t offset = 0; - for (uint32_t in : vcall->in) { - auto it = state.variables.find(in); - if (it == state.variables.end()) - continue; - const Variable *v2 = &it->second; - uint32_t size = type_size[v2->type]; - - const char *tname = type_name_ptx[v2->type], - *prefix = type_prefix[v2->type]; - - // Special handling for predicates (pass via u8) - if ((VarType) v2->type == VarType::Bool) { - tname = "u8"; - prefix = "%w"; + uint32_t callable_id = 0; + for (XXH128_hash_t callable_hash: vcall->callables_set) { + if (!branch_vcall) { + // Call prototype + buffer.put(" {\n"); + buffer.put(" proto: .callprototype"); + if (out_size) + buffer.fmt(" (.param .align %u .b8 result[%u])", out_align, out_size); + buffer.put(" _("); + if (vcall->use_self) { + buffer.put(".reg .u32 self"); + if (data_reg || in_size) + buffer.put(", "); + } + if (data_reg) { + buffer.put(".reg .u64 data"); + if (in_size) + buffer.put(", "); + } + if (in_size) + buffer.fmt(".param .align %u .b8 params[%u]", in_align, in_size); + buffer.put(");\n"); + } else { + buffer.fmt(" l_%u_%u:\n", vcall->id, callable_id); + buffer.put(" {\n"); } - buffer.fmt(" st.param.%s [in+%u], %s%u;\n", tname, offset, - prefix, v2->reg_index); + // Input/output parameter arrays + if (out_size) + buffer.fmt(" .param .align %u .b8 out[%u];\n", out_align, out_size); + if (in_size) + buffer.fmt(" .param .align %u .b8 in[%u];\n", in_align, in_size); - offset += size; - } + // ===================================================== + // 5.1. Pass the input arguments + // ===================================================== - if (vcall->use_self) { - buffer.fmt(" call %s%%rd2, (%%r%u%s%s), proto;\n", - out_size ? "(out), " : "", self_reg, - data_reg ? ", %rd3" : "", - in_size ? ", in" : ""); - } else { - buffer.fmt(" call %s%%rd2, (%s%s%s), proto;\n", - out_size ? "(out), " : "", data_reg ? "%rd3" : "", - data_reg && in_size ? ", " : "", in_size ? "in" : ""); - } + uint32_t offset = 0; + for (uint32_t in : vcall->in) { + auto it = state.variables.find(in); + if (it == state.variables.end()) + continue; + const Variable *v2 = &it->second; + uint32_t size = type_size[v2->type]; - // ===================================================== - // 5.2. Read back the output arguments - // ===================================================== + const char *tname = type_name_ptx[v2->type], + *prefix = type_prefix[v2->type]; - offset = 0; - for (uint32_t i = 0; i < n_out; ++i) { - uint32_t index = vcall->out_nested[i], - index_2 = vcall->out[i]; - auto it = state.variables.find(index); - if (it == state.variables.end()) - continue; - uint32_t size = type_size[it->second.type], - load_offset = offset; - offset += size; + // Special handling for predicates (pass via u8) + if ((VarType) v2->type == VarType::Bool) { + tname = "u8"; + prefix = "%w"; + } - // Skip if outer access expired - auto it2 = state.variables.find(index_2); - if (it2 == state.variables.end()) - continue; + buffer.fmt(" st.param.%s [in+%u], %s%u;\n", tname, offset, + prefix, v2->reg_index); - const Variable *v2 = &it2.value(); - if (v2->reg_index == 0 || v2->param_type == ParamType::Input) - continue; + offset += size; + } - const char *tname = type_name_ptx[v2->type], - *prefix = type_prefix[v2->type]; + // ===================================================== + // 5.2. Setup the function call + // ===================================================== + + auto assemble_call = [&](const char* target) { + buffer.put(" "); + if (vcall->use_self) { + buffer.fmt("call %s%s, (%%r%u%s%s)%s;\n", + out_size ? "(out), " : "", + target, + self_reg, + data_reg ? ", %rd3" : "", + in_size ? ", in" : "", + branch_vcall ? "" : ", proto"); + } else { + buffer.fmt("call %s%s, (%s%s%s)%s;\n", + out_size ? "(out), " : "", + target, + data_reg ? "%rd3" : "", + data_reg && in_size ? ", " : "", + in_size ? "in" : "", + branch_vcall ? "" : ", proto"); + } + }; + + + // ===================================================== + // 5.3. Call the function and read the output arguments + // ===================================================== + + auto read_output_arguments = [&]() { + offset = 0; + for (uint32_t i = 0; i < n_out; ++i) { + uint32_t index = vcall->out_nested[i], + index_2 = vcall->out[i]; + auto it = state.variables.find(index); + if (it == state.variables.end()) + continue; + uint32_t size = type_size[it->second.type], + load_offset = offset; + offset += size; + + // Skip if outer access expired + auto it2 = state.variables.find(index_2); + if (it2 == state.variables.end()) + continue; + + const Variable *v2 = &it2.value(); + if (v2->reg_index == 0 || v2->param_type == ParamType::Input) + continue; + + const char *tname = type_name_ptx[v2->type], + *prefix = type_prefix[v2->type]; + + // Special handling for predicates (pass via u8) + if ((VarType) v2->type == VarType::Bool) { + tname = "u8"; + prefix = "%w"; + } - // Special handling for predicates (pass via u8) - if ((VarType) v2->type == VarType::Bool) { - tname = "u8"; - prefix = "%w"; - } + buffer.fmt(" ld.param.%s %s%u, [out+%u];\n", + tname, prefix, v2->reg_index, load_offset); - buffer.fmt(" ld.param.%s %s%u, [out+%u];\n", - tname, prefix, v2->reg_index, load_offset); - } + if ((VarType) v2->type == VarType::Bool) + buffer.fmt(" setp.ne.u16 %%p%u, %%w%u, 0;\n", + v2->reg_index, v2->reg_index); + } + }; - buffer.put(" }\n\n"); + if (!branch_vcall) { + const char* target = "%rd2"; + assemble_call(target); + read_output_arguments(); + } else { + char target[38]; + snprintf(target, sizeof(target), "func_%016llx%016llx", + (unsigned long long) callable_hash.high64, + (unsigned long long) callable_hash.low64); + assemble_call(target); + read_output_arguments(); + buffer.fmt(" bra.uni l_done_%u;\n", vcall_reg); + } - // ===================================================== - // 6. Special handling for predicates return value(s) - // ===================================================== + buffer.put(" }\n"); - for (uint32_t out : vcall->out) { - auto it = state.variables.find(out); - if (it == state.variables.end()) - continue; - const Variable *v2 = &it->second; - if ((VarType) v2->type != VarType::Bool) - continue; - if (v2->reg_index == 0 || v2->param_type == ParamType::Input) - continue; + if (!branch_vcall) { + break; + } - // Special handling for predicates - buffer.fmt(" setp.ne.u16 %%p%u, %%w%u, 0;\n", - v2->reg_index, v2->reg_index); + callable_id++; } - - buffer.fmt(" bra.uni l_done_%u;\n", vcall_reg); + if (!branch_vcall) + buffer.fmt(" bra.uni l_done_%u;\n", vcall_reg); buffer.put(" }\n"); // ===================================================== @@ -1370,15 +1413,40 @@ void jitc_vcall_upload(ThreadState *ts) { uint64_t *data = (uint64_t *) jitc_malloc(at, vcall->offset_size); memset(data, 0, vcall->offset_size); - for (uint32_t i = 0; i < vcall->n_inst; ++i) { - auto it = globals_map.find(GlobalKey(vcall->inst_hash[i], true)); - if (it == globals_map.end()) - jitc_fail("jitc_vcall_upload(): could not find callable!"); + if (ts->backend == JitBackend::CUDA && jit_flag(JitFlag::VCallBranch)) { + for (uint32_t i = 0; i < vcall->n_inst; ++i) { + uint32_t callable_index = 0; + bool found = false; + for (auto callable: vcall->callables_set) { + if (callable.high64 == vcall->inst_hash[i].high64 && + callable.low64 == vcall->inst_hash[i].low64) { + found = true; + break; + } + callable_index++; + } - // high part: instance data offset, low part: callable index - data[vcall->inst_id[i]] = - (((uint64_t) vcall->data_offset[i]) << 32) | - it->second.callable_index; + auto it = globals_map.find(GlobalKey(vcall->inst_hash[i], true)); + if (it == globals_map.end()) + jitc_fail("jitc_vcall_upload(): could not find callable!"); + + // high part: instance data offset, low part: callable index + data[vcall->inst_id[i]] = + (((uint64_t) vcall->data_offset[i]) << 32) | + callable_index; + } + } + else { + for (uint32_t i = 0; i < vcall->n_inst; ++i) { + auto it = globals_map.find(GlobalKey(vcall->inst_hash[i], true)); + if (it == globals_map.end()) + jitc_fail("jitc_vcall_upload(): could not find callable!"); + + // high part: instance data offset, low part: callable index + data[vcall->inst_id[i]] = + (((uint64_t) vcall->data_offset[i]) << 32) | + it->second.callable_index; + } } jitc_memcpy_async(ts->backend, vcall->offset, data, vcall->offset_size); diff --git a/tests/vcall.cpp b/tests/vcall.cpp index f7d2f69fa..f8fba88cc 100644 --- a/tests/vcall.cpp +++ b/tests/vcall.cpp @@ -229,6 +229,9 @@ TEST_BOTH(01_recorded_vcall) { A1 a1; A2 a2; + jit_set_flag(JitFlag::PrintIR, true); + jit_set_log_level_stderr(LogLevel::Trace); + // jit_llvm_set_target("skylake-avx512", "+avx512f,+avx512dq,+avx512vl,+avx512cd", 16); uint32_t i1 = jit_registry_put(Backend, "Base", &a1); uint32_t i2 = jit_registry_put(Backend, "Base", &a2); @@ -293,7 +296,7 @@ TEST_BOTH(02_calling_conventions) { jit_set_flag(JitFlag::VCallOptimize, i); using BasePtr = Array; - BasePtr self = arange(10) % 3; + BasePtr self = arange(12) % 4; Mask p0(false); Float p1(12); @@ -314,11 +317,11 @@ TEST_BOTH(02_calling_conventions) { jit_var_schedule(result.template get<3>().index()); jit_var_schedule(result.template get<4>().index()); - jit_assert(strcmp(result.template get<0>().str(), "[0, 0, 1, 0, 0, 1, 0, 0, 1, 0]") == 0); - jit_assert(strcmp(result.template get<1>().str(), "[0, 12, 13, 0, 12, 13, 0, 12, 13, 0]") == 0); - jit_assert(strcmp(result.template get<2>().str(), "[0, 34, 36, 0, 34, 36, 0, 34, 36, 0]") == 0); - jit_assert(strcmp(result.template get<3>().str(), "[0, 56, 59, 0, 56, 59, 0, 56, 59, 0]") == 0); - jit_assert(strcmp(result.template get<4>().str(), "[0, 1, 0, 0, 1, 0, 0, 1, 0, 0]") == 0); + jit_assert(strcmp(result.template get<0>().str(), "[0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0]") == 0); + jit_assert(strcmp(result.template get<1>().str(), "[0, 12, 13, 0, 0, 12, 13, 0, 0, 12, 13, 0]") == 0); + jit_assert(strcmp(result.template get<2>().str(), "[0, 34, 36, 0, 0, 34, 36, 0, 0, 34, 36, 0]") == 0); + jit_assert(strcmp(result.template get<3>().str(), "[0, 56, 59, 0, 0, 56, 59, 0, 0, 56, 59, 0]") == 0); + jit_assert(strcmp(result.template get<4>().str(), "[0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]") == 0); } jit_registry_remove(Backend, &b1); From 86e6b27483f9bbd1346c6a1032086273ebfe4762 Mon Sep 17 00:00:00 2001 From: Nicolas Roussel Date: Mon, 21 Nov 2022 08:23:07 +0100 Subject: [PATCH 2/4] Add VCallBranchJumpTable JIT flag: when using VCallBranch use a jump table rather than a series of individual branches. --- include/drjit-core/jit.h | 35 +++++++++++++++++++---------------- src/vcall.cpp | 35 ++++++++++++++++++++++------------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/include/drjit-core/jit.h b/include/drjit-core/jit.h index 0dd76025e..072c83d63 100644 --- a/include/drjit-core/jit.h +++ b/include/drjit-core/jit.h @@ -1343,30 +1343,33 @@ enum class JitFlag : uint32_t { /// Exploit literal constants during AD (used in the Dr.Jit parent project) ADOptimize = 16384, + VCallBranchJumpTable = 32768, + /// Default flags Default = (uint32_t) ConstProp | (uint32_t) ValueNumbering | (uint32_t) LoopRecord | (uint32_t) LoopOptimize | - //(uint32_t) VCallBranch | (uint32_t) VCallRecord | (uint32_t) VCallDeduplicate | + (uint32_t) VCallBranch | (uint32_t) VCallBranchJumpTable | (uint32_t) VCallOptimize | (uint32_t) ADOptimize }; #else enum JitFlag { - JitFlagConstProp = 1, - JitFlagValueNumbering = 2, - JitFlagLoopRecord = 4, - JitFlagLoopOptimize = 8, - JitFlagVCallRecord = 16, - JitFlagVCallBranch = 32, - JitFlagVCallDeduplicate = 64, - JitFlagVCallOptimize = 128, - JitFlagVCallInline = 256, - JitFlagForceOptiX = 512, - JitFlagRecording = 1024, - JitFlagPrintIR = 2048, - JitFlagKernelHistory = 4096, - JitFlagLaunchBlocking = 8192, - JitFlagADOptimize = 16384, + JitFlagConstProp = 1, + JitFlagValueNumbering = 2, + JitFlagLoopRecord = 4, + JitFlagLoopOptimize = 8, + JitFlagVCallRecord = 16, + JitFlagVCallBranch = 32, + JitFlagVCallDeduplicate = 64, + JitFlagVCallOptimize = 128, + JitFlagVCallInline = 256, + JitFlagForceOptiX = 512, + JitFlagRecording = 1024, + JitFlagPrintIR = 2048, + JitFlagKernelHistory = 4096, + JitFlagLaunchBlocking = 8192, + JitFlagADOptimize = 16384, + JitFlagVCallBranchJumpTable = 32768, }; #endif diff --git a/src/vcall.cpp b/src/vcall.cpp index 84e81c8c1..2c3541385 100644 --- a/src/vcall.cpp +++ b/src/vcall.cpp @@ -836,6 +836,9 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, uint32_t in_align, uint32_t out_size, uint32_t out_align) { + bool branch_vcall = jit_flag(JitFlag::VCallBranch); + bool jump_table = jit_flag(JitFlag::VCallBranchJumpTable); + // ===================================================== // 1. Conditional branch // ===================================================== @@ -859,7 +862,6 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, // 3. Turn callable ID into a function pointer // ===================================================== - bool branch_vcall = jit_flag(JitFlag::VCallBranch); if (!branch_vcall) { if (!uses_optix) buffer.fmt(" ld.global.u64 %%rd2, callables[%%r3];\n"); @@ -876,7 +878,6 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, " add.u64 %%rd3, %%rd3, %%rd%u;\n", data_reg); } - // %rd2: function pointer (if applicable) // %rd3: call data pointer with offset @@ -902,9 +903,20 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, // Switch statement: branch to call if (branch_vcall) { - for (size_t i = 0; i < vcall->callables_set.size(); ++i) { - buffer.fmt(" setp.eq.u32 %%p0, %%r3, %u;\n", (uint32_t) i); - buffer.fmt(" @%%p0 bra l_%u_%u;\n", vcall->id, (uint32_t) i); + if (!jump_table) { + for (size_t i = 0; i < vcall->callables_set.size(); ++i) { + buffer.fmt(" setp.eq.u32 %%p0, %%r3, %u;\n", (uint32_t) i); + buffer.fmt(" @%%p0 bra l_%u_%u;\n", vcall_reg, (uint32_t) i); + } + } else { + buffer.put(" ts: .branchtargets "); + for (size_t i = 0; i < vcall->callables_set.size(); ++i) { + if (i != 0) { + buffer.put(", "); + } + buffer.fmt("l_%u_%u", vcall_reg, (uint32_t) i); + } + buffer.put(";\n brx.idx %r3, ts;\n"); } buffer.put("\n"); } @@ -932,7 +944,7 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, buffer.fmt(".param .align %u .b8 params[%u]", in_align, in_size); buffer.put(");\n"); } else { - buffer.fmt(" l_%u_%u:\n", vcall->id, callable_id); + buffer.fmt(" l_%u_%u:\n", vcall_reg, callable_id); buffer.put(" {\n"); } @@ -1049,20 +1061,17 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, (unsigned long long) callable_hash.low64); assemble_call(target); read_output_arguments(); - buffer.fmt(" bra.uni l_done_%u;\n", vcall_reg); } buffer.put(" }\n"); - - if (!branch_vcall) { - break; - } + buffer.fmt(" bra.uni l_done_%u;\n", vcall_reg); callable_id++; + + if (!branch_vcall) + break; } - if (!branch_vcall) - buffer.fmt(" bra.uni l_done_%u;\n", vcall_reg); buffer.put(" }\n"); // ===================================================== From bbf31367d2ea29f4e654196efd255704afb8ca43 Mon Sep 17 00:00:00 2001 From: Nicolas Roussel Date: Mon, 21 Nov 2022 08:23:11 +0100 Subject: [PATCH 3/4] Add VCallBranchBinarySearch JIT flag: when using VCallBranch perform a binary search to branch to the appropriate vcall. --- include/drjit-core/jit.h | 65 ++++++++++++---------- src/hash.h | 12 ++++ src/vcall.cpp | 117 +++++++++++++++++++++++++++------------ tests/vcall.cpp | 3 - 4 files changed, 129 insertions(+), 68 deletions(-) diff --git a/include/drjit-core/jit.h b/include/drjit-core/jit.h index 072c83d63..80beed7e4 100644 --- a/include/drjit-core/jit.h +++ b/include/drjit-core/jit.h @@ -1285,7 +1285,7 @@ extern JIT_EXPORT void jit_prefix_pop(JIT_ENUM JitBackend backend); * The default set of flags is: * * ConstProp | ValueNumbering | LoopRecord | LoopOptimize | - * VCallRecord | VCallOptimize | ADOptimize + * VCallRecord | VCallDeduplicate | VCallOptimize | ADOptimize */ #if defined(__cplusplus) enum class JitFlag : uint32_t { @@ -1306,70 +1306,75 @@ enum class JitFlag : uint32_t { /** * \brief Use branches instead of direct callables (in OptiX) or indirect - * function calls (in CUDA) for virtual function calls. + * function calls (in CUDA) for virtual function calls. The default + * branching strategy is a linear search among all targets. */ VCallBranch = 32, + /// Use a jump table to reach appropriate target when `VCallBranch` is enabled + VCallBranchJumpTable = 64, + + /// Perform a binary search to find the appropriate target when `VCallBranch` is enabled + VCallBranchBinarySearch = 128, + /// De-duplicate virtual function calls that produce the same code - VCallDeduplicate = 64, + VCallDeduplicate = 256, /// Enable constant propagation and elide unnecessary function arguments - VCallOptimize = 128, + VCallOptimize = 512, /** * \brief Inline calls if there is only a single instance? (off by default, * inlining can make kernels so large that they actually run slower in * CUDA/OptiX). */ - VCallInline = 256, + VCallInline = 1024, /// Force execution through OptiX even if a kernel doesn't use ray tracing - ForceOptiX = 512, + ForceOptiX = 2048, /// Temporarily postpone evaluation of statements with side effects - Recording = 1024, + Recording = 4096, /// Print the intermediate representation of generated programs - PrintIR = 2048, + PrintIR = 8192, /// Enable writing of the kernel history - KernelHistory = 4096, + KernelHistory = 16384, /* Force synchronization after every kernel launch. This is useful to isolate crashes to a specific kernel, and to benchmark kernel runtime along with the KernelHistory feature. */ - LaunchBlocking = 8192, + LaunchBlocking = 32768, /// Exploit literal constants during AD (used in the Dr.Jit parent project) - ADOptimize = 16384, - - VCallBranchJumpTable = 32768, + ADOptimize = 65536, /// Default flags Default = (uint32_t) ConstProp | (uint32_t) ValueNumbering | (uint32_t) LoopRecord | (uint32_t) LoopOptimize | (uint32_t) VCallRecord | (uint32_t) VCallDeduplicate | - (uint32_t) VCallBranch | (uint32_t) VCallBranchJumpTable | (uint32_t) VCallOptimize | (uint32_t) ADOptimize }; #else enum JitFlag { - JitFlagConstProp = 1, - JitFlagValueNumbering = 2, - JitFlagLoopRecord = 4, - JitFlagLoopOptimize = 8, - JitFlagVCallRecord = 16, - JitFlagVCallBranch = 32, - JitFlagVCallDeduplicate = 64, - JitFlagVCallOptimize = 128, - JitFlagVCallInline = 256, - JitFlagForceOptiX = 512, - JitFlagRecording = 1024, - JitFlagPrintIR = 2048, - JitFlagKernelHistory = 4096, - JitFlagLaunchBlocking = 8192, - JitFlagADOptimize = 16384, - JitFlagVCallBranchJumpTable = 32768, + JitFlagConstProp = 1, + JitFlagValueNumbering = 2, + JitFlagLoopRecord = 4, + JitFlagLoopOptimize = 8, + JitFlagVCallRecord = 16, + JitFlagVCallBranch = 32, + JitFlagVCallBranchJumpTable = 64, + JitFlagVCallBranchBinarySearch = 128, + JitFlagVCallDeduplicate = 256, + JitFlagVCallOptimize = 512, + JitFlagVCallInline = 1024, + JitFlagForceOptiX = 2048, + JitFlagRecording = 4096, + JitFlagPrintIR = 8192, + JitFlagKernelHistory = 16384, + JitFlagLaunchBlocking = 32768, + JitFlagADOptimize = 65536, }; #endif diff --git a/src/hash.h b/src/hash.h index a4fbaf1af..e31631004 100644 --- a/src/hash.h +++ b/src/hash.h @@ -71,6 +71,18 @@ struct XXH128Cmp { } }; +struct XXH128Eq { + bool operator()(const XXH128_hash_t &lhs, const XXH128_hash_t &rhs) const { + return lhs.high64 == rhs.high64 && lhs.low64 == rhs.low64; + } +}; + +struct XXH128Hash { + size_t operator()(const XXH128_hash_t &hash) const { + return hash.low64 ^ hash.high64; + } +}; + inline void hash_combine(size_t& seed, size_t value) { /// From CityHash (https://github.com/google/cityhash) const size_t mult = 0x9ddfea08eb382d69ull; diff --git a/src/vcall.cpp b/src/vcall.cpp index 2c3541385..674488a43 100644 --- a/src/vcall.cpp +++ b/src/vcall.cpp @@ -828,6 +828,76 @@ static void jitc_var_vcall_assemble(VCall *vcall, vcalls_assembled.push_back(vcall); } +static void jitc_var_vcall_branch_strategy_assemble_cuda(const VCall *vcall, + uint32_t vcall_reg) { + bool jump_table = jit_flag(JitFlag::VCallBranchJumpTable); + bool binary_search = jit_flag(JitFlag::VCallBranchBinarySearch); + + if (jump_table == binary_search) { + // Linear search + if (jump_table) + jitc_log(Warn, "jitc_var_vcall_assemble_cuda(): both " + "JitFlag::VCallBranchJumpTable and " + "JitFlag::VCallBranchBinarySearch are enabled, " + "defaulting back to linear search!"); + + for (size_t i = 0; i < vcall->callables_set.size(); ++i) { + buffer.fmt(" setp.eq.u32 %%p3, %%r3, %u;\n", (uint32_t) i); + buffer.fmt(" @%%p3 bra l_%u_%u;\n", vcall_reg, (uint32_t) i); + } + } else if (binary_search) { + uint32_t size = vcall->callables_set.size(); + + uint32_t max_depth = log2i_ceil(size); + for (uint32_t depth = 0; depth < max_depth; ++depth) { + for (uint32_t i = 0; i < (1 << depth); ++i) { + uint32_t range_start = i << (max_depth - depth); + if (size <= range_start) + break; + + uint32_t offset = 1 << (max_depth - depth - 1); + uint32_t spacing = offset * 2; + uint32_t mid = offset + i * spacing; + + uint32_t next_offset = offset >> 1; + uint32_t next_spacing = spacing >> 1; + uint32_t left = next_offset + (i * 2) * next_spacing; + uint32_t right = next_offset + ((i * 2) + 1) * next_spacing; + + if (depth != 0) + buffer.fmt(" l_%u_%u_%u:\n", vcall_reg, depth, mid); + + if (mid < size) { + buffer.fmt(" setp.lt.u32 %%p3, %%r3, %u;\n", mid); + if (depth + 1 < max_depth) { + buffer.fmt(" @%%p3 bra l_%u_%u_%u;\n", vcall_reg, depth + 1, left); + buffer.fmt(" bra.uni l_%u_%u_%u;\n", vcall_reg, depth + 1, right); + } else { + buffer.fmt(" @%%p3 bra.uni l_%u_%u;\n", vcall_reg, left); + buffer.fmt(" bra.uni l_%u_%u;\n", vcall_reg, right); + } + } else { + if (depth + 1 < max_depth) + buffer.fmt(" bra l_%u_%u_%u;\n", vcall_reg, depth + 1, left); + else + buffer.fmt(" bra.uni l_%u_%u;\n", vcall_reg, left); + } + } + } + } else { + // Jump table + buffer.put(" ts: .branchtargets "); + for (size_t i = 0; i < vcall->callables_set.size(); ++i) { + if (i != 0) + buffer.put(", "); + buffer.fmt("l_%u_%u", vcall_reg, (uint32_t) i); + } + buffer.put(";\n brx.idx %r3, ts;\n"); + } + + buffer.put("\n"); +} + /// Virtual function call code generation -- CUDA/PTX-specific bits static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, uint32_t self_reg, uint32_t mask_reg, @@ -837,7 +907,6 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, uint32_t out_align) { bool branch_vcall = jit_flag(JitFlag::VCallBranch); - bool jump_table = jit_flag(JitFlag::VCallBranchJumpTable); // ===================================================== // 1. Conditional branch @@ -901,30 +970,14 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, v2->reg_index, v2->reg_index); } - // Switch statement: branch to call - if (branch_vcall) { - if (!jump_table) { - for (size_t i = 0; i < vcall->callables_set.size(); ++i) { - buffer.fmt(" setp.eq.u32 %%p0, %%r3, %u;\n", (uint32_t) i); - buffer.fmt(" @%%p0 bra l_%u_%u;\n", vcall_reg, (uint32_t) i); - } - } else { - buffer.put(" ts: .branchtargets "); - for (size_t i = 0; i < vcall->callables_set.size(); ++i) { - if (i != 0) { - buffer.put(", "); - } - buffer.fmt("l_%u_%u", vcall_reg, (uint32_t) i); - } - buffer.put(";\n brx.idx %r3, ts;\n"); - } - buffer.put("\n"); - } + // Switch statement: branch to call (multiple strategies) + if (branch_vcall) + jitc_var_vcall_branch_strategy_assemble_cuda(vcall, vcall_reg); uint32_t callable_id = 0; - for (XXH128_hash_t callable_hash: vcall->callables_set) { + for (const XXH128_hash_t &callable_hash : vcall->callables_set) { if (!branch_vcall) { - // Call prototype + // Generate call prototype buffer.put(" {\n"); buffer.put(" proto: .callprototype"); if (out_size) @@ -1006,7 +1059,6 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, } }; - // ===================================================== // 5.3. Call the function and read the output arguments // ===================================================== @@ -1044,6 +1096,7 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg, buffer.fmt(" ld.param.%s %s%u, [out+%u];\n", tname, prefix, v2->reg_index, load_offset); + // Special handling for predicates if ((VarType) v2->type == VarType::Bool) buffer.fmt(" setp.ne.u16 %%p%u, %%w%u, 0;\n", v2->reg_index, v2->reg_index); @@ -1423,18 +1476,12 @@ void jitc_vcall_upload(ThreadState *ts) { memset(data, 0, vcall->offset_size); if (ts->backend == JitBackend::CUDA && jit_flag(JitFlag::VCallBranch)) { - for (uint32_t i = 0; i < vcall->n_inst; ++i) { - uint32_t callable_index = 0; - bool found = false; - for (auto callable: vcall->callables_set) { - if (callable.high64 == vcall->inst_hash[i].high64 && - callable.low64 == vcall->inst_hash[i].low64) { - found = true; - break; - } - callable_index++; - } + tsl::robin_map callable_indices; + uint32_t index = 0; + for (const XXH128_hash_t &callable : vcall->callables_set) + callable_indices[callable] = index++; + for (uint32_t i = 0; i < vcall->n_inst; ++i) { auto it = globals_map.find(GlobalKey(vcall->inst_hash[i], true)); if (it == globals_map.end()) jitc_fail("jitc_vcall_upload(): could not find callable!"); @@ -1442,7 +1489,7 @@ void jitc_vcall_upload(ThreadState *ts) { // high part: instance data offset, low part: callable index data[vcall->inst_id[i]] = (((uint64_t) vcall->data_offset[i]) << 32) | - callable_index; + callable_indices[vcall->inst_hash[i]]; } } else { diff --git a/tests/vcall.cpp b/tests/vcall.cpp index f8fba88cc..6c649af09 100644 --- a/tests/vcall.cpp +++ b/tests/vcall.cpp @@ -229,9 +229,6 @@ TEST_BOTH(01_recorded_vcall) { A1 a1; A2 a2; - jit_set_flag(JitFlag::PrintIR, true); - jit_set_log_level_stderr(LogLevel::Trace); - // jit_llvm_set_target("skylake-avx512", "+avx512f,+avx512dq,+avx512vl,+avx512cd", 16); uint32_t i1 = jit_registry_put(Backend, "Base", &a1); uint32_t i2 = jit_registry_put(Backend, "Base", &a2); From 5a2a82a4ad99c4a38e31cfcfde634c0c78aaa383 Mon Sep 17 00:00:00 2001 From: Nicolas Roussel Date: Tue, 6 Dec 2022 13:32:54 +0100 Subject: [PATCH 4/4] Add VCallBranch alternatives to tests --- tests/vcall.cpp | 388 ++++++++++++++++++++++++++++-------------------- 1 file changed, 225 insertions(+), 163 deletions(-) diff --git a/tests/vcall.cpp b/tests/vcall.cpp index 6c649af09..2d9383bff 100644 --- a/tests/vcall.cpp +++ b/tests/vcall.cpp @@ -239,10 +239,15 @@ TEST_BOTH(01_recorded_vcall) { BasePtr self = arange(10) % 3; for (uint32_t i = 0; i < 2; ++i) { - jit_set_flag(JitFlag::VCallOptimize, i); - Float y = vcall( - "Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, x); - jit_assert(strcmp(y.str(), "[0, 22, 204, 0, 28, 210, 0, 34, 216, 0]") == 0); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); + jit_set_flag(JitFlag::VCallOptimize, i); + Float y = vcall( + "Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, x); + jit_assert(strcmp(y.str(), "[0, 22, 204, 0, 28, 210, 0, 34, 216, 0]") == 0); + } } jit_registry_remove(Backend, &a1); @@ -290,35 +295,40 @@ TEST_BOTH(02_calling_conventions) { (void) i1; (void) i2; (void) i3; for (uint32_t i = 0; i < 2; ++i) { - jit_set_flag(JitFlag::VCallOptimize, i); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); + jit_set_flag(JitFlag::VCallOptimize, i); - using BasePtr = Array; - BasePtr self = arange(12) % 4; + using BasePtr = Array; + BasePtr self = arange(12) % 4; - Mask p0(false); - Float p1(12); - Double p2(34); - Float p3(56); - Mask p4(true); + Mask p0(false); + Float p1(12); + Double p2(34); + Float p3(56); + Mask p4(true); - auto result = vcall( - "Base", - [](Base *self2, Mask p0, Float p1, Double p2, Float p3, Mask p4) { - return self2->f(p0, p1, p2, p3, p4); - }, - self, p0, p1, p2, p3, p4); - - jit_var_schedule(result.template get<0>().index()); - jit_var_schedule(result.template get<1>().index()); - jit_var_schedule(result.template get<2>().index()); - jit_var_schedule(result.template get<3>().index()); - jit_var_schedule(result.template get<4>().index()); - - jit_assert(strcmp(result.template get<0>().str(), "[0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0]") == 0); - jit_assert(strcmp(result.template get<1>().str(), "[0, 12, 13, 0, 0, 12, 13, 0, 0, 12, 13, 0]") == 0); - jit_assert(strcmp(result.template get<2>().str(), "[0, 34, 36, 0, 0, 34, 36, 0, 0, 34, 36, 0]") == 0); - jit_assert(strcmp(result.template get<3>().str(), "[0, 56, 59, 0, 0, 56, 59, 0, 0, 56, 59, 0]") == 0); - jit_assert(strcmp(result.template get<4>().str(), "[0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]") == 0); + auto result = vcall( + "Base", + [](Base *self2, Mask p0, Float p1, Double p2, Float p3, Mask p4) { + return self2->f(p0, p1, p2, p3, p4); + }, + self, p0, p1, p2, p3, p4); + + jit_var_schedule(result.template get<0>().index()); + jit_var_schedule(result.template get<1>().index()); + jit_var_schedule(result.template get<2>().index()); + jit_var_schedule(result.template get<3>().index()); + jit_var_schedule(result.template get<4>().index()); + + jit_assert(strcmp(result.template get<0>().str(), "[0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0]") == 0); + jit_assert(strcmp(result.template get<1>().str(), "[0, 12, 13, 0, 0, 12, 13, 0, 0, 12, 13, 0]") == 0); + jit_assert(strcmp(result.template get<2>().str(), "[0, 34, 36, 0, 0, 34, 36, 0, 0, 34, 36, 0]") == 0); + jit_assert(strcmp(result.template get<3>().str(), "[0, 56, 59, 0, 0, 56, 59, 0, 0, 56, 59, 0]") == 0); + jit_assert(strcmp(result.template get<4>().str(), "[0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]") == 0); + } } jit_registry_remove(Backend, &b1); @@ -362,32 +372,36 @@ TEST_BOTH(03_optimize_away_outputs) { BasePtr self = arange(10) % 4; for (uint32_t i = 0; i < 2; ++i) { - i = 1; - jit_set_flag(JitFlag::VCallOptimize, i); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); + jit_set_flag(JitFlag::VCallOptimize, i); - jit_assert(jit_var_ref(p3.index()) == 1); + jit_assert(jit_var_ref(p3.index()) == 1); - auto result = vcall( - "Base", - [](Base *self2, Float p1, Float p2, Float p3) { - return self2->f(p1, p2, p3); - }, - self, p1, p2, p3); + auto result = vcall( + "Base", + [](Base *self2, Float p1, Float p2, Float p3) { + return self2->f(p1, p2, p3); + }, + self, p1, p2, p3); - jit_assert(jit_var_ref(p1.index()) == 3); - jit_assert(jit_var_ref(p2.index()) == 3); + jit_assert(jit_var_ref(p1.index()) == 3); + jit_assert(jit_var_ref(p2.index()) == 3); - // Irrelevant input optimized away - jit_assert(jit_var_ref(p3.index()) == 2 - i); + // Irrelevant input optimized away + jit_assert(jit_var_ref(p3.index()) == 2 - i); - result.template get<0>() = Float(0); + result.template get<0>() = Float(0); - jit_assert(jit_var_ref(p1.index()) == 3); - jit_assert(jit_var_ref(p2.index()) == 3 - 2*i); - jit_assert(jit_var_ref(p3.index()) == 2 - i); + jit_assert(jit_var_ref(p1.index()) == 3); + jit_assert(jit_var_ref(p2.index()) == 3 - 2*i); + jit_assert(jit_var_ref(p3.index()) == 2 - i); - jit_assert(strcmp(jit_var_str(result.template get<1>().index()), - "[0, 13, 13, 14, 0, 13, 13, 14, 0, 13]") == 0); + jit_assert(strcmp(jit_var_str(result.template get<1>().index()), + "[0, 13, 13, 14, 0, 13, 13, 14, 0, 13]") == 0); + } } jit_registry_remove(Backend, &c1); @@ -424,47 +438,52 @@ TEST_BOTH(04_devirtualize) { for (uint32_t k = 0; k < 2; ++k) { for (uint32_t i = 0; i < 2; ++i) { - Float p1, p2; - if (k == 0) { - p1 = 12; - p2 = 34; - } else { - p1 = dr::opaque(12); - p2 = dr::opaque(34); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); + Float p1, p2; + if (k == 0) { + p1 = 12; + p2 = 34; + } else { + p1 = dr::opaque(12); + p2 = dr::opaque(34); + } + + jit_set_flag(JitFlag::VCallOptimize, i); + uint32_t scope = jit_scope(Backend); + + auto result = vcall( + "Base", + [](Base *self2, Float p1, Float p2) { + return self2->f(p1, p2); + }, + self, p1, p2); + + jit_set_scope(Backend, scope + 1); + + Float p2_wrap = Float::steal(jit_var_wrap_vcall(p2.index())); + + Mask mask = neq(self, nullptr), + mask_combined = Mask::steal(jit_var_mask_apply(mask.index(), 10)); + + Float alt = (p2_wrap + 2) & mask_combined; + + jit_set_scope(Backend, scope + 2); + + jit_assert((result.template get<0>().index() == alt.index()) == (i == 1)); + jit_assert(jit_var_is_literal(result.template get<2>().index()) == (i == 1)); + + jit_var_schedule(result.template get<0>().index()); + jit_var_schedule(result.template get<1>().index()); + + jit_assert( + strcmp(jit_var_str(result.template get<0>().index()), + "[0, 36, 36, 0, 36, 36, 0, 36, 36, 0]") == 0); + jit_assert(strcmp(jit_var_str(result.template get<1>().index()), + "[0, 13, 14, 0, 13, 14, 0, 13, 14, 0]") == 0); } - - jit_set_flag(JitFlag::VCallOptimize, i); - uint32_t scope = jit_scope(Backend); - - auto result = vcall( - "Base", - [](Base *self2, Float p1, Float p2) { - return self2->f(p1, p2); - }, - self, p1, p2); - - jit_set_scope(Backend, scope + 1); - - Float p2_wrap = Float::steal(jit_var_wrap_vcall(p2.index())); - - Mask mask = neq(self, nullptr), - mask_combined = Mask::steal(jit_var_mask_apply(mask.index(), 10)); - - Float alt = (p2_wrap + 2) & mask_combined; - - jit_set_scope(Backend, scope + 2); - - jit_assert((result.template get<0>().index() == alt.index()) == (i == 1)); - jit_assert(jit_var_is_literal(result.template get<2>().index()) == (i == 1)); - - jit_var_schedule(result.template get<0>().index()); - jit_var_schedule(result.template get<1>().index()); - - jit_assert( - strcmp(jit_var_str(result.template get<0>().index()), - "[0, 36, 36, 0, 36, 36, 0, 36, 36, 0]") == 0); - jit_assert(strcmp(jit_var_str(result.template get<1>().index()), - "[0, 13, 14, 0, 13, 14, 0, 13, 14, 0]") == 0); } } jit_registry_remove(Backend, &d1); @@ -509,11 +528,16 @@ TEST_BOTH(05_extra_data) { } for (uint32_t i = 0; i < 2; ++i) { - jit_set_flag(JitFlag::VCallOptimize, i); - Float result = vcall( - "Base", [](Base *self2, Float x) { return self2->f(x); }, self, - x); - jit_assert(strcmp(result.str(), "[0, 9, 13, 0, 21, 28, 0, 33, 43, 0]") == 0); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); + jit_set_flag(JitFlag::VCallOptimize, i); + Float result = vcall( + "Base", [](Base *self2, Float x) { return self2->f(x); }, self, + x); + jit_assert(strcmp(result.str(), "[0, 9, 13, 0, 21, 28, 0, 33, 43, 0]") == 0); + } } } jit_registry_remove(Backend, &e1); @@ -550,20 +574,25 @@ TEST_BOTH(06_side_effects) { BasePtr self = arange(11) % 3; for (uint32_t i = 0; i < 2; ++i) { - jit_set_flag(JitFlag::VCallOptimize, i); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); + jit_set_flag(JitFlag::VCallOptimize, i); - F1 f1; F2 f2; - uint32_t i1 = jit_registry_put(Backend, "Base", &f1); - uint32_t i2 = jit_registry_put(Backend, "Base", &f2); - jit_assert(i1 == 1 && i2 == 2); + F1 f1; F2 f2; + uint32_t i1 = jit_registry_put(Backend, "Base", &f1); + uint32_t i2 = jit_registry_put(Backend, "Base", &f2); + jit_assert(i1 == 1 && i2 == 2); - vcall("Base", [](Base *self2) { self2->go(); }, self); - jit_assert(strcmp(f1.buffer.str(), "[0, 4, 0, 8, 0]") == 0); - jit_assert(strcmp(f2.buffer.str(), "[0, 1, 5, 3]") == 0); + vcall("Base", [](Base *self2) { self2->go(); }, self); + jit_assert(strcmp(f1.buffer.str(), "[0, 4, 0, 8, 0]") == 0); + jit_assert(strcmp(f2.buffer.str(), "[0, 1, 5, 3]") == 0); - jit_registry_remove(Backend, &f1); - jit_registry_remove(Backend, &f2); - jit_registry_trim(); + jit_registry_remove(Backend, &f1); + jit_registry_remove(Backend, &f2); + jit_registry_trim(); + } } } @@ -595,26 +624,31 @@ TEST_BOTH(07_side_effects_only_once) { BasePtr self = arange(11) % 3; for (uint32_t i = 0; i < 2; ++i) { - jit_set_flag(JitFlag::VCallOptimize, i); - - G1 g1; G2 g2; - uint32_t i1 = jit_registry_put(Backend, "Base", &g1); - uint32_t i2 = jit_registry_put(Backend, "Base", &g2); - jit_assert(i1 == 1 && i2 == 2); - - auto result = vcall("Base", [](Base *self2) { return self2->f(); }, self); - Float f1 = result.template get<0>(); - Float f2 = result.template get<1>(); - jit_assert(strcmp(f1.str(), "[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1]") == 0); - jit_assert(strcmp(g1.buffer.str(), "[0, 4, 0, 0, 0]") == 0); - jit_assert(strcmp(g2.buffer.str(), "[0, 0, 3, 0, 0]") == 0); - jit_assert(strcmp(f2.str(), "[0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2]") == 0); - jit_assert(strcmp(g1.buffer.str(), "[0, 4, 0, 0, 0]") == 0); - jit_assert(strcmp(g2.buffer.str(), "[0, 0, 3, 0, 0]") == 0); - - jit_registry_remove(Backend, &g1); - jit_registry_remove(Backend, &g2); - jit_registry_trim(); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); + jit_set_flag(JitFlag::VCallOptimize, i); + + G1 g1; G2 g2; + uint32_t i1 = jit_registry_put(Backend, "Base", &g1); + uint32_t i2 = jit_registry_put(Backend, "Base", &g2); + jit_assert(i1 == 1 && i2 == 2); + + auto result = vcall("Base", [](Base *self2) { return self2->f(); }, self); + Float f1 = result.template get<0>(); + Float f2 = result.template get<1>(); + jit_assert(strcmp(f1.str(), "[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1]") == 0); + jit_assert(strcmp(g1.buffer.str(), "[0, 4, 0, 0, 0]") == 0); + jit_assert(strcmp(g2.buffer.str(), "[0, 0, 3, 0, 0]") == 0); + jit_assert(strcmp(f2.str(), "[0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2]") == 0); + jit_assert(strcmp(g1.buffer.str(), "[0, 4, 0, 0, 0]") == 0); + jit_assert(strcmp(g2.buffer.str(), "[0, 0, 3, 0, 0]") == 0); + + jit_registry_remove(Backend, &g1); + jit_registry_remove(Backend, &g2); + jit_registry_trim(); + } } } @@ -652,12 +686,17 @@ TEST_BOTH(08_multiple_calls) { for (uint32_t i = 0; i < 2; ++i) { - jit_set_flag(JitFlag::VCallOptimize, i); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); + jit_set_flag(JitFlag::VCallOptimize, i); - Float y = vcall("Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, x); - Float z = vcall("Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, y); + Float y = vcall("Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, x); + Float z = vcall("Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, y); - jit_assert(strcmp(z.str(), "[0, 12, 14, 0, 12, 14, 0, 12, 14, 0]") == 0); + jit_assert(strcmp(z.str(), "[0, 12, 14, 0, 12, 14, 0, 12, 14, 0]") == 0); + } } jit_registry_remove(Backend, &h1); @@ -710,26 +749,31 @@ TEST_BOTH(09_big) { self2 = select(self2 <= n2, self2, 0); for (uint32_t i = 0; i < 2; ++i) { - jit_set_flag(JitFlag::VCallOptimize, i); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); + jit_set_flag(JitFlag::VCallOptimize, i); - Float x = vcall("Base1", [](Base1 *self_) { return self_->f(); }, Base1Ptr(self1)); - Float y = vcall("Base2", [](Base2 *self_) { return self_->f(); }, Base2Ptr(self2)); + Float x = vcall("Base1", [](Base1 *self_) { return self_->f(); }, Base1Ptr(self1)); + Float y = vcall("Base2", [](Base2 *self_) { return self_->f(); }, Base2Ptr(self2)); - jit_var_schedule(x.index()); - jit_var_schedule(y.index()); + jit_var_schedule(x.index()); + jit_var_schedule(y.index()); - jit_assert(x.read(0) == 0); - jit_assert(y.read(0) == 0); + jit_assert(x.read(0) == 0); + jit_assert(y.read(0) == 0); - for (uint32_t j = 1; j <= n1; ++j) - jit_assert(x.read(j) == j - 1); - for (uint32_t j = 1; j <= n2; ++j) - jit_assert(y.read(j) == 100 + j - 1); + for (uint32_t j = 1; j <= n1; ++j) + jit_assert(x.read(j) == j - 1); + for (uint32_t j = 1; j <= n2; ++j) + jit_assert(y.read(j) == 100 + j - 1); - for (uint32_t j = n1 + 1; j < n; ++j) - jit_assert(x.read(j + 1) == 0); - for (uint32_t j = n2 + 1; j < n; ++j) - jit_assert(y.read(j + 1) == 0); + for (uint32_t j = n1 + 1; j < n; ++j) + jit_assert(x.read(j + 1) == 0); + for (uint32_t j = n2 + 1; j < n; ++j) + jit_assert(y.read(j + 1) == 0); + } } for (int i = 0; i < n1; ++i) @@ -754,12 +798,18 @@ TEST_BOTH(09_self) { uint32_t i2_id = jit_registry_put(Backend, "Base", &i2); UInt32 self(i1_id, i2_id); - UInt32 y = vcall( - "Base", - [](Base *self_) { return self_->f(); }, - BasePtr(self)); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); - jit_assert(strcmp(y.str(), "[1, 2]") == 0); + UInt32 y = vcall( + "Base", + [](Base *self_) { return self_->f(); }, + BasePtr(self)); + + jit_assert(strcmp(y.str(), "[1, 2]") == 0); + } jit_registry_remove(Backend, &i1); jit_registry_remove(Backend, &i2); @@ -796,14 +846,20 @@ TEST_BOTH(10_recursion) { UInt32 self2(i21_id, i22_id); Float x(3.f, 5.f); - Float y = vcall( - "Base2", - [](Base2 *self_, const Base1Ptr &ptr_, const Float &x_) { - return self_->g(ptr_, x_); - }, - Base2Ptr(self2), Base1Ptr(self1), x); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); - jit_assert(strcmp(y.str(), "[7, 16]") == 0); + Float y = vcall( + "Base2", + [](Base2 *self_, const Base1Ptr &ptr_, const Float &x_) { + return self_->g(ptr_, x_); + }, + Base2Ptr(self2), Base1Ptr(self1), x); + + jit_assert(strcmp(y.str(), "[7, 16]") == 0); + } jit_registry_remove(Backend, &i11); jit_registry_remove(Backend, &i12); @@ -842,14 +898,20 @@ TEST_BOTH(11_recursion_with_local) { UInt32 self2(i21_id, i22_id); Float x(3.f, 5.f); - Float y = vcall( - "Base2", - [](Base2 *self_, const Base1Ptr &ptr_, const Float &x_) { - return self_->g(ptr_, x_); - }, - Base2Ptr(self2), Base1Ptr(self1), x); + for (uint32_t j = 0; j < 4; ++j) { + jit_set_flag(JitFlag::VCallBranch, j > 0); + jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2); + jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3); - jit_assert(strcmp(y.str(), "[7, 16]") == 0); + Float y = vcall( + "Base2", + [](Base2 *self_, const Base1Ptr &ptr_, const Float &x_) { + return self_->g(ptr_, x_); + }, + Base2Ptr(self2), Base1Ptr(self1), x); + + jit_assert(strcmp(y.str(), "[7, 16]") == 0); + } jit_registry_remove(Backend, &i11); jit_registry_remove(Backend, &i12);