Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frozen ThreadState #107

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ add_library(
src/malloc.h src/malloc.cpp
src/registry.h src/registry.cpp
src/util.h src/util.cpp
src/record_ts.h src/record_ts.cpp

# CUDA backend
src/cuda_api.h
Expand Down
154 changes: 154 additions & 0 deletions include/drjit-core/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,9 +494,15 @@ extern JIT_EXPORT void jit_registry_remove(const void *ptr);
extern JIT_EXPORT uint32_t jit_registry_id(const void *ptr);

/// Return the largest instance ID for the given domain
/// If the \c domain is a nullptr, it returns the number of active entries in
/// all domains for the given backend
extern JIT_EXPORT uint32_t jit_registry_id_bound(JitBackend backend,
const char *domain);

/// Fills the \c dest pointer array with all pointers registered in the registry
/// \c dest has to point to an array with \c jit_registry_id_bound(backend, nullptr) entries
extern JIT_EXPORT void jit_registry_get_pointers(JitBackend backend, void **dest);

/// Return the pointer value associated with a given instance ID
extern JIT_EXPORT void *jit_registry_ptr(JitBackend backend,
const char *domain, uint32_t id);
Expand Down Expand Up @@ -2474,6 +2480,154 @@ extern JIT_EXPORT uint32_t jit_array_write(uint32_t target, uint32_t offset,
extern JIT_EXPORT uint32_t jit_array_read(uint32_t source, uint32_t offset,
uint32_t mask);

// ====================================================================
// Core functionality to enable kernel freezing via @dr.freeze
// ====================================================================

/// Opaque data structure, storing a sequence of Dr.Jit operations recorded
/// using the API below.
struct Recording;

/**
* \brief Start a recording session. This causes Dr.Jit to track all backend
* operations such as memory copies and kernel launches, storing them into a
* representation that can be replayed without needing to re-trace code.
*
* \param backend
* The backend for which recording should be started.
*
* \param inputs
* An array of input variable indices, which have to be specified
* when starting the recording. Different indices, representing other
* variables, might be provided when replaying. This function borrows the
* indices and does not increment their refcount.
*
* \param n_inputs
* The number of input variables for the recording
*/
extern JIT_EXPORT void
jit_freeze_start(JitBackend backend, const uint32_t *inputs, uint32_t n_inputs);

/**
* \brief Stop recording operations and return a struct containing the
* recording.
*
* The recording is returned as an opaque pointer and has to be destroyed
* afterwards by calling the \c jitc_freeze_destroy function.
*
* \param backend
* The backend on which recording should be stopped.
*
* \param outputs
* An array of output variable indices. When replaying the recording, these
* variables are returned from the replay function. This function borrows
* the indices and does not modify their refcount.
*
* \param n_outputs
* The number of output variables of the recording.
*/
extern JIT_EXPORT Recording *jit_freeze_stop(JitBackend backend,
const uint32_t *outputs,
uint32_t n_outputs);

/**
* \brief Replay a recording with different inputs.
*
* Replaying a recording with different inputs results in different output
* variables. Their variable indices will be written into the outputs array.
*
* \param recording
* The recording to replay given different inputs.
*
* \param inputs
* An array of input variable indices for replaying the recording. The
* number of inputs taken from the array is equal to the number of inputs
* supplied to the jit_start_record function. The variable indices are only
* referenced and their refcount is not changed by this function.
*
* \param outputs
* This array is filled with the output variable indices, created when
* replaying the recording. The size of the array has to be equal to the
* number of output variables supplied to the jit_record_stop function.
* The variables in the output are borrowing references and have to be
* released.
*/
DoeringChristian marked this conversation as resolved.
Show resolved Hide resolved
extern JIT_EXPORT void jit_freeze_replay(Recording *recording,
const uint32_t *inputs,
uint32_t *outputs);

/**
* \brief Perform a dry run replay of a recording (if required), which does not
* perform any actual work.
*
* A dry run of the recording calculates the widths of all kernels and returns
* ``0`` if the function has to be re-recorded. It returns ``1`` if the function
* can be replayed. This is required to catch cases where the size of an input
* variable changes the compiled kernels, for example when performing scatter
* reductions in LLVM mode.
*
* \param recording
* The recording to replay given different inputs.
*
* \param inputs
* An array of input variable indices for replaying the recording. The
* number of inputs taken from the array is equal to the number of inputs
* supplied to the jit_start_record function. No actual changes are
* performed on the variables in dry-run mode.
*
* \return 0 if retracing the function is required, 1 otherwise
*/
extern JIT_EXPORT int jit_freeze_dry_run(Recording *recording,
const uint32_t *inputs);

/**
* \brief Pause recording the ThreadState for this backend.
*
* Returns an integer indicating the pause state before calling this function,
* i.e. returns ``1`` if recording was already paused, and ``0`` otherwise. If
* no recording is in progress, this function fails. This is useful to prevent
* accidentally recording operations when traversing the output of a frozen
* function.
*
* \param backend
* The backend for which to pause recording the thread state.
*/
extern JIT_EXPORT int jit_freeze_pause(JitBackend backend);

/**
* \brief Resume recording the ThreadState for this backend.
*
* Returns an integer indicating the pause state before calling this function,
* i.e. returns ``1`` if recording was paused, and ``0`` otherwise. If no
* recording is in progress, this function fails.
*
* \param backend
* The backend for which to pause recording the thread state.
*/
extern JIT_EXPORT int jit_freeze_resume(JitBackend backend);

/**
* \brief Abort recording the ThreadState for this backend.
*
* This will abort the recording process and restore the state to the state it
* was in before starting the recording. Aborting a recording has the same
* effect as never starting the recording. All operations on variables are still
* performed and their state might have changed. If no recording is in progress,
* this function will run without effect.
*
* \param backend
* The backend for which to abort recording the thread state.
*/
extern JIT_EXPORT void jit_freeze_abort(JitBackend backend);

/**
* \brief Destroys a recording and frees the associated memory.
*
* \param recording
* The recording to destroy.
*/
extern JIT_EXPORT void jit_freeze_destroy(Recording *recording);

#if defined(__cplusplus)
}
#endif
60 changes: 51 additions & 9 deletions src/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "cond.h"
#include "profile.h"
#include "array.h"
#include "record_ts.h"
#include <thread>
#include <condition_variable>
#include <drjit-core/half.h>
Expand Down Expand Up @@ -159,18 +160,11 @@ uint32_t jit_flags() {
}

void jit_set_flag(JitFlag flag, int enable) {
uint32_t flags = jitc_flags();

if (enable)
flags |= (uint32_t) flag;
else
flags &= ~(uint32_t) flag;

jitc_set_flags(flags);
jitc_set_flag(flag, enable);
}

int jit_flag(JitFlag flag) {
return (jitc_flags() & (uint32_t) flag) ? 1 : 0;
return jitc_flag(flag);
}

uint32_t jit_record_checkpoint(JitBackend backend) {
Expand Down Expand Up @@ -983,6 +977,11 @@ uint32_t jit_registry_id_bound(JitBackend backend, const char *domain) {
return jitc_registry_id_bound(backend, domain);
}

void jit_registry_get_pointers(JitBackend backend, void **dest) {
lock_guard guard(state.lock);
return jitc_registry_get_pointers(backend, dest);
}

void *jit_registry_ptr(JitBackend backend, const char *domain, uint32_t id) {
lock_guard guard(state.lock);
return jitc_registry_ptr(backend, domain, id);
Expand Down Expand Up @@ -1497,3 +1496,46 @@ int jit_leak_warnings() {
void jit_set_leak_warnings(int value) {
state.leak_warnings = (bool) value;
}

void jit_freeze_start(JitBackend backend, const uint32_t *inputs,
uint32_t n_inputs) {
lock_guard guard(state.lock);
return jitc_freeze_start(backend, inputs, n_inputs);
}

Recording *jit_freeze_stop(JitBackend backend, const uint32_t *outputs,
uint32_t n_outputs) {
lock_guard guard(state.lock);
return jitc_freeze_stop(backend, outputs, n_outputs);
}

int jit_freeze_pause(JitBackend backend) {
lock_guard guard(state.lock);
return jitc_freeze_pause(backend);
}

int jit_freeze_resume(JitBackend backend) {
lock_guard guard(state.lock);
return jitc_freeze_resume(backend);
}

void jit_freeze_abort(JitBackend backend) {
lock_guard guard(state.lock);
return jitc_freeze_abort(backend);
}

void jit_freeze_replay(Recording *recording, const uint32_t *inputs,
uint32_t *outputs) {
lock_guard guard(state.lock);
jitc_freeze_replay(recording, inputs, outputs);
}

int jit_freeze_dry_run(Recording *recording, const uint32_t *inputs) {
lock_guard guard(state.lock);
return jitc_freeze_dry_run(recording, inputs);
}

void jit_freeze_destroy(Recording *recording) {
lock_guard guard(state.lock);
jitc_freeze_destroy(recording);
}
4 changes: 1 addition & 3 deletions src/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,7 @@ void jitc_assemble(ThreadState *ts, ScheduledGroup group) {
width = jitc_llvm_vector_width;

if (backend == JitBackend::CUDA) {
uintptr_t size = 0;
memcpy(&size, &group.size, sizeof(uint32_t));
kernel_params.push_back((void *) size);
kernel_params.push_back((void *) (uintptr_t) group.size);

// The first 3 variables are reserved on the CUDA backend
n_regs = 4;
Expand Down
15 changes: 15 additions & 0 deletions src/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,21 @@ uint32_t jitc_flags() {
return jitc_flags_v;
}

void jitc_set_flag(JitFlag flag, int enable) {
uint32_t flags = jitc_flags();

if (enable)
flags |= (uint32_t) flag;
else
flags &= ~(uint32_t) flag;

jitc_set_flags(flags);
}

int jitc_flag(JitFlag flag) {
return (jitc_flags() & (uint32_t) flag) ? 1 : 0;
}

/// ==========================================================================

KernelHistory::KernelHistory() : m_data(nullptr), m_size(0), m_capacity(0) { }
Expand Down
6 changes: 6 additions & 0 deletions src/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,12 @@ extern void jitc_set_flags(uint32_t flags);

extern uint32_t jitc_flags();

/// Selectively enables/disables flags
extern void jitc_set_flag(JitFlag flag, int enable);

/// Checks whether a given flag is active. Returns zero or one.
extern int jitc_flag(JitFlag flag);

/// Push a new label onto the prefix stack
extern void jitc_prefix_push(JitBackend backend, const char *label);

Expand Down
7 changes: 7 additions & 0 deletions src/llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,10 @@ extern void jitc_llvm_set_target(const char *target_cpu,
/// Insert a ray tracing function call into the LLVM program
extern void jitc_llvm_ray_trace(uint32_t func, uint32_t scene, int shadow_ray,
const uint32_t *in, uint32_t *out);

/// Computes the workers and replication_per_worker factors for the
/// ``jitc_var_expand`` function, given the size and type size.
/// ``jitc_var_expand`` Expands a variable to a larger storage area to avoid
/// atomic scatter.
extern std::pair<uint32_t, uint32_t>
jitc_llvm_expand_replication_factor(uint32_t size, uint32_t tsize);
9 changes: 9 additions & 0 deletions src/llvm_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,12 @@ void jitc_llvm_compile(Kernel &kernel) {
jitc_fail("jit_llvm_compile(): VirtualProtect() failed: %u", GetLastError());
#endif
}

std::pair<uint32_t, uint32_t> jitc_llvm_expand_replication_factor(uint32_t size, uint32_t tsize) {
uint32_t workers = pool_size() + 1;
// 1 cache line per worker for scalar targets, otherwise be a bit more
// reasonable
uint32_t replication_per_worker = size == 1u ? (64u / tsize) : 1u;

return std::pair(workers, replication_per_worker);
}
Loading