Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch/Switch for virtual function calls #45

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions include/drjit-core/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ extern JIT_EXPORT void jit_prefix_pop(JIT_ENUM JitBackend backend);
* The default set of flags is:
*
* <tt>ConstProp | ValueNumbering | LoopRecord | LoopOptimize |
* VCallRecord | VCallOptimize | ADOptimize</tt>
* VCallRecord | VCallDeduplicate | VCallOptimize | ADOptimize</tt>
*/
#if defined(__cplusplus)
enum class JitFlag : uint32_t {
Expand All @@ -1304,38 +1304,51 @@ enum class JitFlag : uint32_t {
/// Record virtual function calls instead of splitting them into many small kernel launches
VCallRecord = 16,

/**
* \brief Use branches instead of direct callables (in OptiX) or indirect
* function calls (in CUDA) for virtual function calls. The default
* branching strategy is a linear search among all targets.
*/
VCallBranch = 32,

/// Use a jump table to reach appropriate target when `VCallBranch` is enabled
VCallBranchJumpTable = 64,

/// Perform a binary search to find the appropriate target when `VCallBranch` is enabled
VCallBranchBinarySearch = 128,

/// De-duplicate virtual function calls that produce the same code
VCallDeduplicate = 32,
VCallDeduplicate = 256,

/// Enable constant propagation and elide unnecessary function arguments
VCallOptimize = 64,
VCallOptimize = 512,

/**
* \brief Inline calls if there is only a single instance? (off by default,
* inlining can make kernels so large that they actually run slower in
* CUDA/OptiX).
*/
VCallInline = 128,
VCallInline = 1024,

/// Force execution through OptiX even if a kernel doesn't use ray tracing
ForceOptiX = 256,
ForceOptiX = 2048,

/// Temporarily postpone evaluation of statements with side effects
Recording = 512,
Recording = 4096,

/// Print the intermediate representation of generated programs
PrintIR = 1024,
PrintIR = 8192,

/// Enable writing of the kernel history
KernelHistory = 2048,
KernelHistory = 16384,

/* Force synchronization after every kernel launch. This is useful to
isolate crashes to a specific kernel, and to benchmark kernel runtime
along with the KernelHistory feature. */
LaunchBlocking = 4096,
LaunchBlocking = 32768,

/// Exploit literal constants during AD (used in the Dr.Jit parent project)
ADOptimize = 8192,
ADOptimize = 65536,

/// Default flags
Default = (uint32_t) ConstProp | (uint32_t) ValueNumbering |
Expand All @@ -1345,20 +1358,23 @@ enum class JitFlag : uint32_t {
};
#else
enum JitFlag {
JitFlagConstProp = 1,
JitFlagValueNumbering = 2,
JitFlagLoopRecord = 4,
JitFlagLoopOptimize = 8,
JitFlagVCallRecord = 16,
JitFlagVCallDeduplicate = 32,
JitFlagVCallOptimize = 64,
JitFlagVCallInline = 128,
JitFlagForceOptiX = 256,
JitFlagRecording = 512,
JitFlagPrintIR = 1024,
JitFlagKernelHistory = 2048,
JitFlagLaunchBlocking = 4096,
JitFlagADOptimize = 8192
JitFlagConstProp = 1,
JitFlagValueNumbering = 2,
JitFlagLoopRecord = 4,
JitFlagLoopOptimize = 8,
JitFlagVCallRecord = 16,
JitFlagVCallBranch = 32,
JitFlagVCallBranchJumpTable = 64,
JitFlagVCallBranchBinarySearch = 128,
JitFlagVCallDeduplicate = 256,
JitFlagVCallOptimize = 512,
JitFlagVCallInline = 1024,
JitFlagForceOptiX = 2048,
JitFlagRecording = 4096,
JitFlagPrintIR = 8192,
JitFlagKernelHistory = 16384,
JitFlagLaunchBlocking = 32768,
JitFlagADOptimize = 65536,
};
#endif

Expand Down
51 changes: 35 additions & 16 deletions src/eval_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,30 +203,48 @@ void jitc_assemble_cuda(ThreadState *ts, ScheduledGroup group,
it.second.callable_index = ctr++;
}

if (callable_count > 0 && !uses_optix) {
if (callable_count > 0) {
size_t insertion_point =
(char *) strstr(buffer.get(), ".address_size 64\n\n") -
buffer.get() + 18,
insertion_start = buffer.size();

buffer.fmt(".extern .global .u64 callables[%u];\n\n",
callable_count_unique);
if (jit_flag(JitFlag::VCallBranch)) {
// Copy signatures to very beginning
for (const auto &it : globals_map) {
if (!it.first.callable)
continue;

jitc_insert_code_at(insertion_point, insertion_start);
const char* func_definition = globals.get() + it.second.start;
const char* signature_begin = strstr(func_definition, ".func");
const char* signature_end = strstr(func_definition, "{");

buffer.fmt("\n.visible .global .align 8 .u64 callables[%u] = {\n",
callable_count_unique);
for (auto const &it : globals_map) {
if (!it.first.callable)
continue;
buffer.put(".visible ");
buffer.put(signature_begin,
signature_end - 1 - signature_begin);
buffer.put(";\n");
}
buffer.fmt("\n");
jitc_insert_code_at(insertion_point, insertion_start);
} else if (!uses_optix) {
buffer.fmt(".extern .global .u64 callables[%u];\n\n",
callable_count_unique);
jitc_insert_code_at(insertion_point, insertion_start);

buffer.fmt("\n.visible .global .align 8 .u64 callables[%u] = {\n",
callable_count_unique);
for (auto const &it : globals_map) {
if (!it.first.callable)
continue;

buffer.fmt(" func_%016llx%016llx%s\n",
(unsigned long long) it.first.hash.high64,
(unsigned long long) it.first.hash.low64,
it.second.callable_index + 1 < callable_count_unique ? "," : "");
}

buffer.fmt(" func_%016llx%016llx%s\n",
(unsigned long long) it.first.hash.high64,
(unsigned long long) it.first.hash.low64,
it.second.callable_index + 1 < callable_count_unique ? "," : "");
buffer.put("};\n\n");
}

buffer.put("};\n\n");
}

jitc_vcall_upload(ts);
Expand All @@ -245,8 +263,9 @@ void jitc_assemble_cuda_func(const char *name, uint32_t inst_id,

buffer.put(".visible .func");
if (out_size) buffer.fmt(" (.param .align %u .b8 result[%u])", out_align, out_size);
bool uses_direct_callables = uses_optix && !(jit_flag(JitFlag::VCallBranch));
buffer.fmt(" %s^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^(",
uses_optix ? "__direct_callable__" : "func_");
uses_direct_callables ? "__direct_callable__" : "func_");

if (use_self) {
buffer.put(".reg .u32 self");
Expand Down
12 changes: 12 additions & 0 deletions src/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ struct XXH128Cmp {
}
};

struct XXH128Eq {
bool operator()(const XXH128_hash_t &lhs, const XXH128_hash_t &rhs) const {
return lhs.high64 == rhs.high64 && lhs.low64 == rhs.low64;
}
};

struct XXH128Hash {
size_t operator()(const XXH128_hash_t &hash) const {
return hash.low64 ^ hash.high64;
}
};

inline void hash_combine(size_t& seed, size_t value) {
/// From CityHash (https://github.com/google/cityhash)
const size_t mult = 0x9ddfea08eb382d69ull;
Expand Down
28 changes: 15 additions & 13 deletions src/optix_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,19 +662,21 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size,
pgd[0].raygen.module = kernel.optix.mod;
pgd[0].raygen.entryFunctionName = strdup(kern_name);

for (auto const &it : globals_map) {
if (!it.first.callable)
continue;

char *name = (char *) malloc_check(52);
snprintf(name, 52, "__direct_callable__%016llx%016llx",
(unsigned long long) it.first.hash.high64,
(unsigned long long) it.first.hash.low64);

uint32_t index = 1 + it.second.callable_index;
pgd[index].kind = OPTIX_PROGRAM_GROUP_KIND_CALLABLES;
pgd[index].callables.moduleDC = kernel.optix.mod;
pgd[index].callables.entryFunctionNameDC = name;
if (!jit_flag(JitFlag::VCallBranch)) {
for (auto const &it : globals_map) {
if (!it.first.callable)
continue;

char *name = (char *) malloc_check(52);
snprintf(name, 52, "__direct_callable__%016llx%016llx",
(unsigned long long) it.first.hash.high64,
(unsigned long long) it.first.hash.low64);

uint32_t index = 1 + it.second.callable_index;
pgd[index].kind = OPTIX_PROGRAM_GROUP_KIND_CALLABLES;
pgd[index].callables.moduleDC = kernel.optix.mod;
pgd[index].callables.entryFunctionNameDC = name;
}
}

kernel.optix.pg = new OptixProgramGroup[n_programs];
Expand Down
Loading