Skip to content

Commit

Permalink
ENH: properly format views with complex data types
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar committed Oct 14, 2024
1 parent e0cc7d4 commit 7bb41b3
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions include/views.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ RetT get_extents(Tp &m, std::index_sequence<Idx...>) {
template <typename Up, size_t Idx, typename Tp>
constexpr auto get_stride(Tp &m);

template <typename Tp>
inline std::string get_format() {
return py::format_descriptor<Tp>::format();
}

template <>
inline std::string get_format<Kokkos::complex<float>>() {
return py::format_descriptor<std::complex<float>>::format();
}

template <>
inline std::string get_format<Kokkos::complex<double>>() {
return py::format_descriptor<std::complex<double>>::format();
}

template <typename Up, typename Tp, size_t... Idx,
typename RetT = std::array<size_t, sizeof...(Idx)>>
RetT get_strides(Tp &m, std::index_sequence<Idx...>) {
Expand Down Expand Up @@ -322,12 +337,13 @@ void generate_view(py::module &_mod, const std::string &_name,
_view.def_buffer([_ndim](ViewT &m) -> py::buffer_info {
auto _extents = get_extents(m, std::make_index_sequence<DimIdx + 1>{});
auto _strides = get_stride<Tp>(m, std::make_index_sequence<DimIdx + 1>{});
auto _format = get_format<Tp>();
return py::buffer_info(m.data(), // Pointer to buffer
sizeof(Tp), // Size of one scalar
py::format_descriptor<Tp>::format(), // Descriptor
_ndim, // Number of dimensions
_extents, // Buffer dimensions
_strides // Strides (in bytes) for each index
_format, // Descriptor
_ndim, // Number of dimensions
_extents, // Buffer dimensions
_strides // Strides (in bytes) for each index
);
});

Expand Down

0 comments on commit 7bb41b3

Please sign in to comment.