Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call: allow isolation per (variant, domain, scope) #109

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions include/drjit-core/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,10 @@ extern JIT_EXPORT void *jit_malloc_migrate(void *ptr, JIT_ENUM AllocType type,
*
* This function registers the specified pointer \c ptr with the registry,
* returning the associated ID value, which is guaranteed to be unique within
* the specified domain \c domain. The domain is normally an identifier that is
* associated with the "flavor" of the pointer (e.g. instances of a particular
* class), and which ensures that the returned ID values are as low as
* possible.
* the specified domain identified by the \c (variant, domain) strings.
* The domain is normally an identifier that is associated with the "flavor"
* of the pointer (e.g. instances of a particular * class), and which ensures
* that the returned ID values are as low as possible.
*
* Caution: for reasons of efficiency, the \c domain parameter is assumed to a
* static constant that will remain alive. The RTTI identifier
Expand All @@ -480,7 +480,7 @@ extern JIT_EXPORT void *jit_malloc_migrate(void *ptr, JIT_ENUM AllocType type,
* Raises an exception when ``ptr`` is ``nullptr``, or when it has already been
* registered with *any* domain.
*/
extern JIT_EXPORT uint32_t jit_registry_put(JIT_ENUM JitBackend backend,
extern JIT_EXPORT uint32_t jit_registry_put(const char *variant,
const char *domain, void *ptr);

/**
Expand All @@ -494,15 +494,16 @@ extern JIT_EXPORT void jit_registry_remove(const void *ptr);
extern JIT_EXPORT uint32_t jit_registry_id(const void *ptr);

/// Return the largest instance ID for the given domain
extern JIT_EXPORT uint32_t jit_registry_id_bound(JitBackend backend,
extern JIT_EXPORT uint32_t jit_registry_id_bound(const char *variant,
const char *domain);

/// Return the pointer value associated with a given instance ID
extern JIT_EXPORT void *jit_registry_ptr(JitBackend backend,
extern JIT_EXPORT void *jit_registry_ptr(const char *variant,
const char *domain, uint32_t id);

/// Return an arbitrary pointer value associated with a given domain
extern JIT_EXPORT void *jit_registry_peek(JitBackend backend, const char *domain);
extern JIT_EXPORT void *jit_registry_peek(const char *variant,
const char *domain);

/// Disable any instances that are currently registered in the registry
extern JIT_EXPORT void jit_registry_clear();
Expand Down Expand Up @@ -2143,13 +2144,14 @@ struct CallBucket {
*
* This function expects an array of integers, whose entries correspond to
* pointers that have previously been registered by calling \ref
* jit_registry_put() with domain \c domain. It then invokes \ref jit_mkperm()
* to compute a permutation that reorders the array into coherent buckets. The
* buckets are returned using an array of type \ref CallBucket, which contains
* both the resolved pointer address (obtained via \ref
* jit_registry_get_ptr()) and the variable index of an unsigned 32 bit array
* containing the corresponding entries of the input array. The total number of
* buckets is returned via the \c bucket_count_inout argument.
* jit_registry_put() with domain \c (variant, domain).
* It then invokes \ref jit_mkperm() to compute a permutation that reorders
* the array into coherent buckets. The buckets are returned using an array
* of type \ref CallBucket, which contains both the resolved pointer address
* (obtained via \ref * jit_registry_get_ptr()) and the variable index of an
* unsigned 32 bit array * containing the corresponding entries of the
* input array.
* The total number of buckets is returned via the \c bucket_count_inout argument.
*
* Alternatively, this function can be used to to dispatch using an arbitrary
* index list. In this case, \c domain should be set to \c nullptr and the
Expand All @@ -2166,8 +2168,9 @@ struct CallBucket {
* set of instances.
*/
extern JIT_EXPORT struct CallBucket *
jit_var_call_reduce(JIT_ENUM JitBackend backend, const char *domain,
uint32_t index, uint32_t *bucket_count_inout);
jit_var_call_reduce(JIT_ENUM JitBackend backend, const char *variant,
const char *domain, uint32_t index,
uint32_t *bucket_count_inout);

/**
* \brief Insert a function call to a ray tracing functor into the LLVM program
Expand Down
25 changes: 13 additions & 12 deletions src/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -963,9 +963,9 @@ uint32_t jit_mkperm(JitBackend backend, const uint32_t *values, uint32_t size,
return jitc_mkperm(backend, values, size, bucket_count, perm, offsets);
}

uint32_t jit_registry_put(JitBackend backend, const char *domain, void *ptr) {
uint32_t jit_registry_put(const char *variant, const char *domain, void *ptr) {
lock_guard guard(state.lock);
return jitc_registry_put(backend, domain, ptr);
return jitc_registry_put(variant, domain, ptr);
}

void jit_registry_remove(const void *ptr) {
Expand All @@ -978,19 +978,19 @@ uint32_t jit_registry_id(const void *ptr) {
return jitc_registry_id(ptr);
}

uint32_t jit_registry_id_bound(JitBackend backend, const char *domain) {
uint32_t jit_registry_id_bound(const char *variant, const char *domain) {
lock_guard guard(state.lock);
return jitc_registry_id_bound(backend, domain);
return jitc_registry_id_bound(variant, domain);
}

void *jit_registry_ptr(JitBackend backend, const char *domain, uint32_t id) {
void *jit_registry_ptr(const char *variant, const char *domain, uint32_t id) {
lock_guard guard(state.lock);
return jitc_registry_ptr(backend, domain, id);
return jitc_registry_ptr(variant, domain, id);
}

void *jit_registry_peek(JitBackend backend, const char *domain) {
void *jit_registry_peek(const char *variant, const char *domain) {
lock_guard guard(state.lock);
return jitc_registry_peek(backend, domain);
return jitc_registry_peek(variant, domain);
}

void jit_registry_clear() {
Expand Down Expand Up @@ -1025,11 +1025,12 @@ void jit_aggregate(JitBackend backend, void *dst, AggregationEntry *agg,
return jitc_aggregate(backend, dst, agg, size);
}

struct CallBucket *
jit_var_call_reduce(JitBackend backend, const char *domain, uint32_t index,
uint32_t *bucket_count_inout) {
struct CallBucket *jit_var_call_reduce(JitBackend backend, const char *variant,
const char *domain, uint32_t index,
uint32_t *bucket_count_inout) {
lock_guard guard(state.lock);
return jitc_var_call_reduce(backend, domain, index, bucket_count_inout);
return jitc_var_call_reduce(backend, variant, domain, index,
bucket_count_inout);
}

void jit_kernel_history_clear() {
Expand Down
25 changes: 14 additions & 11 deletions src/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@
license that can be found in the LICENSE file.
*/

#include <set>
#include <string.h>

#include "call.h"
#include "eval.h"
#include "internal.h"
#include "log.h"
#include "var.h"
#include "eval.h"
#include "registry.h"
#include "util.h"
#include "loop.h"
#include "op.h"
#include "profile.h"
#include "loop.h"
#include "registry.h"
#include "trace.h"
#include "call.h"
#include <set>
#include "util.h"
#include "var.h"

std::vector<CallData *> calls_assembled;

Expand Down Expand Up @@ -688,8 +690,9 @@ void jitc_call_upload(ThreadState *ts) {
}

// Compute a permutation to reorder an array of registered pointers
CallBucket *jitc_var_call_reduce(JitBackend backend, const char *domain,
uint32_t index, uint32_t *bucket_count_inout) {
CallBucket *jitc_var_call_reduce(JitBackend backend, const char *variant,
const char *domain, uint32_t index,
uint32_t *bucket_count_inout) {

struct CallReduceRecord {
CallBucket *buckets;
Expand All @@ -711,7 +714,7 @@ CallBucket *jitc_var_call_reduce(JitBackend backend, const char *domain,

uint32_t bucket_count;
if (domain)
bucket_count = jitc_registry_id_bound(backend, domain);
bucket_count = jitc_registry_id_bound(variant, domain);
else
bucket_count = *bucket_count_inout;

Expand Down Expand Up @@ -796,7 +799,7 @@ CallBucket *jitc_var_call_reduce(JitBackend backend, const char *domain,

CallBucket bucket_out;
if (domain)
bucket_out.ptr = jitc_registry_ptr(backend, domain, bucket.id);
bucket_out.ptr = jitc_registry_ptr(variant, domain, bucket.id);
else
bucket_out.ptr = nullptr;

Expand Down
4 changes: 2 additions & 2 deletions src/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ extern void jitc_var_call(const char *domain, bool symbolic, uint32_t self,

extern void jitc_call_upload(ThreadState *ts);

extern CallBucket *jitc_var_call_reduce(JitBackend backend, const char *domain,
uint32_t index,
extern CallBucket *jitc_var_call_reduce(JitBackend backend, const char *variant,
const char *domain, uint32_t index,
uint32_t *bucket_count_out);

extern void jitc_var_call_assemble(CallData *call, uint32_t call_reg,
Expand Down
40 changes: 22 additions & 18 deletions src/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,29 @@
license that can be found in the LICENSE file.
*/

#include <queue>
#include <string.h>

#include "registry.h"
#include "log.h"
#include <queue>

// Dr.Jit maintains an ID registry per backend and class. This class separates
// multiple parallel data structures maintaining this information.
// Dr.Jit maintains an ID registry per variant and domain (e.g. class).
// This class separates multiple parallel data structures maintaining
// this information.
struct DomainKey {
JitBackend backend;
const char *variant;
const char *domain;

struct Eq {
bool operator()(DomainKey k1, DomainKey k2) const {
return k1.backend == k2.backend && strcmp(k1.domain, k2.domain) == 0;
return (strcmp(k1.variant, k2.variant) == 0)
&& (strcmp(k1.domain, k2.domain) == 0);
}
};

struct Hash {
size_t operator()(DomainKey k) const {
return hash_str(k.domain, (size_t) k.backend);
return hash_str(k.variant, 0) ^ hash_str(k.domain, 1);
merlinND marked this conversation as resolved.
Show resolved Hide resolved
}
};
};
Expand Down Expand Up @@ -60,7 +64,7 @@ struct Registry {
static Registry registry;

/// Register a pointer with Dr.Jit's pointer registry
uint32_t jitc_registry_put(JitBackend backend, const char *domain_name, void *ptr) {
uint32_t jitc_registry_put(const char *variant, const char *domain_name, void *ptr) {
Registry &r = registry;

auto [it1, result1] =
Expand All @@ -69,8 +73,8 @@ uint32_t jitc_registry_put(JitBackend backend, const char *domain_name, void *pt
jitc_raise("jit_registry_put(domain=\"%s\", ptr=%p): pointer is "
"already registered!", domain_name, ptr);

// Allocate a domain entry for the key (backend, domain) if unregistered
auto [it2, result2] = r.domain_ids.try_emplace(DomainKey{ backend, domain_name },
// Allocate a domain entry for the key (variant, domain) if unregistered
auto [it2, result2] = r.domain_ids.try_emplace(DomainKey{ variant, domain_name },
(uint32_t) r.domains.size());
if (result2) {
r.domains.emplace_back();
Expand Down Expand Up @@ -147,29 +151,29 @@ uint32_t jitc_registry_id(const void *ptr) {
return it->second.index + 1;
}

uint32_t jitc_registry_id_bound(JitBackend backend, const char *domain) {
uint32_t jitc_registry_id_bound(const char *variant, const char *domain) {
Registry &r = registry;
auto it = r.domain_ids.find(DomainKey{ backend, domain });
auto it = r.domain_ids.find(DomainKey{ variant, domain });
if (it == r.domain_ids.end())
return 0;
else
return r.domains[it->second].id_bound;
}

void *jitc_registry_ptr(JitBackend backend, const char *domain_name, uint32_t id) {
void *jitc_registry_ptr(const char *variant, const char *domain_name, uint32_t id) {
if (id == 0)
return nullptr;

Registry &r = registry;
auto it = r.domain_ids.find(DomainKey{ backend, domain_name });
auto it = r.domain_ids.find(DomainKey{ variant, domain_name });
void *ptr = nullptr;

if (it != r.domain_ids.end()) {
Domain &domain = r.domains[it->second];
if (id - 1 >= domain.fwd_map.size())
jitc_raise("jit_registry_ptr(domain=\"%s\", id=%u): instance is "
"not registered!",
domain_name, id);
jitc_raise("jit_registry_ptr(variant=\"%s\", domain=\"%s\", id=%u):"
" instance is not registered!",
variant, domain_name, id);
Ptr entry = domain.fwd_map[id - 1];
if (entry.active)
ptr = entry.ptr;
Expand All @@ -178,9 +182,9 @@ void *jitc_registry_ptr(JitBackend backend, const char *domain_name, uint32_t id
return ptr;
}

void *jitc_registry_peek(JitBackend backend, const char *domain_name) {
void *jitc_registry_peek(const char *variant, const char *domain) {
Registry &r = registry;
auto it = r.domain_ids.find(DomainKey{ backend, domain_name });
auto it = r.domain_ids.find(DomainKey{ variant, domain });
void *ptr = nullptr;

if (it != r.domain_ids.end()) {
Expand Down
9 changes: 5 additions & 4 deletions src/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "internal.h"

/// Register a pointer with Dr.Jit's pointer registry
extern uint32_t jitc_registry_put(JitBackend backend, const char *domain,
extern uint32_t jitc_registry_put(const char *variant, const char *domain,
void *ptr);

/// Remove a pointer from the registry
Expand All @@ -22,13 +22,14 @@ extern void jitc_registry_remove(const void *ptr);
extern uint32_t jitc_registry_id(const void *ptr);

/// Return the largest instance ID for the given domain
extern uint32_t jitc_registry_id_bound(JitBackend backend, const char *domain);
extern uint32_t jitc_registry_id_bound(const char *variant, const char *domain);

/// Return the pointer value associated with a given instance ID
extern void *jitc_registry_ptr(JitBackend backend, const char *domain, uint32_t id);
extern void *jitc_registry_ptr(const char *variant, const char *domain,
uint32_t id);

/// Return an arbitrary pointer value associated with a given domain
extern void *jitc_registry_peek(JitBackend backend, const char *domain);
extern void *jitc_registry_peek(const char *variant, const char *domain);

/// Check for leaks in the registry
extern void jitc_registry_shutdown();
Expand Down
Loading