/* * This code is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This code is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this code; if not, write to the Free Software * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA */ /* Copyright (C) 2020-2023 Max-Planck-Society Author: Martin Reinecke */ #ifndef DUCC0_PYBIND_UTILS_H #define DUCC0_PYBIND_UTILS_H #include #include #include #include #include "ducc0/infra/error_handling.h" #include "ducc0/infra/mav.h" #include "ducc0/infra/misc_utils.h" namespace ducc0 { namespace detail_pybind { using shape_t=fmav_info::shape_t; using stride_t=fmav_info::stride_t; namespace py = pybind11; py::object normalizeDtype(const py::object &dtype) { static py::object converter = py::module_::import("numpy").attr("dtype"); return converter(dtype); } bool isPyarr(const py::object &obj) { return py::isinstance(obj); } template bool isPyarr(const py::object &obj) { return py::isinstance>(obj); } template py::array_t toPyarr(const py::object &obj) { auto tmp = obj.cast>(); MR_assert(tmp.is(obj), "error during array conversion"); return tmp; } shape_t copy_shape(const py::array &arr) { shape_t res(size_t(arr.ndim())); for (size_t i=0; i stride_t copy_strides(const py::array &arr, bool rw) { stride_t res(size_t(arr.ndim())); constexpr auto st = ptrdiff_t(sizeof(T)); for (size_t i=0; i std::array copy_fixshape(const py::array &arr) { MR_assert(size_t(arr.ndim())==ndim, "incorrect number of dimensions"); std::array res; for (size_t i=0; i std::array copy_fixstrides(const py::array &arr, bool rw) { MR_assert(size_t(arr.ndim())==ndim, "incorrect number of dimensions"); std::array res; constexpr auto st = ptrdiff_t(sizeof(T)); for (size_t i=0; i cfmav to_cfmav(const py::object &obj) { auto arr = toPyarr(obj); return cfmav(reinterpret_cast(arr.data()), copy_shape(arr), copy_strides(arr, false)); } template vfmav to_vfmav(const py::object &obj) { auto arr = toPyarr(obj); return vfmav(reinterpret_cast(arr.mutable_data()), copy_shape(arr), copy_strides(arr, true)); } template cmav to_cmav(const py::array &obj) { auto arr = toPyarr(obj); return cmav(reinterpret_cast(arr.data()), copy_fixshape(arr), copy_fixstrides(arr, false)); } template cmav to_cmav_with_optional_leading_dimensions(const py::array &obj) { auto tmp = to_cfmav(obj); MR_assert(tmp.ndim()<=ndim, "array has too many dimensions"); typename cmav::shape_t newshape; typename cmav::stride_t newstride; size_t add=ndim-tmp.ndim(); for (size_t i=0; i(tmp.data(), newshape, newstride); } template vmav to_vmav(const py::array &obj) { auto arr = toPyarr(obj); return vmav(reinterpret_cast(arr.mutable_data()), copy_fixshape(arr), copy_fixstrides(arr, true)); } template vmav to_vmav_with_optional_leading_dimensions(const py::array &obj) { auto tmp = to_vfmav(obj); MR_assert(tmp.ndim()<=ndim, "array has too many dimensions"); typename vmav::shape_t newshape; typename vmav::stride_t newstride; size_t add=ndim-tmp.ndim(); for (size_t i=0; i(tmp.data(), newshape, newstride); } template array to_array(const py::object &obj) { auto vec = py::cast>(obj); MR_assert(vec.size()==len, "unexpected number of elements"); array res; for (size_t i=0;i void zero_Pyarr(py::array_t &arr, size_t nthreads=1) { auto arr2 = to_vfmav(arr); mav_apply([](T &v){ v=T(0); }, nthreads, arr2); } template py::array_t make_Pyarr(const shape_t &dims, bool zero=false) { auto res=py::array_t(dims); if (zero) zero_Pyarr(res); return res; } template py::array_t make_Pyarr (const std::array &dims, bool zero=false) { auto res=py::array_t(shape_t(dims.begin(), dims.end())); if (zero) zero_Pyarr(res); return res; } template py::array_t make_noncritical_Pyarr(const shape_t &shape) { auto ndim = shape.size(); if (ndim==1) return make_Pyarr(shape); auto shape2 = noncritical_shape(shape, sizeof(T)); py::array_t tarr(shape2); py::list slices; for (size_t i=0; i sub(tarr[py::tuple(slices)]); return sub; } template py::array_t get_Pyarr(py::object &arr_, size_t ndims) { MR_assert(isPyarr(arr_), "incorrect data type"); auto tmp = toPyarr(arr_); MR_assert(ndims==size_t(tmp.ndim()), "dimension mismatch"); return tmp; } template py::array_t get_optional_Pyarr(py::object &arr_, const shape_t &dims, bool zero_if_new=false) { if (arr_.is_none()) return make_Pyarr(dims, zero_if_new); MR_assert(isPyarr(arr_), "incorrect data type"); auto tmp = toPyarr(arr_); MR_assert(dims.size()==size_t(tmp.ndim()), "dimension mismatch"); for (size_t i=0; i py::array_t get_optional_Pyarr_minshape (py::object &arr_, const shape_t &dims) { if (arr_.is_none()) return make_Pyarr(dims); MR_assert(isPyarr(arr_), "incorrect data type"); auto tmp = toPyarr(arr_); MR_assert(dims.size()==size_t(tmp.ndim()), "dimension mismatch"); for (size_t i=0; i py::array_t get_optional_const_Pyarr( const py::object &arr_, const shape_t &dims) { if (arr_.is_none()) return py::array_t(shape_t(dims.size(), 0)); MR_assert(isPyarr(arr_), "incorrect data type"); auto tmp = toPyarr(arr_); MR_assert(dims.size()==size_t(tmp.ndim()), "dimension mismatch"); for (size_t i=0; i bool isDtype(const py::object &dtype) { static const auto tmp = make_Pyarr({}).dtype(); return tmp.is(dtype); } } using detail_pybind::isPyarr; using detail_pybind::make_Pyarr; using detail_pybind::make_noncritical_Pyarr; using detail_pybind::get_Pyarr; using detail_pybind::get_optional_Pyarr; using detail_pybind::get_optional_Pyarr_minshape; using detail_pybind::get_optional_const_Pyarr; using detail_pybind::to_cfmav; using detail_pybind::to_vfmav; using detail_pybind::to_cmav; using detail_pybind::to_cmav_with_optional_leading_dimensions; using detail_pybind::to_vmav; using detail_pybind::to_vmav_with_optional_leading_dimensions; using detail_pybind::to_array; using detail_pybind::normalizeDtype; using detail_pybind::isDtype; } #endif