diff --git a/src/python/freeze.cpp b/src/python/freeze.cpp index 2a4db1b5..f7badc1a 100644 --- a/src/python/freeze.cpp +++ b/src/python/freeze.cpp @@ -22,8 +22,6 @@ #include #include #include -#include -#include #include struct ProfilerPhase { @@ -63,89 +61,47 @@ struct ADScopeContext { ~ADScopeContext() { ad_scope_leave(process_postponed); } }; -using index64_vector = drjit::detail::index64_vector; static const char *doc_freeze = R"( )"; -enum class LayoutFlag : uint32_t { - SingletonArray = (1 << 0), - Unaligned = (1 << 1), - GradEnabled = (1 << 2), - Postponed = (1 << 3), - Registry = (1 << 4), -}; +bool Layout::operator==(const Layout &rhs) const { + if (!(this->type.equal(rhs.type))) + return false; -/// Stores information about python objects, such as their type, their number of -/// sub-elements or their field keys. This can be used to reconstruct a PyTree -/// from a flattened variable array. -struct Layout { - /// Nanobind type of the container/variable - nb::type_object type; - /// Number of members in this container. - /// Can be used to traverse the layout without knowing the type. - uint32_t num = 0; - /// Optional field identifiers of the container - /// for example: keys in dictionary - std::vector fields; - /// Optional drjit type of the variable - VarType vt = VarType::Void; - /// Optional evaluation state of the variable - VarState vs = VarState::Invalid; - uint32_t flags = 0; - /// The literal data - uint64_t literal = 0; - /// The index in the flat_variables array of this variable. - /// This can be used to determine aliasing. - uint32_t index = 0; - /// We have to track the condition, where two variables have the same size - /// during recording but don't when replaying. - /// Therefore we de-duplicate the size. - uint32_t size_index = 0; - - /// If a non drjit type is passed as function arguments or result, we simply - /// cache it here. - /// TODO: possibly do the same for literals? - nb::object py_object = nb::none(); - - bool operator==(const Layout &rhs) const { - if (!(this->type.equal(rhs.type))) - return false; + if (this->num != rhs.num) + return false; - if (this->num != rhs.num) - return false; + if (this->fields.size() != rhs.fields.size()) + return false; - if (this->fields.size() != rhs.fields.size()) - return false; - - for (uint32_t i = 0; i < this->fields.size(); ++i) { - if (!(this->fields[i].equal(rhs.fields[i]))) - return false; - } - if (this->vt != rhs.vt) + for (uint32_t i = 0; i < this->fields.size(); ++i) { + if (!(this->fields[i].equal(rhs.fields[i]))) return false; + } + if (this->vt != rhs.vt) + return false; - if (this->vs != rhs.vs) - return false; + if (this->vs != rhs.vs) + return false; - if (this->flags != rhs.flags) - return false; + if (this->flags != rhs.flags) + return false; - if (this->index != rhs.index) - return false; + if (this->index != rhs.index) + return false; - if (this->size_index != rhs.size_index) - return false; + if (this->size_index != rhs.size_index) + return false; - if (this->literal != rhs.literal) - return false; - if (!this->py_object.equal(rhs.py_object)) - return false; + if (this->literal != rhs.literal) + return false; + if (!this->py_object.equal(rhs.py_object)) + return false; - return true; - } -}; + return true; +} static void log_layouts(const std::vector &layouts, std::ostream &os, uint32_t &index, std::string &padding) { @@ -184,756 +140,702 @@ static void log_layouts(const std::vector &layouts, std::ostream &os, } } -// Additional context required when traversing the inputs -struct TraverseContext { - /// Set of postponed ad nodes, used to mark inputs to functions. - const tsl::robin_set *postponed = nullptr; -}; - /** - * A flattened representation of the PyTree. + * Adds a variable to the flattened array, deduplicating it. + * This allows for checking for aliasing conditions, as aliasing inputs map + * to the same flat variable index. */ -struct FlatVariables { - - // Variable, used to iterate over the variables/layouts when constructing - // python objects - uint32_t layout_index = 0; - - /// The flattened variable indices of the input/output to a frozen function - std::vector variables; - /// Mapping from drjit variable index to index in flat variables - tsl::robin_map index_to_slot; - - /// We have to track the condition, where two variables have the same size - /// during recording but don't when replaying. - /// Therefore we de-duplicate the size. - std::vector sizes; - tsl::robin_map size_to_slot; - - /// This saves information about the type, size and fields of pytree - /// objects. The information is stored in DFS order. - std::vector layout; - JitBackend backend = JitBackend::None; - - // Wether variables should be borrowed, instead of stealing them - bool borrow = true; - - FlatVariables() {} - FlatVariables(bool borrow) : borrow(borrow) {} - - void clear() { - this->layout_index = 0; - this->variables.clear(); - this->index_to_slot.clear(); - this->layout.clear(); - this->backend = JitBackend::None; - } - void release() { - for (uint32_t &index : this->variables) { - jit_var_dec_ref(index); - } +uint32_t FlatVariables::add_variable_index(uint32_t variable_index) { + uint32_t next_slot = this->variables.size(); + auto result = this->index_to_slot.try_emplace(variable_index, next_slot); + auto it = result.first; + bool inserted = result.second; + + if (inserted) { + if (borrow) + jit_var_inc_ref(variable_index); + this->variables.push_back(variable_index); + return next_slot; + } else { + return it.value(); } +} - /** - * Adds a variable to the flattened array, deduplicating it. - * This allows for checking for aliasing conditions, as aliasing inputs map - * to the same flat variable index. - */ - uint32_t add_variable_index(uint32_t variable_index) { - uint32_t next_slot = this->variables.size(); - auto result = - this->index_to_slot.try_emplace(variable_index, next_slot); - auto it = result.first; - bool inserted = result.second; - - if (inserted) { - if (borrow) - jit_var_inc_ref(variable_index); - this->variables.push_back(variable_index); - return next_slot; - } else { - return it.value(); - } +/** + * This function returns an index of an equivalence class for the variable + * size in the flattened variables. + * It uses a hashmap and vector to deduplicate sizes. + * + * This is necessary, to catch cases, where two variables had the same size + * when freezing a function and two different sizes when replaying. + * In that case one kernel would be recorded, that evaluates both variables. + * However, when replaying two kernels would have to be launched since the + * now differently sized variables cannot be evaluated by the same kernel. + */ +uint32_t FlatVariables::add_size(uint32_t size) { + uint32_t next_slot = this->sizes.size(); + auto result = this->size_to_slot.try_emplace(size, next_slot); + auto it = result.first; + bool inserted = result.second; + + if (inserted) { + this->sizes.push_back(size); + return next_slot; + } else { + return it.value(); } +} - /** - * This function returns an index of an equivalence class for the variable - * size in the flattened variables. - * It uses a hashmap and vector to deduplicate sizes. - * - * This is necessary, to catch cases, where two variables had the same size - * when freezing a function and two different sizes when replaying. - * In that case one kernel would be recorded, that evaluates both variables. - * However, when replaying two kernels would have to be launched since the - * now differently sized variables cannot be evaluated by the same kernel. - */ - uint32_t add_size(uint32_t size) { - uint32_t next_slot = this->sizes.size(); - auto result = this->size_to_slot.try_emplace(size, next_slot); - auto it = result.first; - bool inserted = result.second; - - if (inserted) { - this->sizes.push_back(size); - return next_slot; - } else { - return it.value(); - } +/** + * Traverse the variable referenced by a jit index and add it to the flat + * variables. An optional type python type can be supplied if it is known. + */ +void FlatVariables::traverse_jit_index(uint32_t index, TraverseContext &ctx, + nb::handle tp) { + // ProfilerPhase profiler("traverse_jit_index"); + VarInfo info = jit_set_backend(index); + JitBackend var_backend = info.backend; + + if (backend == var_backend || this->backend == JitBackend::None) { + backend = var_backend; + } else { + jit_raise("freeze(): backend missmatch error (backend of this " + "variable %s does not match backend of others %s)!", + var_backend == JitBackend::CUDA ? "CUDA" : "LLVM", + backend == JitBackend::CUDA ? "CUDA" : "LLVM"); } - /** - * Traverse the variable referenced by a jit index and add it to the flat - * variables. An optional type python type can be supplied if it is known. - */ - void traverse_jit_index(uint32_t index, TraverseContext &ctx, - nb::handle tp = nb::none()) { - // ProfilerPhase profiler("traverse_jit_index"); - VarInfo info = jit_set_backend(index); - JitBackend var_backend = info.backend; - - if (backend == var_backend || this->backend == JitBackend::None) { - backend = var_backend; - } else { - jit_raise("freeze(): backend missmatch error (backend of this " - "variable %s does not match backend of others %s)!", - var_backend == JitBackend::CUDA ? "CUDA" : "LLVM", - backend == JitBackend::CUDA ? "CUDA" : "LLVM"); - } + if (jit_var_type(index) == VarType::Pointer) { + // We do not support pointers as inputs. It might be possible with + // some extra handling, but they are never used directly. + jit_raise("Pointer inputs not yet supported!"); + } - if (jit_var_type(index) == VarType::Pointer) { - // We do not support pointers as inputs. It might be possible with - // some extra handling, but they are never used directly. - jit_raise("Pointer inputs not yet supported!"); - } + uint32_t var_size = jit_var_size(index); + + Layout layout; + VarState vs = jit_var_state(index); + layout.type = nb::borrow(tp); + layout.vs = vs; + layout.vt = jit_var_type(index); + layout.size_index = this->add_size(var_size); + + if (vs == VarState::Literal) { + // jit_fail("test r%u", index); + // Special case, where the variable is a literal. This should not + // occur, as all literals are made opaque in beforehand, however it + // is nice to have a fallback. + jit_var_read(index, 0, &layout.literal); + // Store size in index variable, as this is not used for literals + layout.index = var_size; + } else if (vs == VarState::Evaluated) { + // Special case, handling evaluated/opaque variables. + + void *data = nullptr; + uint32_t tmp = jit_var_data(index, &data); + if (tmp != index) + jit_fail("traverse(): An evaluated variable changed during " + "evaluation!"); + jit_var_dec_ref(tmp); + + layout.index = this->add_variable_index(index); + bool unaligned = jit_var_is_unaligned(index); + + layout.flags |= + (var_size == 1 ? (uint32_t) LayoutFlag::SingletonArray : 0); + layout.flags |= + (jit_var_is_unaligned(index) ? (uint32_t) LayoutFlag::Unaligned + : 0); - uint32_t var_size = jit_var_size(index); + } else { + jit_raise("collect(): found variable %u in unsupported state %u!", + index, (uint32_t) vs); + } + this->layout.push_back(layout); +} +/** + * Add an ad variable by it's index. Both the value and gradient are added + * to the flattened variables. If the ad index has been marked as postponed + * in the \c TraverseContext.postponed field, we mark the resulting layout + * with that flag. This will cause the gradient edges to be propagated when + * assigning to the input. The function takes an optional python-type if + * it is known. + */ +void FlatVariables::traverse_ad_index(uint64_t index, TraverseContext &ctx, + nb::handle tp) { + // ProfilerPhase profiler("traverse_ad_index"); + int grad_enabled = ad_grad_enabled(index); + jit_log(LogLevel::Debug, "traverse_ad_index(): a%u, r%u", + (uint32_t) (index >> 32), (uint32_t) index, grad_enabled); + if (grad_enabled) { + uint32_t ad_index = (uint32_t) (index >> 32); Layout layout; - VarState vs = jit_var_state(index); - layout.type = nb::borrow(tp); - layout.vs = vs; - layout.vt = jit_var_type(index); - layout.size_index = this->add_size(var_size); - - if (vs == VarState::Literal) { - // jit_fail("test r%u", index); - // Special case, where the variable is a literal. This should not - // occur, as all literals are made opaque in beforehand, however it - // is nice to have a fallback. - jit_var_read(index, 0, &layout.literal); - // Store size in index variable, as this is not used for literals - layout.index = var_size; - } else if (vs == VarState::Evaluated) { - // Special case, handling evaluated/opaque variables. - - void *data = nullptr; - uint32_t tmp = jit_var_data(index, &data); - if (tmp != index) - jit_fail("traverse(): An evaluated variable changed during " - "evaluation!"); - jit_var_dec_ref(tmp); - - layout.index = this->add_variable_index(index); - bool unaligned = jit_var_is_unaligned(index); - - layout.flags |= - (var_size == 1 ? (uint32_t) LayoutFlag::SingletonArray : 0); - layout.flags |= - (jit_var_is_unaligned(index) ? (uint32_t) LayoutFlag::Unaligned - : 0); - - } else { - jit_raise("collect(): found variable %u in unsupported state %u!", - index, (uint32_t) vs); + layout.type = nb::borrow(tp); + layout.num = 2; + layout.vt = jit_var_type(index); + + // Set flags + layout.flags |= (uint32_t) LayoutFlag::GradEnabled; + // If the edge with this node as it's target has been postponed by + // the isolate gradient scope, it has been enqueued and we mark the + // ad variable as such. + if (ctx.postponed && ctx.postponed->contains(ad_index)) { + layout.flags |= (uint32_t) LayoutFlag::Postponed; } - this->layout.push_back(layout); - } - /** - * Add an ad variable by it's index. Both the value and gradient are added - * to the flattened variables. If the ad index has been marked as postponed - * in the \c TraverseContext.postponed field, we mark the resulting layout - * with that flag. This will cause the gradient edges to be propagated when - * assigning to the input. The function takes an optional python-type if - * it is known. - */ - void traverse_ad_index(uint64_t index, TraverseContext &ctx, - nb::handle tp = nb::none()) { - // ProfilerPhase profiler("traverse_ad_index"); - int grad_enabled = ad_grad_enabled(index); - jit_log(LogLevel::Debug, "traverse_ad_index(): a%u, r%u", - (uint32_t) (index >> 32), (uint32_t) index, grad_enabled); - if (grad_enabled) { - uint32_t ad_index = (uint32_t) (index >> 32); - Layout layout; - layout.type = nb::borrow(tp); - layout.num = 2; - layout.vt = jit_var_type(index); - - // Set flags - layout.flags |= (uint32_t) LayoutFlag::GradEnabled; - // If the edge with this node as it's target has been postponed by - // the isolate gradient scope, it has been enqueued and we mark the - // ad variable as such. - if (ctx.postponed && ctx.postponed->contains(ad_index)) { - layout.flags |= (uint32_t) LayoutFlag::Postponed; - } - - this->layout.push_back(layout); + this->layout.push_back(layout); - traverse_jit_index((uint32_t) index, ctx, tp); - uint32_t grad = ad_grad(index); - traverse_jit_index(grad, ctx, tp); - jit_var_dec_ref(grad); - } else { - traverse_jit_index(index, ctx, tp); - } + traverse_jit_index((uint32_t) index, ctx, tp); + uint32_t grad = ad_grad(index); + traverse_jit_index(grad, ctx, tp); + jit_var_dec_ref(grad); + } else { + traverse_jit_index(index, ctx, tp); } +} - /** - * Wrapper aground traverse_ad_index for a python handle. - */ - void traverse_ad_var(nb::handle h, TraverseContext &ctx) { - auto s = supp(h.type()); +/** + * Wrapper aground traverse_ad_index for a python handle. + */ +void FlatVariables::traverse_ad_var(nb::handle h, TraverseContext &ctx) { + auto s = supp(h.type()); - raise_if(s.index == nullptr, "freeze(): ArraySupplement index function " - "pointer is nullptr."); + raise_if(s.index == nullptr, "freeze(): ArraySupplement index function " + "pointer is nullptr."); - uint64_t index = s.index(inst_ptr(h)); + uint64_t index = s.index(inst_ptr(h)); - this->traverse_ad_index(index, ctx, h.type()); - } + this->traverse_ad_index(index, ctx, h.type()); +} - /** - * Traverse a c++ tree using it's `traverse_1_cb_ro` callback. - */ - void traverse_cb(const drjit::TraversableBase *traversable, - TraverseContext &ctx, nb::object type = nb::none()) { - ProfilerPhase profiler(traversable); +/** + * Traverse a c++ tree using it's `traverse_1_cb_ro` callback. + */ +void FlatVariables::traverse_cb(const drjit::TraversableBase *traversable, + TraverseContext &ctx, nb::object type) { + ProfilerPhase profiler(traversable); - Layout layout; - layout.type = nb::borrow(type); - size_t layout_index = this->layout.size(); - this->layout.push_back(layout); + Layout layout; + layout.type = nb::borrow(type); + size_t layout_index = this->layout.size(); + this->layout.push_back(layout); - uint32_t num_fileds = 0; + uint32_t num_fileds = 0; - struct Payload { - FlatVariables *flat_vars; - uint32_t num_fields; - TraverseContext *ctx; - }; - Payload payload{ this, 0, &ctx }; - traversable->traverse_1_cb_ro( - (void *) &payload, [](void *p, uint64_t index) { - if (!index) - return; - Payload *payload = (Payload *) p; - payload->num_fields++; - payload->flat_vars->traverse_ad_index(index, *payload->ctx); - }); - - this->layout[layout_index].num = payload.num_fields; - } + struct Payload { + FlatVariables *flat_vars; + uint32_t num_fields; + TraverseContext *ctx; + }; + Payload payload{ this, 0, &ctx }; + traversable->traverse_1_cb_ro( + (void *) &payload, [](void *p, uint64_t index) { + if (!index) + return; + Payload *payload = (Payload *) p; + payload->num_fields++; + payload->flat_vars->traverse_ad_index(index, *payload->ctx); + }); - /** - * Traverses a PyTree in DFS order, and records it's layout in the - * `layout` vector. - * - * When hitting a drjit primitive type, it calls the - * `traverse_dr_var` method, which will add their indices to the - * `flat_variables` vector. The collect method will also record metadata - * about the drjit variable in the layout. Therefore, the layout can be used - * as an identifier to the recording of the frozen function. - */ - void traverse(nb::handle h, TraverseContext &ctx) { - ProfilerPhase profiler("traverse"); - nb::handle tp = h.type(); - - auto tp_name = nb::type_name(tp).c_str(); - jit_log(LogLevel::Debug, "FlatVariables::traverse(): %s {", tp_name); + this->layout[layout_index].num = payload.num_fields; +} - try { - if (is_drjit_type(tp)) { - const ArraySupplement &s = supp(tp); - if (s.is_tensor) { - nb::handle array = s.tensor_array(h.ptr()); - - Layout layout; - layout.type = nb::borrow(tp); - layout.py_object = shape(h); - layout.literal = width(array); - this->layout.push_back(layout); - - traverse(nb::steal(array), ctx); - } else if (s.ndim != 1) { - Py_ssize_t len = s.shape[0]; - if (len == DRJIT_DYNAMIC) - len = s.len(inst_ptr(h)); - - Layout layout; - layout.type = nb::borrow(tp); - layout.num = len; - this->layout.push_back(layout); - - for (Py_ssize_t i = 0; i < len; ++i) - traverse(nb::steal(s.item(h.ptr(), i)), ctx); - } else { - traverse_ad_var(h, ctx); - } - } else if (tp.is(&PyTuple_Type)) { - nb::tuple tuple = nb::borrow(h); +/** + * Traverses a PyTree in DFS order, and records it's layout in the + * `layout` vector. + * + * When hitting a drjit primitive type, it calls the + * `traverse_dr_var` method, which will add their indices to the + * `flat_variables` vector. The collect method will also record metadata + * about the drjit variable in the layout. Therefore, the layout can be used + * as an identifier to the recording of the frozen function. + */ +void FlatVariables::traverse(nb::handle h, TraverseContext &ctx) { + ProfilerPhase profiler("traverse"); + nb::handle tp = h.type(); - Layout layout; - layout.type = nb::borrow(tp); - layout.num = tuple.size(); - this->layout.push_back(layout); + auto tp_name = nb::type_name(tp).c_str(); + jit_log(LogLevel::Debug, "FlatVariables::traverse(): %s {", tp_name); - for (nb::handle h2 : tuple) { - traverse(h2, ctx); - } - } else if (tp.is(&PyList_Type)) { - nb::list list = nb::borrow(h); + try { + if (is_drjit_type(tp)) { + const ArraySupplement &s = supp(tp); + if (s.is_tensor) { + nb::handle array = s.tensor_array(h.ptr()); Layout layout; - layout.type = nb::borrow(tp); - layout.num = list.size(); + layout.type = nb::borrow(tp); + layout.py_object = shape(h); + layout.literal = width(array); this->layout.push_back(layout); - for (nb::handle h2 : list) { - traverse(h2, ctx); - } - } else if (tp.is(&PyDict_Type)) { - nb::dict dict = nb::borrow(h); + traverse(nb::steal(array), ctx); + } else if (s.ndim != 1) { + Py_ssize_t len = s.shape[0]; + if (len == DRJIT_DYNAMIC) + len = s.len(inst_ptr(h)); Layout layout; layout.type = nb::borrow(tp); - layout.num = dict.size(); - layout.fields.reserve(layout.num); - for (auto k : dict.keys()) { - layout.fields.push_back(nb::borrow(k)); - } + layout.num = len; this->layout.push_back(layout); - for (auto [k, v] : dict) { - traverse(v, ctx); - } - } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { - - Layout layout; - layout.type = nb::borrow(tp); - layout.num = ds.size(); - layout.fields.reserve(layout.num); - for (auto k : ds.keys()) { - layout.fields.push_back(nb::borrow(k)); - } - this->layout.push_back(layout); + for (Py_ssize_t i = 0; i < len; ++i) + traverse(nb::steal(s.item(h.ptr(), i)), ctx); + } else { + traverse_ad_var(h, ctx); + } + } else if (tp.is(&PyTuple_Type)) { + nb::tuple tuple = nb::borrow(h); - for (auto [k, v] : ds) { - traverse(nb::getattr(h, k), ctx); - } - } else if (nb::object df = get_dataclass_fields(tp); - df.is_valid()) { + Layout layout; + layout.type = nb::borrow(tp); + layout.num = tuple.size(); + this->layout.push_back(layout); - Layout layout; - layout.type = nb::borrow(tp); - for (auto field : df) { - nb::object k = field.attr(DR_STR(name)); - layout.fields.push_back(nb::borrow(k)); - } - layout.num = layout.fields.size(); - this->layout.push_back(layout); + for (nb::handle h2 : tuple) { + traverse(h2, ctx); + } + } else if (tp.is(&PyList_Type)) { + nb::list list = nb::borrow(h); - for (nb::handle field : df) { - nb::object k = field.attr(DR_STR(name)); - traverse(nb::getattr(h, k), ctx); - } - } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { - ProfilerPhase profiler("traverse cb"); - Layout layout; - layout.type = nb::borrow(tp); - size_t layout_index = this->layout.size(); - this->layout.push_back(layout); + Layout layout; + layout.type = nb::borrow(tp); + layout.num = list.size(); + this->layout.push_back(layout); - uint32_t num_fields = 0; - - // Traverse the opaque C++ object - cb(h, nb::cpp_function([&](uint64_t index) { - if (!index) - return; - jit_log(LogLevel::Debug, - "traverse(): traverse_cb[%u] = a%u r%u", - num_fields, (uint32_t) (index >> 32), - (uint32_t) index); - num_fields++; - this->traverse_ad_index(index, ctx, nb::none()); - return; - })); + for (nb::handle h2 : list) { + traverse(h2, ctx); + } + } else if (tp.is(&PyDict_Type)) { + nb::dict dict = nb::borrow(h); - // Update layout number of fields - this->layout[layout_index].num = num_fields; - } else if (tp.is(&_PyNone_Type)) { - Layout layout; - layout.type = nb::borrow(tp); - this->layout.push_back(layout); - } else { - jit_log(LogLevel::Warn, - "traverse(): You passed a value to a frozen function, " - "that could not be converted to Dr.Jit types. This is " - "not recommended and the value will be cached.", - nb::type_name(tp).c_str()); + Layout layout; + layout.type = nb::borrow(tp); + layout.num = dict.size(); + layout.fields.reserve(layout.num); + for (auto k : dict.keys()) { + layout.fields.push_back(nb::borrow(k)); + } + this->layout.push_back(layout); - Layout layout; - layout.type = nb::borrow(tp); - layout.py_object = nb::borrow(h); - this->layout.push_back(layout); + for (auto [k, v] : dict) { + traverse(v, ctx); } - } catch (nb::python_error &e) { - nb::raise_from(e, PyExc_RuntimeError, - "FlatVariables::traverse(): error encountered while " - "processing an argument of type '%U' (see above).", - nb::type_name(tp).ptr()); - } catch (const std::exception &e) { - nb::chain_error(PyExc_RuntimeError, - "FlatVariables::traverse(): error encountered " - "while processing an argument of type '%U': %s", - nb::type_name(tp).ptr(), e.what()); - nb::raise_python_error(); - } + } else if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { - jit_log(LogLevel::Debug, "}"); - } + Layout layout; + layout.type = nb::borrow(tp); + layout.num = ds.size(); + layout.fields.reserve(layout.num); + for (auto k : ds.keys()) { + layout.fields.push_back(nb::borrow(k)); + } + this->layout.push_back(layout); - /** - * First traverses the PyTree, then the registry. This ensures that - * additional data to vcalls is tracked correctly. - */ - void traverse_with_registry(nb::handle h, TraverseContext &ctx) { + for (auto [k, v] : ds) { + traverse(nb::getattr(h, k), ctx); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { - // Traverse the handle - traverse(h, ctx); + Layout layout; + layout.type = nb::borrow(tp); + for (auto field : df) { + nb::object k = field.attr(DR_STR(name)); + layout.fields.push_back(nb::borrow(k)); + } + layout.num = layout.fields.size(); + this->layout.push_back(layout); - // Traverse the registry - { - ProfilerPhase profiler("traverse_registry"); + for (nb::handle field : df) { + nb::object k = field.attr(DR_STR(name)); + traverse(nb::getattr(h, k), ctx); + } + } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { + ProfilerPhase profiler("traverse cb"); Layout layout; - layout.type = nb::borrow(nb::none()); + layout.type = nb::borrow(tp); size_t layout_index = this->layout.size(); this->layout.push_back(layout); uint32_t num_fields = 0; - jit_log(LogLevel::Debug, "registry{"); - uint32_t registry_bound = jit_registry_id_bound(backend, nullptr); - std::vector registry_pointers; - registry_pointers.resize(registry_bound); - jit_registry_get_pointers(backend, registry_pointers.data()); - - jit_log(LogLevel::Debug, "registry_bound=%u", registry_bound); - jit_log(LogLevel::Debug, "layout_index=%u", this->layout.size()); - for (void *ptr : registry_pointers) { - jit_log(LogLevel::Debug, "ptr=%p", ptr); - if (!ptr) - continue; - - // WARN: very unsafe cast! - auto base = (nb::intrusive_base *) ptr; - auto self = base->self_py(); - - if (self) - traverse(self, ctx); - - const drjit::TraversableBase *traversable = - dynamic_cast(base); - - if (!traversable) { - int status; - jit_fail( - "Could not cast intrusive_base to TraversableBase! " - "The typename was: %s", - abi::__cxa_demangle(typeid(*base).name(), nullptr, - nullptr, &status)); - continue; - } - - traverse_cb(traversable, ctx); - num_fields++; - } - jit_log(LogLevel::Debug, "}"); - + // Traverse the opaque C++ object + cb(h, nb::cpp_function([&](uint64_t index) { + if (!index) + return; + jit_log(LogLevel::Debug, + "traverse(): traverse_cb[%u] = a%u r%u", num_fields, + (uint32_t) (index >> 32), (uint32_t) index); + num_fields++; + this->traverse_ad_index(index, ctx, nb::none()); + return; + })); + + // Update layout number of fields this->layout[layout_index].num = num_fields; + } else if (tp.is(&_PyNone_Type)) { + Layout layout; + layout.type = nb::borrow(tp); + this->layout.push_back(layout); + } else { + jit_log(LogLevel::Warn, + "traverse(): You passed a value to a frozen function, " + "that could not be converted to Dr.Jit types. This is " + "not recommended and the value will be cached.", + nb::type_name(tp).c_str()); + + Layout layout; + layout.type = nb::borrow(tp); + layout.py_object = nb::borrow(h); + this->layout.push_back(layout); } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::traverse(): error encountered while " + "processing an argument of type '%U' (see above).", + nb::type_name(tp).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::traverse(): error encountered " + "while processing an argument of type '%U': %s", + nb::type_name(tp).ptr(), e.what()); + nb::raise_python_error(); } - /** - * Construct a variable, given it's layout. - * This is the counterpart to `traverse_jit_index`. - */ - uint32_t construct_jit_index(const Layout &layout) { - if (layout.vs == VarState::Literal) { - uint32_t index = jit_var_literal(this->backend, layout.vt, - &layout.literal, layout.index); + jit_log(LogLevel::Debug, "}"); +} - return index; - } else { - uint32_t index = this->variables[layout.index]; - jit_log(LogLevel::Debug, " uses output[%u] = r%u", layout.index, - index); +/** + * First traverses the PyTree, then the registry. This ensures that + * additional data to vcalls is tracked correctly. + */ +void FlatVariables::traverse_with_registry(nb::handle h, TraverseContext &ctx) { - jit_var_inc_ref(index); + // Traverse the handle + traverse(h, ctx); - return index; - } - } + // Traverse the registry + { + ProfilerPhase profiler("traverse_registry"); + Layout layout; + layout.type = nb::borrow(nb::none()); + size_t layout_index = this->layout.size(); + this->layout.push_back(layout); - /** - * Construct/assign the variable index given a layout. - * This corresponds to `traverse_ad_index`> - * - * This function is also used for assignment to ad-variables. - * If a `prev_index` is provided, and it is an ad-variable the gradient and - * value of the flat variables will be applied to the ad variable, - * preserving the ad_idnex. - * - * It returns an owning reference. - */ - uint64_t construct_ad_index(const Layout &layout, uint32_t shrink = 0, - uint64_t prev_index = 0) { - uint64_t index; - if ((layout.flags & (uint32_t) LayoutFlag::GradEnabled) != 0) { - bool postponed = (layout.flags & (uint32_t) LayoutFlag::Postponed); - - Layout &val_layout = this->layout[layout_index++]; - uint32_t val = construct_jit_index(val_layout); - - Layout &grad_layout = this->layout[layout_index++]; - uint32_t grad = construct_jit_index(grad_layout); - - // Resize the gradient if it is a literal - if ((VarState) jit_var_state(grad) == VarState::Literal) { - uint32_t new_grad = jit_var_resize(grad, jit_var_size(val)); - jit_var_dec_ref(grad); - grad = new_grad; - } + uint32_t num_fields = 0; - // If the prev_index variable is provided we assign the new value - // and gradient to the ad variable of that index instead of creating - // a new one. - uint32_t ad_index = (uint32_t) (prev_index >> 32); - if (ad_index) { - index = (((uint64_t) ad_index) << 32) | ((uint64_t) val); - ad_var_inc_ref(index); - } else - index = ad_var_new(val); - - jit_log(LogLevel::Debug, " -> ad_var r%zu", index); - jit_var_dec_ref(val); - - // Equivalent to set_grad - ad_clear_grad(index); - ad_accum_grad(index, grad); - jit_var_dec_ref(grad); + jit_log(LogLevel::Debug, "registry{"); + uint32_t registry_bound = jit_registry_id_bound(backend, nullptr); + std::vector registry_pointers; + registry_pointers.resize(registry_bound); + jit_registry_get_pointers(backend, registry_pointers.data()); - // Variables, that have been postponed by the isolate gradient scope - // will be enqueued, which propagates their gradeint to previous - // functions. - if (ad_index && postponed) { - ad_enqueue(drjit::ADMode::Backward, index); + jit_log(LogLevel::Debug, "registry_bound=%u", registry_bound); + jit_log(LogLevel::Debug, "layout_index=%u", this->layout.size()); + for (void *ptr : registry_pointers) { + jit_log(LogLevel::Debug, "ptr=%p", ptr); + if (!ptr) + continue; + + // WARN: very unsafe cast! + auto base = (nb::intrusive_base *) ptr; + auto self = base->self_py(); + + if (self) + traverse(self, ctx); + + const drjit::TraversableBase *traversable = + dynamic_cast(base); + + if (!traversable) { + int status; + jit_fail("Could not cast intrusive_base to TraversableBase! " + "The typename was: %s", + abi::__cxa_demangle(typeid(*base).name(), nullptr, + nullptr, &status)); + continue; } - } else { - index = construct_jit_index(layout); + + traverse_cb(traversable, ctx); + num_fields++; } + jit_log(LogLevel::Debug, "}"); - if (shrink > 0) - index = ad_var_shrink(index, shrink); - return index; + this->layout[layout_index].num = num_fields; } +} - /** - * Construct an ad variable given it's layout. - * This corresponds to `traverse_ad_var` - */ - nb::object construct_ad_var(const Layout &layout, uint32_t shrink = 0) { - uint64_t index = construct_ad_index(layout, shrink); +/** + * Construct a variable, given it's layout. + * This is the counterpart to `traverse_jit_index`. + */ +uint32_t FlatVariables::construct_jit_index(const Layout &layout) { + if (layout.vs == VarState::Literal) { + uint32_t index = jit_var_literal(this->backend, layout.vt, + &layout.literal, layout.index); - auto result = nb::inst_alloc_zero(layout.type); - const ArraySupplement &s = supp(result.type()); - s.init_index(index, inst_ptr(result)); + return index; + } else { + uint32_t index = this->variables[layout.index]; + jit_log(LogLevel::Debug, " uses output[%u] = r%u", layout.index, + index); - // We have to release the reference, since assignment will borrow from - // it. - ad_var_dec_ref(index); + jit_var_inc_ref(index); - return result; + return index; } +} - /** - * This is the counterpart to the traverse method, used to construct the - * output of a frozen function. Given a layout vector and flat_variables, it - * re-constructs the PyTree. - */ - nb::object construct() { - if (this->layout.size() == 0) { - return nb::none(); +/** + * Construct/assign the variable index given a layout. + * This corresponds to `traverse_ad_index`> + * + * This function is also used for assignment to ad-variables. + * If a `prev_index` is provided, and it is an ad-variable the gradient and + * value of the flat variables will be applied to the ad variable, + * preserving the ad_idnex. + * + * It returns an owning reference. + */ +uint64_t FlatVariables::construct_ad_index(const Layout &layout, + uint32_t shrink, + uint64_t prev_index) { + uint64_t index; + if ((layout.flags & (uint32_t) LayoutFlag::GradEnabled) != 0) { + bool postponed = (layout.flags & (uint32_t) LayoutFlag::Postponed); + + Layout &val_layout = this->layout[layout_index++]; + uint32_t val = construct_jit_index(val_layout); + + Layout &grad_layout = this->layout[layout_index++]; + uint32_t grad = construct_jit_index(grad_layout); + + // Resize the gradient if it is a literal + if ((VarState) jit_var_state(grad) == VarState::Literal) { + uint32_t new_grad = jit_var_resize(grad, jit_var_size(val)); + jit_var_dec_ref(grad); + grad = new_grad; } - const Layout &layout = this->layout[layout_index++]; + // If the prev_index variable is provided we assign the new value + // and gradient to the ad variable of that index instead of creating + // a new one. + uint32_t ad_index = (uint32_t) (prev_index >> 32); + if (ad_index) { + index = (((uint64_t) ad_index) << 32) | ((uint64_t) val); + ad_var_inc_ref(index); + } else + index = ad_var_new(val); + + jit_log(LogLevel::Debug, " -> ad_var r%zu", index); + jit_var_dec_ref(val); - auto tp_name = nb::type_name(layout.type).c_str(); - jit_log(LogLevel::Debug, "FlatVariables::construct(): %s {", tp_name); + // Equivalent to set_grad + ad_clear_grad(index); + ad_accum_grad(index, grad); + jit_var_dec_ref(grad); - if (layout.type.is(nb::none().type())) { - return nb::none(); + // Variables, that have been postponed by the isolate gradient scope + // will be enqueued, which propagates their gradeint to previous + // functions. + if (ad_index && postponed) { + ad_enqueue(drjit::ADMode::Backward, index); } - try { - if (is_drjit_type(layout.type)) { - const ArraySupplement &s = supp(layout.type); - if (s.is_tensor) { - const Layout &array_layout = this->layout[layout_index++]; - nb::object array = - construct_ad_var(array_layout, layout.literal); - - return layout.type(array, layout.py_object); - } else if (s.ndim != 1) { - auto result = nb::inst_alloc_zero(layout.type); - dr::ArrayBase *p = inst_ptr(result); - size_t size = s.shape[0]; - if (size == DRJIT_DYNAMIC) { - size = s.len(p); - s.init(size, p); - } - for (size_t i = 0; i < size; ++i) { - result[i] = construct(); - } - return result; - } else { - return construct_ad_var(layout); - } - } else if (layout.type.is(&PyTuple_Type)) { - nb::list list; - for (uint32_t i = 0; i < layout.num; ++i) { - list.append(construct()); - } - return nb::tuple(list); - } else if (layout.type.is(&PyList_Type)) { - nb::list list; - for (uint32_t i = 0; i < layout.num; ++i) { - list.append(construct()); - } - return list; - } else if (layout.type.is(&PyDict_Type)) { - nb::dict dict; - for (auto k : layout.fields) { - dict[k] = construct(); - } - return dict; - } else if (nb::dict ds = get_drjit_struct(layout.type); - ds.is_valid()) { - nb::object tmp = layout.type(); - // TODO: validation against `ds` - for (auto k : layout.fields) { - nb::setattr(tmp, k, construct()); + } else { + index = construct_jit_index(layout); + } + + if (shrink > 0) + index = ad_var_shrink(index, shrink); + return index; +} + +/** + * Construct an ad variable given it's layout. + * This corresponds to `traverse_ad_var` + */ +nb::object FlatVariables::construct_ad_var(const Layout &layout, + uint32_t shrink) { + uint64_t index = construct_ad_index(layout, shrink); + + auto result = nb::inst_alloc_zero(layout.type); + const ArraySupplement &s = supp(result.type()); + s.init_index(index, inst_ptr(result)); + + // We have to release the reference, since assignment will borrow from + // it. + ad_var_dec_ref(index); + + return result; +} + +/** + * This is the counterpart to the traverse method, used to construct the + * output of a frozen function. Given a layout vector and flat_variables, it + * re-constructs the PyTree. + */ +nb::object FlatVariables::construct() { + if (this->layout.size() == 0) { + return nb::none(); + } + + const Layout &layout = this->layout[layout_index++]; + + auto tp_name = nb::type_name(layout.type).c_str(); + jit_log(LogLevel::Debug, "FlatVariables::construct(): %s {", tp_name); + + if (layout.type.is(nb::none().type())) { + return nb::none(); + } + try { + if (is_drjit_type(layout.type)) { + const ArraySupplement &s = supp(layout.type); + if (s.is_tensor) { + const Layout &array_layout = this->layout[layout_index++]; + nb::object array = + construct_ad_var(array_layout, layout.literal); + + return layout.type(array, layout.py_object); + } else if (s.ndim != 1) { + auto result = nb::inst_alloc_zero(layout.type); + dr::ArrayBase *p = inst_ptr(result); + size_t size = s.shape[0]; + if (size == DRJIT_DYNAMIC) { + size = s.len(p); + s.init(size, p); } - return tmp; - } else if (nb::object df = get_dataclass_fields(layout.type); - df.is_valid()) { - nb::dict dict; - for (auto k : layout.fields) { - dict[k] = construct(); + for (size_t i = 0; i < size; ++i) { + result[i] = construct(); } - return layout.type(**dict); + return result; } else { - if (layout.py_object.is_none()) { - nb::raise("Tried to construct a variable that is not " - "constructable!"); - } - return layout.py_object; + return construct_ad_var(layout); } - } catch (nb::python_error &e) { - nb::raise_from( - e, PyExc_RuntimeError, - "FlatVariables::construct(): error encountered while " - "processing an argument of type '%U' (see above).", - nb::type_name(layout.type).ptr()); - } catch (const std::exception &e) { - nb::chain_error(PyExc_RuntimeError, - "FlatVariables::construct(): error encountered " - "while processing an argument of type '%U': %s", - nb::type_name(layout.type).ptr(), e.what()); - nb::raise_python_error(); + } else if (layout.type.is(&PyTuple_Type)) { + nb::list list; + for (uint32_t i = 0; i < layout.num; ++i) { + list.append(construct()); + } + return nb::tuple(list); + } else if (layout.type.is(&PyList_Type)) { + nb::list list; + for (uint32_t i = 0; i < layout.num; ++i) { + list.append(construct()); + } + return list; + } else if (layout.type.is(&PyDict_Type)) { + nb::dict dict; + for (auto k : layout.fields) { + dict[k] = construct(); + } + return dict; + } else if (nb::dict ds = get_drjit_struct(layout.type); ds.is_valid()) { + nb::object tmp = layout.type(); + // TODO: validation against `ds` + for (auto k : layout.fields) { + nb::setattr(tmp, k, construct()); + } + return tmp; + } else if (nb::object df = get_dataclass_fields(layout.type); + df.is_valid()) { + nb::dict dict; + for (auto k : layout.fields) { + dict[k] = construct(); + } + return layout.type(**dict); + } else { + if (layout.py_object.is_none()) { + nb::raise("Tried to construct a variable that is not " + "constructable!"); + } + return layout.py_object; } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::construct(): error encountered while " + "processing an argument of type '%U' (see above).", + nb::type_name(layout.type).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::construct(): error encountered " + "while processing an argument of type '%U': %s", + nb::type_name(layout.type).ptr(), e.what()); + nb::raise_python_error(); } +} - /** - * Assigns an ad variable. - * Corresponds to `traverse_ad_var`. - * This uses `construct_ad_index` to either construct a new ad variable or - * assign the value and gradient to an already existing one. - */ - void assign_ad_var(Layout &layout, nb::handle dst) { - const ArraySupplement &s = supp(layout.type); - - uint64_t index; - if (s.index) { - // ``construct_ad_index`` is used for assignment - index = construct_ad_index(layout, 0, s.index(inst_ptr(dst))); - } else - index = construct_ad_index(layout); - - s.reset_index(index, inst_ptr(dst)); - jit_log(LogLevel::Debug, - "index=%zu, grad_enabled=%u, ad_grad_enabled=%u", index, - grad_enabled(dst), ad_grad_enabled(index)); - - // Release reference, since ``construct_ad_index`` returns owning - // reference and ``s.reset_index`` borrows from it. - ad_var_dec_ref(index); - } +/** + * Assigns an ad variable. + * Corresponds to `traverse_ad_var`. + * This uses `construct_ad_index` to either construct a new ad variable or + * assign the value and gradient to an already existing one. + */ +void FlatVariables::assign_ad_var(Layout &layout, nb::handle dst) { + const ArraySupplement &s = supp(layout.type); + + uint64_t index; + if (s.index) { + // ``construct_ad_index`` is used for assignment + index = construct_ad_index(layout, 0, s.index(inst_ptr(dst))); + } else + index = construct_ad_index(layout); + + s.reset_index(index, inst_ptr(dst)); + jit_log(LogLevel::Debug, "index=%zu, grad_enabled=%u, ad_grad_enabled=%u", + index, grad_enabled(dst), ad_grad_enabled(index)); + + // Release reference, since ``construct_ad_index`` returns owning + // reference and ``s.reset_index`` borrows from it. + ad_var_dec_ref(index); +} - /** - * Helper function, used to assign a callback variable. - * - * \param tmp - * This vector is populated with the indices to variables that have been - * constructed. It is required to release the references, since the - * references created by `construct_ad_index` are owning and they are - * borrowed after the callback returns. - */ - uint64_t assign_cb_internal(uint64_t index, index64_vector &tmp) { - if (!index) - return index; - Layout &layout = this->layout[layout_index++]; +/** + * Helper function, used to assign a callback variable. + * + * \param tmp + * This vector is populated with the indices to variables that have been + * constructed. It is required to release the references, since the + * references created by `construct_ad_index` are owning and they are + * borrowed after the callback returns. + */ +uint64_t FlatVariables::assign_cb_internal(uint64_t index, + index64_vector &tmp) { + if (!index) + return index; + Layout &layout = this->layout[layout_index++]; - uint64_t new_index = this->construct_ad_index(layout, 0, index); + uint64_t new_index = this->construct_ad_index(layout, 0, index); - if (layout.vt != (VarType) jit_var_type(index)) - jit_raise("VarType missmatch %u != %u while assigning (a%u, r%u) " - "-> (a%u, r%u)!", - (uint32_t) layout.vt, (uint32_t) jit_var_type(index), - (uint32_t) (index >> 32), (uint32_t) index, - (uint32_t) (new_index >> 32), (uint32_t) new_index); + if (layout.vt != (VarType) jit_var_type(index)) + jit_raise("VarType missmatch %u != %u while assigning (a%u, r%u) " + "-> (a%u, r%u)!", + (uint32_t) layout.vt, (uint32_t) jit_var_type(index), + (uint32_t) (index >> 32), (uint32_t) index, + (uint32_t) (new_index >> 32), (uint32_t) new_index); - tmp.push_back_steal(new_index); - return new_index; - } + tmp.push_back_steal(new_index); + return new_index; +} - /** - * Assigns variables using it's `traverse_cb_rw` callback. - * This corresponds to `traverse_cb`. - */ - void assign_cb(drjit::TraversableBase *traversable) { - Layout &layout = this->layout[layout_index++]; +/** + * Assigns variables using it's `traverse_cb_rw` callback. + * This corresponds to `traverse_cb`. + */ +void FlatVariables::assign_cb(drjit::TraversableBase *traversable) { + Layout &layout = this->layout[layout_index++]; - struct Payload { - FlatVariables *flat_vars; - index64_vector tmp; - uint32_t num_fields; - uint32_t field_counter; - }; - jit_log(LogLevel::Debug, " layout.num=%u", layout.num); - Payload payload{ this, index64_vector(), (uint32_t) layout.num, 0 }; - traversable->traverse_1_cb_rw((void *) &payload, [](void *p, - uint64_t index) { + struct Payload { + FlatVariables *flat_vars; + index64_vector tmp; + uint32_t num_fields; + uint32_t field_counter; + }; + jit_log(LogLevel::Debug, " layout.num=%u", layout.num); + Payload payload{ this, index64_vector(), (uint32_t) layout.num, 0 }; + traversable->traverse_1_cb_rw( + (void *) &payload, [](void *p, uint64_t index) { if (!index) return index; Payload *payload = (Payload *) p; @@ -948,184 +850,181 @@ struct FlatVariables { return payload->flat_vars->assign_cb_internal(index, payload->tmp); }); - if (payload.field_counter != layout.num) - jit_raise("While traversing and object for assigning inputs, the " - "number of variables to assign did not match the number " - "of variables traversed when recording!"); - } - - /** - * Assigns the flattened variables to an already existing PyTree. - * This is used when input variables have changed. - */ - void assign(nb::handle dst) { - nb::handle tp = dst.type(); - Layout &layout = this->layout[layout_index++]; - - auto tp_name = nb::type_name(tp).c_str(); - auto layout_tp_name = nb::type_name(layout.type).c_str(); - jit_log(LogLevel::Debug, "FlatVariables::assign(): %s with %s {", - tp_name, layout_tp_name); - - if (!layout.type.equal(tp)) - nb::raise( - "Type missmatch! Type of the object when recording %s does not " - "match type of object that is assigned %s.", - nb::type_name(tp).c_str(), nb::type_name(layout.type).c_str()); - - try { - if (is_drjit_type(tp)) { - const ArraySupplement &s = supp(tp); - - if (s.is_tensor) { - nb::handle array = s.tensor_array(dst.ptr()); - - Layout &array_layout = this->layout[layout_index++]; - - assign_ad_var(array_layout, array); - } else if (s.ndim != 1) { - Py_ssize_t len = s.shape[0]; - if (len == DRJIT_DYNAMIC) - len = s.len(inst_ptr(dst)); + if (payload.field_counter != layout.num) + jit_raise("While traversing and object for assigning inputs, the " + "number of variables to assign did not match the number " + "of variables traversed when recording!"); +} - for (Py_ssize_t i = 0; i < len; ++i) - assign(dst[i]); - } else { - assign_ad_var(layout, dst); - } - } else if (tp.is(&PyTuple_Type)) { - nb::tuple tuple = nb::borrow(dst); - raise_if(tuple.size() != layout.num, ""); - - for (nb::handle h2 : tuple) - assign(h2); - } else if (tp.is(&PyList_Type)) { - nb::list list = nb::borrow(dst); - raise_if(list.size() != layout.num, ""); - - for (nb::handle h2 : list) - assign(h2); - } else if (tp.is(&PyDict_Type)) { - nb::dict dict = nb::borrow(dst); - for (auto &k : layout.fields) { - if (dict.contains(&k)) - assign(dict[k]); - else - dst[k] = construct(); - } - } else if (nb::dict ds = get_drjit_struct(dst); ds.is_valid()) { - for (auto &k : layout.fields) { - if (nb::hasattr(dst, k)) - assign(nb::getattr(dst, k)); - else - nb::setattr(dst, k, construct()); - } - } else if (nb::object df = get_dataclass_fields(tp); - df.is_valid()) { - for (auto k : layout.fields) { - if (nb::hasattr(dst, k)) - assign(nb::getattr(dst, k)); - else - nb::setattr(dst, k, construct()); - } - } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { - index64_vector tmp; - uint32_t num_fields = 0; - - cb(dst, nb::cpp_function([&](uint64_t index) { - if (!index) - return index; - jit_log(LogLevel::Debug, - "assign(): traverse_cb[%u] was a%u r%u", - num_fields, (uint32_t) (index >> 32), - (uint32_t) index); - num_fields++; - if (num_fields > layout.num) - jit_raise( - "While traversing the object of type %s " - "for assigning inputs, the number of variables " - "to assign did not match the number of " - "variables traversed when recording!", - nb::str(tp).c_str()); - return assign_cb_internal(index, tmp); - })); - if (num_fields != layout.num) - jit_raise("While traversing the object of type %s " - "for assigning inputs, the number of variables " - "to assign did not match the number of variables " - "traversed when recording!", - nb::str(tp).c_str()); +/** + * Assigns the flattened variables to an already existing PyTree. + * This is used when input variables have changed. + */ +void FlatVariables::assign(nb::handle dst) { + nb::handle tp = dst.type(); + Layout &layout = this->layout[layout_index++]; + + auto tp_name = nb::type_name(tp).c_str(); + auto layout_tp_name = nb::type_name(layout.type).c_str(); + jit_log(LogLevel::Debug, "FlatVariables::assign(): %s with %s {", tp_name, + layout_tp_name); + + if (!layout.type.equal(tp)) + nb::raise( + "Type missmatch! Type of the object when recording %s does not " + "match type of object that is assigned %s.", + nb::type_name(tp).c_str(), nb::type_name(layout.type).c_str()); + + try { + if (is_drjit_type(tp)) { + const ArraySupplement &s = supp(tp); + + if (s.is_tensor) { + nb::handle array = s.tensor_array(dst.ptr()); + + Layout &array_layout = this->layout[layout_index++]; + + assign_ad_var(array_layout, array); + } else if (s.ndim != 1) { + Py_ssize_t len = s.shape[0]; + if (len == DRJIT_DYNAMIC) + len = s.len(inst_ptr(dst)); + + for (Py_ssize_t i = 0; i < len; ++i) + assign(dst[i]); } else { + assign_ad_var(layout, dst); } - } catch (nb::python_error &e) { - nb::raise_from(e, PyExc_RuntimeError, - "FlatVariables::assign(): error encountered while " - "processing an argument " - "of type '%U' (see above).", - nb::type_name(tp).ptr()); - } catch (const std::exception &e) { - nb::chain_error(PyExc_RuntimeError, - "FlatVariables::assign(): error encountered " - "while processing an argument " - "of type '%U': %s", - nb::type_name(tp).ptr(), e.what()); - nb::raise_python_error(); - } + } else if (tp.is(&PyTuple_Type)) { + nb::tuple tuple = nb::borrow(dst); + raise_if(tuple.size() != layout.num, ""); + + for (nb::handle h2 : tuple) + assign(h2); + } else if (tp.is(&PyList_Type)) { + nb::list list = nb::borrow(dst); + raise_if(list.size() != layout.num, ""); + + for (nb::handle h2 : list) + assign(h2); + } else if (tp.is(&PyDict_Type)) { + nb::dict dict = nb::borrow(dst); + for (auto &k : layout.fields) { + if (dict.contains(&k)) + assign(dict[k]); + else + dst[k] = construct(); + } + } else if (nb::dict ds = get_drjit_struct(dst); ds.is_valid()) { + for (auto &k : layout.fields) { + if (nb::hasattr(dst, k)) + assign(nb::getattr(dst, k)); + else + nb::setattr(dst, k, construct()); + } + } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { + for (auto k : layout.fields) { + if (nb::hasattr(dst, k)) + assign(nb::getattr(dst, k)); + else + nb::setattr(dst, k, construct()); + } + } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid()) { + index64_vector tmp; + uint32_t num_fields = 0; - jit_log(LogLevel::Debug, "}"); + cb(dst, nb::cpp_function([&](uint64_t index) { + if (!index) + return index; + jit_log(LogLevel::Debug, + "assign(): traverse_cb[%u] was a%u r%u", num_fields, + (uint32_t) (index >> 32), (uint32_t) index); + num_fields++; + if (num_fields > layout.num) + jit_raise( + "While traversing the object of type %s " + "for assigning inputs, the number of variables " + "to assign did not match the number of " + "variables traversed when recording!", + nb::str(tp).c_str()); + return assign_cb_internal(index, tmp); + })); + if (num_fields != layout.num) + jit_raise("While traversing the object of type %s " + "for assigning inputs, the number of variables " + "to assign did not match the number of variables " + "traversed when recording!", + nb::str(tp).c_str()); + } else { + } + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "FlatVariables::assign(): error encountered while " + "processing an argument " + "of type '%U' (see above).", + nb::type_name(tp).ptr()); + } catch (const std::exception &e) { + nb::chain_error(PyExc_RuntimeError, + "FlatVariables::assign(): error encountered " + "while processing an argument " + "of type '%U': %s", + nb::type_name(tp).ptr(), e.what()); + nb::raise_python_error(); } - /** - * First assigns the registry and then the PyTree. - * Corresponds to `traverse_with_registry`. - */ - void assign_with_registry(nb::handle dst) { - - // Assign the handle - assign(dst); - - // Assign registry - Layout &layout = this->layout[layout_index++]; - uint32_t num_fields = 0; - jit_log(LogLevel::Debug, "registry{"); - uint32_t registry_bound = jit_registry_id_bound(backend, nullptr); - std::vector registry_pointers; - registry_pointers.resize(registry_bound); - jit_registry_get_pointers(backend, registry_pointers.data()); - - jit_log(LogLevel::Debug, "registry_bound=%u", registry_bound); - jit_log(LogLevel::Debug, "layout_index=%u", this->layout_index); - for (void *ptr : registry_pointers) { - jit_log(LogLevel::Debug, "ptr=%p", ptr); - if (!ptr) - continue; - - // WARN: very unsafe cast! - auto base = (nb::intrusive_base *) ptr; - auto self = base->self_py(); - - if (self) - assign(self); + jit_log(LogLevel::Debug, "}"); +} - drjit::TraversableBase *traversable = - dynamic_cast(base); +/** + * First assigns the registry and then the PyTree. + * Corresponds to `traverse_with_registry`. + */ +void FlatVariables::assign_with_registry(nb::handle dst) { - if (!traversable) { - int status; - // TODO: should we put that behind the debug flag? - jit_raise("Could not cast intrusive_base to TraversableBase! " - "The typename was: %s", - abi::__cxa_demangle(typeid(*base).name(), nullptr, - nullptr, &status)); - continue; - } + // Assign the handle + assign(dst); - assign_cb(traversable); - num_fields++; + // Assign registry + Layout &layout = this->layout[layout_index++]; + uint32_t num_fields = 0; + jit_log(LogLevel::Debug, "registry{"); + uint32_t registry_bound = jit_registry_id_bound(backend, nullptr); + std::vector registry_pointers; + registry_pointers.resize(registry_bound); + jit_registry_get_pointers(backend, registry_pointers.data()); + + jit_log(LogLevel::Debug, "registry_bound=%u", registry_bound); + jit_log(LogLevel::Debug, "layout_index=%u", this->layout_index); + for (void *ptr : registry_pointers) { + jit_log(LogLevel::Debug, "ptr=%p", ptr); + if (!ptr) + continue; + + // WARN: very unsafe cast! + auto base = (nb::intrusive_base *) ptr; + auto self = base->self_py(); + + if (self) + assign(self); + + drjit::TraversableBase *traversable = + dynamic_cast(base); + + if (!traversable) { + int status; + // TODO: should we put that behind the debug flag? + jit_raise("Could not cast intrusive_base to TraversableBase! " + "The typename was: %s", + abi::__cxa_demangle(typeid(*base).name(), nullptr, + nullptr, &status)); + continue; } - jit_log(LogLevel::Debug, "}"); + + assign_cb(traversable); + num_fields++; } -}; + jit_log(LogLevel::Debug, "}"); +} std::ostream &operator<<(std::ostream &os, const FlatVariables &r) { std::string offset = " "; @@ -1427,19 +1326,6 @@ inline void hash_combine(size_t &seed, size_t value) { seed = b * mult; } -struct RecordingKey { - std::vector layout; - uint32_t flags; - - RecordingKey() {} - RecordingKey(std::vector layout, uint32_t flags) - : layout(std::move(layout)), flags(flags) {} - - bool operator==(const RecordingKey &rhs) const { - return this->layout == rhs.layout && this->flags == rhs.flags; - } -}; - std::ostream &operator<<(std::ostream &os, const RecordingKey &r) { std::string offset = " "; @@ -1461,268 +1347,213 @@ std::ostream &operator<<(std::ostream &os, const RecordingKey &r) { return os; } -struct RecordingKeyHasher { - size_t operator()(const RecordingKey &key) const { - // Hash the layout - size_t hash = key.layout.size(); - for (const Layout &layout : key.layout) { - hash_combine(hash, py_object_hash(layout.type)); - hash_combine(hash, layout.num); - hash_combine(hash, layout.fields.size()); - for (auto &field : layout.fields) { - hash_combine(hash, py_object_hash(field)); - } - hash_combine(hash, (size_t) layout.vt); - hash_combine(hash, (size_t) layout.vs); - hash_combine(hash, (size_t) layout.flags); - hash_combine(hash, (size_t) layout.literal); - hash_combine(hash, (size_t) layout.index); - hash_combine(hash, (size_t) layout.size_index); - hash_combine(hash, py_object_hash(layout.py_object)); +size_t RecordingKeyHasher::operator()(const RecordingKey &key) const { + // Hash the layout + size_t hash = key.layout.size(); + for (const Layout &layout : key.layout) { + hash_combine(hash, py_object_hash(layout.type)); + hash_combine(hash, layout.num); + hash_combine(hash, layout.fields.size()); + for (auto &field : layout.fields) { + hash_combine(hash, py_object_hash(field)); } - - hash_combine(hash, (size_t) key.flags); - - return hash; + hash_combine(hash, (size_t) layout.vt); + hash_combine(hash, (size_t) layout.vs); + hash_combine(hash, (size_t) layout.flags); + hash_combine(hash, (size_t) layout.literal); + hash_combine(hash, (size_t) layout.index); + hash_combine(hash, (size_t) layout.size_index); + hash_combine(hash, py_object_hash(layout.py_object)); } -}; - -struct FunctionRecording; - -using RecordingMap = - tsl::robin_map, - RecordingKeyHasher>; - -struct FrozenFunction { - nb::callable func; - - RecordingMap recordings; - RecordingKey prev_key; - uint32_t recording_counter = 0; - - FrozenFunction(nb::callable func) : func(func) {} - ~FrozenFunction() {} - FrozenFunction(const FrozenFunction &) = delete; - FrozenFunction &operator=(const FrozenFunction &) = delete; - FrozenFunction(FrozenFunction &&) = default; - FrozenFunction &operator=(FrozenFunction &&) = default; + hash_combine(hash, (size_t) key.flags); - uint32_t saved_recordings() { return this->recordings.size(); } - - nb::object operator()(nb::args args, nb::kwargs kwargs); -}; - -struct FunctionRecording { - Recording *recording = nullptr; - FlatVariables out_variables; - - FunctionRecording() : out_variables(false) {} - FunctionRecording(const FunctionRecording &) = delete; - FunctionRecording &operator=(const FunctionRecording &) = delete; - FunctionRecording(FunctionRecording &&) = default; - FunctionRecording &operator=(FunctionRecording &&) = default; + return hash; +} - ~FunctionRecording() { - if (this->recording) { - jit_freeze_destroy(this->recording); - } - this->recording = nullptr; +/* + * Record a function, given it's python input and flattened input. + */ +nb::object FunctionRecording::record(nb::callable func, + FrozenFunction *frozen_func, + nb::list input, + const FlatVariables &in_variables) { + ProfilerPhase profiler("record"); + JitBackend backend = in_variables.backend; + frozen_func->recording_counter++; + + jit_log(LogLevel::Info, + "Recording (n_inputs=%u):", in_variables.variables.size()); + jit_freeze_start(backend, in_variables.variables.data(), + in_variables.variables.size()); + + // Record the function + // bool tmp = jit_flag(JitFlag::KernelFreezing); + jit_set_flag(JitFlag::KernelFreezing, false); + nb::object output; + { + ProfilerPhase profiler("function"); + output = func(*input[0], **input[1]); } + jit_set_flag(JitFlag::KernelFreezing, true); - void clear() { - if (this->recording) { - jit_freeze_destroy(this->recording); - } - this->recording = nullptr; - this->out_variables = FlatVariables(false); - } + // output.append(result); + // output.append(input); - /* - * Record a function, given it's python input and flattened input. - */ - nb::object record(nb::callable func, FrozenFunction *frozen_func, - nb::list input, const FlatVariables &in_variables) { - ProfilerPhase profiler("record"); - JitBackend backend = in_variables.backend; - frozen_func->recording_counter++; - - jit_log(LogLevel::Info, - "Recording (n_inputs=%u):", in_variables.variables.size()); - jit_freeze_start(backend, in_variables.variables.data(), - in_variables.variables.size()); - - // Record the function - // bool tmp = jit_flag(JitFlag::KernelFreezing); - jit_set_flag(JitFlag::KernelFreezing, false); - nb::object output; + // Eval the input and output and it's gradients. + jit_log(LogLevel::Debug, "Evaluating output:"); + { + ProfilerPhase profiler("evaluate input + output"); + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); { - ProfilerPhase profiler("function"); - output = func(*input[0], **input[1]); + ProfilerPhase profiler("schedule input"); + deep_make_opaque(input, false, true); } - jit_set_flag(JitFlag::KernelFreezing, true); - - // output.append(result); - // output.append(input); - - // Eval the input and output and it's gradients. - jit_log(LogLevel::Debug, "Evaluating output:"); { - ProfilerPhase profiler("evaluate input + output"); - // Enter Resume scope, so we can track gradients - ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, - false); - { - ProfilerPhase profiler("schedule input"); - deep_make_opaque(input, false, true); - } - { - ProfilerPhase profiler("schedule output"); - deep_eval(output, false); - } - { - nb::gil_scoped_release guard; - jit_eval(); - } + ProfilerPhase profiler("schedule output"); + deep_eval(output, false); } - - // Pause recording before traversal as to not accidentally record - // unwanted operations. - // jit_freeze_pause(backend); - - // TODO: validate, that gradients wheren't enabled for inputs inside the - // frozen function. - - // Collect nodes, that have been postponed by the `Isolate` scope in a - // hash set. - // These are the targets of postponed edges, as the isolate gradient - // scope only handles backward mode differentiation. - // If they are, then we have to enqueue them when replaying the - // recording. - tsl::robin_set postponed; { - drjit::vector postponed_vec; - ad_scope_postponed(postponed_vec); - for (uint32_t index : postponed_vec) - postponed.insert(index); + nb::gil_scoped_release guard; + jit_eval(); } + } - jit_log(LogLevel::Info, "Traversing output"); - { - ProfilerPhase profiler("traverse output"); - // Enter Resume scope, so we can track gradients - ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, - false); + // Pause recording before traversal as to not accidentally record + // unwanted operations. + // jit_freeze_pause(backend); - TraverseContext ctx; - ctx.postponed = &postponed; - out_variables.traverse(output, ctx); - out_variables.traverse_with_registry(input, ctx); - } + // TODO: validate, that gradients wheren't enabled for inputs inside the + // frozen function. - if ((out_variables.variables.size() > 0 && - in_variables.variables.size() > 0) && - out_variables.backend != backend) { - Recording *recording = jit_freeze_stop(backend, nullptr, 0); - jit_freeze_destroy(recording); - - nb::raise("freeze(): backend missmatch error (backend %u of " - "output " - "variables did not match backend %u of input " - "variables)", - (uint32_t) out_variables.backend, (uint32_t) backend); - } + // Collect nodes, that have been postponed by the `Isolate` scope in a + // hash set. + // These are the targets of postponed edges, as the isolate gradient + // scope only handles backward mode differentiation. + // If they are, then we have to enqueue them when replaying the + // recording. + tsl::robin_set postponed; + { + drjit::vector postponed_vec; + ad_scope_postponed(postponed_vec); + for (uint32_t index : postponed_vec) + postponed.insert(index); + } - recording = jit_freeze_stop(backend, out_variables.variables.data(), - out_variables.variables.size()); + jit_log(LogLevel::Info, "Traversing output"); + { + ProfilerPhase profiler("traverse output"); + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + + TraverseContext ctx; + ctx.postponed = &postponed; + out_variables.traverse(output, ctx); + out_variables.traverse_with_registry(input, ctx); + } - jit_log(LogLevel::Info, "Recording done (n_outputs=%u)", - out_variables.variables.size()); + if ((out_variables.variables.size() > 0 && + in_variables.variables.size() > 0) && + out_variables.backend != backend) { + Recording *recording = jit_freeze_stop(backend, nullptr, 0); + jit_freeze_destroy(recording); + + nb::raise("freeze(): backend missmatch error (backend %u of " + "output " + "variables did not match backend %u of input " + "variables)", + (uint32_t) out_variables.backend, (uint32_t) backend); + } - // For catching input assignment missmatches, we asign the input and - // output - { - // Enter Resume scope, so we can track gradients - ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, - false); + recording = jit_freeze_stop(backend, out_variables.variables.data(), + out_variables.variables.size()); - out_variables.layout_index = 0; - jit_log(LogLevel::Debug, "Construct:"); - output = nb::borrow(out_variables.construct()); - // NOTE: temporarily disable this to not enqueue twice - // jit_log(LogLevel::Debug, "Assign:"); - // out_variables.assign(input); - out_variables.layout_index = 0; - } + jit_log(LogLevel::Info, "Recording done (n_outputs=%u)", + out_variables.variables.size()); - return output; + // For catching input assignment missmatches, we asign the input and + // output + { + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + + out_variables.layout_index = 0; + jit_log(LogLevel::Debug, "Construct:"); + output = nb::borrow(out_variables.construct()); + // NOTE: temporarily disable this to not enqueue twice + // jit_log(LogLevel::Debug, "Assign:"); + // out_variables.assign(input); + out_variables.layout_index = 0; } - /* - * Replays the recording. - * - * This constructs the output and re-assigns the input. - */ - nb::object replay(nb::callable func, FrozenFunction *frozen_func, - nb::list input, const FlatVariables &in_variables) { - ProfilerPhase profiler("replay"); - - jit_log(LogLevel::Info, "Replaying:"); - int dryrun_success; - { - ProfilerPhase profiler("dry run"); - dryrun_success = - jit_freeze_dry_run(recording, in_variables.variables.data()); - } - if (!dryrun_success) { - // Dry run has failed. Re-record the function. - jit_log(LogLevel::Warn, "re-recording"); - this->clear(); - try { - return this->record(func, frozen_func, input, in_variables); - } catch (nb::python_error &e) { - nb::raise_from( - e, PyExc_RuntimeError, - "replay(): error encountered while re-recording a " - "function (see above)."); - } catch (const std::exception &e) { - jit_freeze_abort(in_variables.backend); - nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); - nb::raise_python_error(); - } - } else { - ProfilerPhase profiler("jit replay"); - nb::gil_scoped_release guard; - jit_freeze_replay(recording, in_variables.variables.data(), - out_variables.variables.data()); + return output; +} +/* + * Replays the recording. + * + * This constructs the output and re-assigns the input. + */ +nb::object FunctionRecording::replay(nb::callable func, + FrozenFunction *frozen_func, + nb::list input, + const FlatVariables &in_variables) { + ProfilerPhase profiler("replay"); + + jit_log(LogLevel::Info, "Replaying:"); + int dryrun_success; + { + ProfilerPhase profiler("dry run"); + dryrun_success = + jit_freeze_dry_run(recording, in_variables.variables.data()); + } + if (!dryrun_success) { + // Dry run has failed. Re-record the function. + jit_log(LogLevel::Warn, "re-recording"); + this->clear(); + try { + return this->record(func, frozen_func, input, in_variables); + } catch (nb::python_error &e) { + nb::raise_from(e, PyExc_RuntimeError, + "replay(): error encountered while re-recording a " + "function (see above)."); + } catch (const std::exception &e) { + jit_freeze_abort(in_variables.backend); + + nb::chain_error(PyExc_RuntimeError, "record(): %s", e.what()); + nb::raise_python_error(); } - jit_log(LogLevel::Info, "Replaying done:"); + } else { + ProfilerPhase profiler("jit replay"); + nb::gil_scoped_release guard; + jit_freeze_replay(recording, in_variables.variables.data(), + out_variables.variables.data()); + } + jit_log(LogLevel::Info, "Replaying done:"); - // Construct Output variables - nb::object output; + // Construct Output variables + nb::object output; + { + // Enter Resume scope, so we can track gradients + ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, false); + out_variables.layout_index = 0; { - // Enter Resume scope, so we can track gradients - ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, -1, - false); - out_variables.layout_index = 0; - { - ProfilerPhase profiler("construct output"); - output = nb::borrow(out_variables.construct()); - } - { - ProfilerPhase profiler("assign input"); - out_variables.assign_with_registry(input); - } + ProfilerPhase profiler("construct output"); + output = nb::borrow(out_variables.construct()); } + { + ProfilerPhase profiler("assign input"); + out_variables.assign_with_registry(input); + } + } - // out_variables is assigned by jit_record_replay, which transfers - // ownership to this array. Therefore, we have to drop the variables - // afterwards. - out_variables.release(); + // out_variables is assigned by jit_record_replay, which transfers + // ownership to this array. Therefore, we have to drop the variables + // afterwards. + out_variables.release(); - return output; - } -}; + return output; +} nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) { nb::object result; diff --git a/src/python/freeze.h b/src/python/freeze.h index e3245585..ff19c6c0 100644 --- a/src/python/freeze.h +++ b/src/python/freeze.h @@ -11,8 +11,333 @@ #pragma once #include "common.h" +#include "drjit/autodiff.h" #include "functional" +#include +#include +#include + +using index64_vector = drjit::detail::index64_vector; + +enum class LayoutFlag : uint32_t { + SingletonArray = (1 << 0), + Unaligned = (1 << 1), + GradEnabled = (1 << 2), + Postponed = (1 << 3), + Registry = (1 << 4), +}; + +/// Stores information about python objects, such as their type, their number of +/// sub-elements or their field keys. This can be used to reconstruct a PyTree +/// from a flattened variable array. +struct Layout { + /// Nanobind type of the container/variable + nb::type_object type; + /// Number of members in this container. + /// Can be used to traverse the layout without knowing the type. + uint32_t num = 0; + /// Optional field identifiers of the container + /// for example: keys in dictionary + std::vector fields; + /// Optional drjit type of the variable + VarType vt = VarType::Void; + /// Optional evaluation state of the variable + VarState vs = VarState::Invalid; + uint32_t flags = 0; + /// The literal data + uint64_t literal = 0; + /// The index in the flat_variables array of this variable. + /// This can be used to determine aliasing. + uint32_t index = 0; + /// We have to track the condition, where two variables have the same size + /// during recording but don't when replaying. + /// Therefore we de-duplicate the size. + uint32_t size_index = 0; + + /// If a non drjit type is passed as function arguments or result, we simply + /// cache it here. + /// TODO: possibly do the same for literals? + nb::object py_object = nb::none(); + + bool operator==(const Layout &rhs) const; +}; + + +// Additional context required when traversing the inputs +struct TraverseContext { + /// Set of postponed ad nodes, used to mark inputs to functions. + const tsl::robin_set *postponed = nullptr; +}; + +/** + * A flattened representation of the PyTree. + */ +struct FlatVariables { + + // Index, used to iterate over the variables/layouts when constructing + // python objects + uint32_t layout_index = 0; + + /// The flattened and de-duplicated variable indices of the input/output to + /// a frozen function + std::vector variables; + /// Mapping from drjit variable index to index in flat variables + tsl::robin_map index_to_slot; + + /// We have to track the condition, where two variables have the same size + /// during recording but don't when replaying. + /// Therefore we construct equivalence classes of sizes. + /// This vector represents the different sizes, encountered during + /// traversal. The algorithm used to "add" a size is the same as for adding + /// a variable index. + std::vector sizes; + /// Mapping from the size to its index in the ``sizes`` vector. + tsl::robin_map size_to_slot; + + /// This saves information about the type, size and fields of pytree + /// objects. The information is stored in DFS order. + std::vector layout; + JitBackend backend = JitBackend::None; + + // Whether variables should be borrowed, instead of stealing them + bool borrow = true; + + FlatVariables() {} + FlatVariables(bool borrow) : borrow(borrow) {} + + void clear() { + this->layout_index = 0; + this->variables.clear(); + this->index_to_slot.clear(); + this->layout.clear(); + this->backend = JitBackend::None; + } + void release() { + for (uint32_t &index : this->variables) { + jit_var_dec_ref(index); + } + } + + /** + * Adds a variable to the flattened array, deduplicating it. + * This allows for checking for aliasing conditions, as aliasing inputs map + * to the same flat variable index. + */ + uint32_t add_variable_index(uint32_t variable_index); + + /** + * This function returns an index into the ``sizes`` vector, representing an + * equivalence class for the variable size. It uses a HashMap and vector to + * deduplicate sizes. + * + * This is necessary, to catch cases, where two variables had the same size + * when freezing a function and two different sizes when replaying. + * In that case one kernel would be recorded, that evaluates both variables. + * However, when replaying two kernels would have to be launched since the + * now differently sized variables cannot be evaluated by the same kernel. + */ + uint32_t add_size(uint32_t size); + + /** + * Traverse the variable referenced by a jit index and add it to the flat + * variables. An optional type python type can be supplied if it is known. + */ + void traverse_jit_index(uint32_t index, TraverseContext &ctx, + nb::handle tp = nb::none()); + /** + * Add an ad variable by it's index. Both the value and gradient are added + * to the flattened variables. If the ad index has been marked as postponed + * in the \c TraverseContext.postponed field, we mark the resulting layout + * with that flag. This will cause the gradient edges to be propagated when + * assigning to the input. The function takes an optional python-type if + * it is known. + */ + void traverse_ad_index(uint64_t index, TraverseContext &ctx, + nb::handle tp = nb::none()); + + /** + * Wrapper aground traverse_ad_index for a python handle. + */ + void traverse_ad_var(nb::handle h, TraverseContext &ctx); + + /** + * Traverse a c++ tree using it's `traverse_1_cb_ro` callback. + */ + void traverse_cb(const drjit::TraversableBase *traversable, + TraverseContext &ctx, nb::object type = nb::none()); + + /** + * Traverses a PyTree in DFS order, and records it's layout in the + * `layout` vector. + * + * When hitting a drjit primitive type, it calls the + * `traverse_dr_var` method, which will add their indices to the + * `flat_variables` vector. The collect method will also record metadata + * about the drjit variable in the layout. Therefore, the layout can be used + * as an identifier to the recording of the frozen function. + */ + void traverse(nb::handle h, TraverseContext &ctx); + + /** + * First traverses the PyTree, then the registry. This ensures that + * additional data to vcalls is tracked correctly. + */ + void traverse_with_registry(nb::handle h, TraverseContext &ctx); + + /** + * Construct a variable, given it's layout. + * This is the counterpart to `traverse_jit_index`. + */ + uint32_t construct_jit_index(const Layout &layout); + + /** + * Construct/assign the variable index given a layout. + * This corresponds to `traverse_ad_index`> + * + * This function is also used for assignment to ad-variables. + * If a `prev_index` is provided, and it is an ad-variable the gradient and + * value of the flat variables will be applied to the ad variable, + * preserving the ad_idnex. + * + * It returns an owning reference. + */ + uint64_t construct_ad_index(const Layout &layout, uint32_t shrink = 0, + uint64_t prev_index = 0); + + /** + * Construct an ad variable given it's layout. + * This corresponds to `traverse_ad_var` + */ + nb::object construct_ad_var(const Layout &layout, uint32_t shrink = 0); + + /** + * This is the counterpart to the traverse method, used to construct the + * output of a frozen function. Given a layout vector and flat_variables, it + * re-constructs the PyTree. + */ + nb::object construct(); + + /** + * Assigns an ad variable. + * Corresponds to `traverse_ad_var`. + * This uses `construct_ad_index` to either construct a new ad variable or + * assign the value and gradient to an already existing one. + */ + void assign_ad_var(Layout &layout, nb::handle dst); + + /** + * Helper function, used to assign a callback variable. + * + * \param tmp + * This vector is populated with the indices to variables that have been + * constructed. It is required to release the references, since the + * references created by `construct_ad_index` are owning and they are + * borrowed after the callback returns. + */ + uint64_t assign_cb_internal(uint64_t index, index64_vector &tmp); + + /** + * Assigns variables using it's `traverse_cb_rw` callback. + * This corresponds to `traverse_cb`. + */ + void assign_cb(drjit::TraversableBase *traversable); + + /** + * Assigns the flattened variables to an already existing PyTree. + * This is used when input variables have changed. + */ + void assign(nb::handle dst); + + /** + * First assigns the registry and then the PyTree. + * Corresponds to `traverse_with_registry`. + */ + void assign_with_registry(nb::handle dst); +}; + +struct RecordingKey { + std::vector layout; + uint32_t flags; + + RecordingKey() {} + RecordingKey(std::vector layout, uint32_t flags) + : layout(std::move(layout)), flags(flags) {} + + bool operator==(const RecordingKey &rhs) const { + return this->layout == rhs.layout && this->flags == rhs.flags; + } +}; + +struct RecordingKeyHasher { + size_t operator()(const RecordingKey &key) const; +}; + +struct FrozenFunction; + +struct FunctionRecording { + Recording *recording = nullptr; + FlatVariables out_variables; + + FunctionRecording() : out_variables(false) {} + FunctionRecording(const FunctionRecording &) = delete; + FunctionRecording &operator=(const FunctionRecording &) = delete; + FunctionRecording(FunctionRecording &&) = default; + FunctionRecording &operator=(FunctionRecording &&) = default; + + ~FunctionRecording() { + if (this->recording) { + jit_freeze_destroy(this->recording); + } + this->recording = nullptr; + } + + void clear() { + if (this->recording) { + jit_freeze_destroy(this->recording); + } + this->recording = nullptr; + this->out_variables = FlatVariables(false); + } + + /* + * Record a function, given it's python input and flattened input. + */ + nb::object record(nb::callable func, FrozenFunction *frozen_func, + nb::list input, const FlatVariables &in_variables); + /* + * Replays the recording. + * + * This constructs the output and re-assigns the input. + */ + nb::object replay(nb::callable func, FrozenFunction *frozen_func, + nb::list input, const FlatVariables &in_variables); +}; + +using RecordingMap = + tsl::robin_map, + RecordingKeyHasher>; + +struct FrozenFunction { + nb::callable func; + + RecordingMap recordings; + RecordingKey prev_key; + uint32_t recording_counter = 0; + + FrozenFunction(nb::callable func) : func(func) {} + ~FrozenFunction() {} + + FrozenFunction(const FrozenFunction &) = delete; + FrozenFunction &operator=(const FrozenFunction &) = delete; + FrozenFunction(FrozenFunction &&) = default; + FrozenFunction &operator=(FrozenFunction &&) = default; + + uint32_t saved_recordings() { return this->recordings.size(); } + + nb::object operator()(nb::args args, nb::kwargs kwargs); +}; + struct FrozenFunction; extern FrozenFunction freeze(nb::callable);