// Copyright 2022 The IREE Authors // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "./invoke.h" #include #include #include "./hal.h" #include "./vm.h" #include "iree/base/api.h" #include "iree/hal/api.h" #include "iree/modules/hal/module.h" #include "iree/vm/api.h" namespace iree { namespace python { namespace { class InvokeContext { public: InvokeContext(HalDevice &device) : device_(device) {} HalDevice &device() { return device_; } HalAllocator allocator() { // TODO: Unfortunate that we inc ref here but that is how our object model // is set up. return HalAllocator::BorrowFromRawPtr(device().allocator()); } private: HalDevice device_; }; using PackCallback = std::function; class InvokeStatics { public: ~InvokeStatics() { for (auto it : py_type_to_pack_callbacks_) { py::handle(it.first).dec_ref(); } } py::str kNamedTag = py::str("named"); py::str kSlistTag = py::str("slist"); py::str kStupleTag = py::str("stuple"); py::str kSdictTag = py::str("sdict"); py::int_ kZero = py::int_(0); py::int_ kOne = py::int_(1); py::int_ kTwo = py::int_(2); py::str kAsArray = py::str("asarray"); py::str kMapDtypeToElementTypeAttr = py::str("map_dtype_to_element_type"); py::str kContiguousArg = py::str("C"); py::str kArrayProtocolAttr = py::str("__array__"); py::str kDtypeAttr = py::str("dtype"); // Primitive type names. py::str kF32 = py::str("f32"); py::str kF64 = py::str("f64"); py::str kI1 = py::str("i1"); py::str kI8 = py::str("i8"); py::str kI16 = py::str("i16"); py::str kI32 = py::str("i32"); py::str kI64 = py::str("i64"); // Compound types names. py::str kNdarray = py::str("ndarray"); // Attribute names. py::str kAttrBufferView = py::str("_buffer_view"); // Module 'numpy'. py::module_ &numpy_module() { return numpy_module_; } py::object &runtime_module() { if (!runtime_module_) { runtime_module_ = py::module_::import_("iree.runtime"); } return *runtime_module_; } py::module_ &array_interop_module() { if (!array_interop_module_) { array_interop_module_ = py::module_::import_("iree.runtime.array_interop"); } return *array_interop_module_; } py::object &device_array_type() { if (!device_array_type_) { device_array_type_ = runtime_module().attr("DeviceArray"); } return *device_array_type_; } py::type_object &hal_buffer_view_type() { return hal_buffer_view_type_; } py::object MapElementAbiTypeToDtype(py::object &element_abi_type) { try { return abi_type_to_dtype_[element_abi_type]; } catch (std::exception &) { std::string msg("could not map abi type "); msg.append(py::cast(py::repr(element_abi_type))); msg.append(" to numpy dtype"); throw std::invalid_argument(std::move(msg)); } } enum iree_hal_element_types_t MapDtypeToElementType(py::object dtype) { // TODO: Consider porting this from a py func to C++ as it can be on // the critical path. try { py::object element_type = array_interop_module().attr(kMapDtypeToElementTypeAttr)(dtype); if (element_type.is_none()) { throw std::invalid_argument("mapping not found"); } return py::cast(element_type); } catch (std::exception &e) { std::string msg("could not map dtype "); msg.append(py::cast(py::repr(dtype))); msg.append(" to element type: "); msg.append(e.what()); throw std::invalid_argument(std::move(msg)); } } PackCallback AbiTypeToPackCallback(py::handle desc) { return AbiTypeToPackCallback( std::move(desc), /*desc_is_list=*/py::isinstance(desc)); } // Given an ABI desc, return a callback that can pack a corresponding py // value into a list. For efficiency, the caller must specify whether the // desc is a list (this check already needs to be done typically so // passed in). PackCallback AbiTypeToPackCallback(py::handle desc, bool desc_is_list) { // Switch based on descriptor type. if (desc_is_list) { // Compound type. py::object compound_type = desc[kZero]; if (compound_type.equal(kNdarray)) { // Has format: // ["ndarray", "f32", dim0, dim1, ...] // Extract static information about the target. std::vector abi_shape(py::len(desc) - 2); for (size_t i = 0, e = abi_shape.size(); i < e; ++i) { py::handle dim = desc[py::int_(i + 2)]; abi_shape[i] = dim.is_none() ? -1 : py::cast(dim); } // Map abi element type to dtype. py::object abi_type = desc[kOne]; py::object target_dtype = MapElementAbiTypeToDtype(abi_type); auto hal_element_type = MapDtypeToElementType(target_dtype); return [this, target_dtype = std::move(target_dtype), hal_element_type, abi_shape = std::move(abi_shape)](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { IREE_TRACE_SCOPE_NAMED("ArgumentPacker::ReflectionNdarray"); HalBufferView *bv = nullptr; py::object retained_bv; if (is_instance_of_type_object(py_value, device_array_type())) { // Short-circuit: If a DeviceArray is provided, assume it is // correct. IREE_TRACE_SCOPE_NAMED("PackDeviceArray"); bv = py::cast(py_value.attr(kAttrBufferView)); } else if (is_instance_of_type_object(py_value, hal_buffer_view_type())) { // Short-circuit: If a HalBufferView is provided directly. IREE_TRACE_SCOPE_NAMED("PackBufferView"); bv = py::cast(py_value); } else { // Fall back to the array protocol to generate a host side // array and then convert that. IREE_TRACE_SCOPE_NAMED("PackHostArray"); py::object host_array; try { host_array = numpy_module().attr(kAsArray)(py_value, target_dtype, kContiguousArg); } catch (std::exception &e) { std::string msg("could not convert value to numpy array: dtype="); msg.append(py::cast(py::repr(target_dtype))); msg.append(", error='"); msg.append(e.what()); msg.append("', value="); msg.append(py::cast(py::repr(py_value))); throw std::invalid_argument(std::move(msg)); } retained_bv = c.allocator().AllocateBufferCopy( IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_MAPPING, c.device(), host_array, hal_element_type); bv = py::cast(retained_bv); } // TODO: Add some shape verification. Not strictly necessary as the VM // will check, but may make error reporting nicer. // TODO: It is theoretically possible to enqueue further conversions // on the device, but for now we require things to line up closely. // TODO: If adding further manipulation here, please make this common // with the generic access case. iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_retain_ref(bv->raw_ptr()); CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref), "could not push buffer view to list"); }; } else if (compound_type.equal(kSlistTag) || compound_type.equal(kStupleTag)) { // Tuple/list extraction. // When decoding a list or tuple, the desc object is like: // ['slist', [...value_type_0...], ...] // Where the type is either 'slist' or 'stuple'. std::vector sub_packers(py::len(desc) - 1); for (size_t i = 0; i < sub_packers.size(); i++) { sub_packers[i] = AbiTypeToPackCallback(desc[py::int_(i + 1)]); } return [sub_packers = std::move(sub_packers)](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { if (py::len(py_value) != sub_packers.size()) { std::string msg("expected a sequence with "); msg.append(std::to_string(sub_packers.size())); msg.append(" values. got: "); msg.append(py::cast(py::repr(py_value))); throw std::invalid_argument(std::move(msg)); } VmVariantList item_list = VmVariantList::Create(sub_packers.size()); for (size_t i = 0; i < sub_packers.size(); ++i) { py::object item_py_value; try { item_py_value = py_value[py::int_(i)]; } catch (std::exception &e) { std::string msg("could not get item "); msg.append(std::to_string(i)); msg.append(" from: "); msg.append(py::cast(py::repr(py_value))); msg.append(": "); msg.append(e.what()); throw std::invalid_argument(std::move(msg)); } sub_packers[i](c, item_list.raw_ptr(), item_py_value); } // Push the sub list. iree_vm_ref_t retained = iree_vm_list_move_ref(item_list.steal_raw_ptr()); iree_vm_list_push_ref_move(list, &retained); }; } else if (compound_type.equal(kSdictTag)) { // Dict extraction. // The descriptor for an sdict is like: // ['sdict', ['key1', value1], ...] std::vector> sub_packers( py::len(desc) - 1); for (size_t i = 0; i < sub_packers.size(); i++) { py::object sub_desc = desc[py::int_(i + 1)]; py::object key = sub_desc[kZero]; py::object value_desc = sub_desc[kOne]; sub_packers[i] = std::make_pair(std::move(key), AbiTypeToPackCallback(value_desc)); } return [sub_packers = std::move(sub_packers)](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { if (py::len(py_value) != sub_packers.size()) { std::string msg("expected a dict with "); msg.append(std::to_string(sub_packers.size())); msg.append(" values. got: "); msg.append(py::cast(py::repr(py_value))); throw std::invalid_argument(std::move(msg)); } VmVariantList item_list = VmVariantList::Create(sub_packers.size()); for (size_t i = 0; i < sub_packers.size(); ++i) { py::object item_py_value; try { item_py_value = py_value[sub_packers[i].first]; } catch (std::exception &e) { std::string msg("could not get item "); msg.append(py::cast(py::repr(sub_packers[i].first))); msg.append(" from: "); msg.append(py::cast(py::repr(py_value))); msg.append(": "); msg.append(e.what()); throw std::invalid_argument(std::move(msg)); } sub_packers[i].second(c, item_list.raw_ptr(), item_py_value); } // Push the sub list. iree_vm_ref_t retained = iree_vm_list_move_ref(item_list.steal_raw_ptr()); iree_vm_list_push_ref_move(list, &retained); }; } else { std::string message("Unrecognized reflection compound type: "); message.append(py::cast(compound_type)); throw std::invalid_argument(message); } } else { // Primitive type. py::str prim_type = py::cast(desc); if (prim_type.equal(kF32)) { // f32 return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { iree_vm_value_t vm_value = iree_vm_value_make_f32(py::cast(py_value)); CheckApiStatus(iree_vm_list_push_value(list, &vm_value), "could not append value"); }; } else if (prim_type.equal(kF64)) { // f64 return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { iree_vm_value_t vm_value = iree_vm_value_make_f64(py::cast(py_value)); CheckApiStatus(iree_vm_list_push_value(list, &vm_value), "could not append value"); }; } else if (prim_type.equal(kI32)) { // i32. return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { iree_vm_value_t vm_value = iree_vm_value_make_i32(py::cast(py_value)); CheckApiStatus(iree_vm_list_push_value(list, &vm_value), "could not append value"); }; } else if (prim_type.equal(kI64)) { // i64. return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { iree_vm_value_t vm_value = iree_vm_value_make_i64(py::cast(py_value)); CheckApiStatus(iree_vm_list_push_value(list, &vm_value), "could not append value"); }; } else if (prim_type.equal(kI8)) { // i8. return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { iree_vm_value_t vm_value = iree_vm_value_make_i8(py::cast(py_value)); CheckApiStatus(iree_vm_list_push_value(list, &vm_value), "could not append value"); }; } else if (prim_type.equal(kI16)) { // i16. return [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { iree_vm_value_t vm_value = iree_vm_value_make_i16(py::cast(py_value)); CheckApiStatus(iree_vm_list_push_value(list, &vm_value), "could not append value"); }; } else { std::string message("Unrecognized reflection primitive type: "); message.append(py::cast(prim_type)); throw std::invalid_argument(message); } } } PackCallback GetGenericPackCallbackFor(py::handle arg) { PopulatePyTypeToPackCallbacks(); py::handle clazz = arg.type(); auto found_it = py_type_to_pack_callbacks_.find(clazz.ptr()); if (found_it == py_type_to_pack_callbacks_.end()) { // Probe to see if we have a host array. if (py::hasattr(arg, kArrayProtocolAttr)) { return GetGenericPackCallbackForNdarray(); } return {}; } return found_it->second; } private: PackCallback GetGenericPackCallbackForNdarray() { return [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { IREE_TRACE_SCOPE_NAMED("ArgumentPacker::GenericNdarray"); py::object host_array; try { host_array = numpy_module().attr(kAsArray)( py_value, /*dtype=*/py::none(), kContiguousArg); } catch (std::exception &e) { std::string msg("could not convert value to numpy array: "); msg.append("error='"); msg.append(e.what()); msg.append("', value="); msg.append(py::cast(py::repr(py_value))); throw std::invalid_argument(std::move(msg)); } auto hal_element_type = MapDtypeToElementType(host_array.attr(kDtypeAttr)); // Put it on the device. py::object retained_bv = c.allocator().AllocateBufferCopy( IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, IREE_HAL_BUFFER_USAGE_DEFAULT | IREE_HAL_BUFFER_USAGE_MAPPING, c.device(), host_array, hal_element_type); HalBufferView *bv = py::cast(retained_bv); // TODO: If adding further manipulation here, please make this common // with the reflection access case. iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_retain_ref(bv->raw_ptr()); CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref), "could not append value"); }; } void PopulatePyTypeToPackCallbacks() { if (!py_type_to_pack_callbacks_.empty()) return; // We only care about int and double in the numeric hierarchy. Since Python // has no further refinement of these, just treat them as vm 64 bit int and // floats and let the VM take care of it. There isn't much else we can do. AddPackCallback( py::cast(1).type(), [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { iree_vm_value_t vm_value = iree_vm_value_make_i64(py::cast(py_value)); CheckApiStatus(iree_vm_list_push_value(list, &vm_value), "could not append value"); }); AddPackCallback( py::cast(1.0).type(), [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { iree_vm_value_t vm_value = iree_vm_value_make_f64(py::cast(py_value)); CheckApiStatus(iree_vm_list_push_value(list, &vm_value), "could not append value"); }); // List/tuple. auto sequence_callback = [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { auto py_seq = py::cast(py_value); VmVariantList item_list = VmVariantList::Create(py::len(py_seq)); for (py::handle py_item : py_seq) { PackCallback sub_packer = GetGenericPackCallbackFor(py_item); if (!sub_packer) { std::string message("could not convert python value to VM: "); message.append(py::cast(py::repr(py_item))); throw std::invalid_argument(std::move(message)); } sub_packer(c, item_list.raw_ptr(), py_item); } // Push the sub list. iree_vm_ref_t retained = iree_vm_list_move_ref(item_list.steal_raw_ptr()); iree_vm_list_push_ref_move(list, &retained); }; AddPackCallback((py::list{}).type(), sequence_callback); AddPackCallback((create_empty_tuple()).type(), sequence_callback); // Dict. auto dict_callback = [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { // Gets all dict items and sorts (by key). auto py_dict = py::cast(py_value); py::list py_keys; for (std::pair it : py_dict) { py_keys.append(it.first); } py_keys.attr("sort")(); VmVariantList item_list = VmVariantList::Create(py_keys.size()); for (auto py_key : py_keys) { py::object py_item = py_dict[py_key]; PackCallback sub_packer = GetGenericPackCallbackFor(py_item); if (!sub_packer) { std::string message("could not convert python value to VM: "); message.append(py::cast(py::repr(py_item))); throw std::invalid_argument(std::move(message)); } sub_packer(c, item_list.raw_ptr(), py_item); } // Push the sub list. iree_vm_ref_t retained = iree_vm_list_move_ref(item_list.steal_raw_ptr()); iree_vm_list_push_ref_move(list, &retained); }; AddPackCallback((py::dict{}).type(), dict_callback); // HalBufferView. AddPackCallback( py::type(), [](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { HalBufferView *bv = py::cast(py_value); iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_retain_ref(bv->raw_ptr()); CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref), "could not append value"); }); // DeviceArray. AddPackCallback( device_array_type(), [this](InvokeContext &c, iree_vm_list_t *list, py::handle py_value) { HalBufferView *bv = py::cast(py_value.attr(kAttrBufferView)); iree_vm_ref_t buffer_view_ref = iree_hal_buffer_view_retain_ref(bv->raw_ptr()); CheckApiStatus(iree_vm_list_push_ref_move(list, &buffer_view_ref), "could not append value"); }); } void AddPackCallback(py::handle t, PackCallback pcb) { assert(py_type_to_pack_callbacks_.count(t.ptr()) == 0 && "duplicate types"); t.inc_ref(); py_type_to_pack_callbacks_.insert(std::make_pair(t.ptr(), std::move(pcb))); } py::dict BuildAbiTypeToDtype() { auto d = py::dict(); d[kF32] = numpy_module().attr("float32"); d[kF64] = numpy_module().attr("float64"); d[kI1] = numpy_module().attr("bool_"); d[kI8] = numpy_module().attr("int8"); d[kI16] = numpy_module().attr("int16"); d[kI64] = numpy_module().attr("int64"); d[kI32] = numpy_module().attr("int32"); return d; } // Cached modules and types. Those that involve recursive lookup within // our top level module, we defer. Those outside, we cache at creation. py::module_ numpy_module_ = py::module_::import_("numpy"); std::optional runtime_module_; std::optional array_interop_module_; std::optional device_array_type_; py::type_object hal_buffer_view_type_ = py::cast(py::type()); // Maps Python type to a PackCallback that can generically code it. // This will have inc_ref() called on them when added. std::unordered_map py_type_to_pack_callbacks_; // Dict of str (ABI dtype like 'f32') to numpy dtype. py::dict abi_type_to_dtype_ = BuildAbiTypeToDtype(); }; /// Object that can pack Python arguments into a VM List for a specific /// function. class ArgumentPacker { public: ArgumentPacker(InvokeStatics &statics, std::optional arg_descs) : statics_(statics) { IREE_TRACE_SCOPE_NAMED("ArgumentPacker::Init"); if (!arg_descs) { dynamic_dispatch_ = true; } else { // Reflection dispatch. for (py::handle desc : *arg_descs) { int arg_index = flat_arg_packers_.size(); std::optional kwarg_name; py::object retained_sub_desc; bool desc_is_list = py::isinstance(desc); // Check if named. // ["named", "kwarg_name", sub_desc] // If found, then we set kwarg_name and reset desc to the sub_desc. if (desc_is_list) { py::object maybe_named_field = desc[statics.kZero]; if (maybe_named_field.equal(statics.kNamedTag)) { py::object name_field = desc[statics.kOne]; retained_sub_desc = desc[statics.kTwo]; kwarg_name = py::cast(name_field); desc = retained_sub_desc; desc_is_list = py::isinstance(desc); kwarg_to_index_[name_field] = arg_index; } } if (!kwarg_name) { pos_only_arg_count_ += 1; } flat_arg_packers_.push_back( statics.AbiTypeToPackCallback(desc, desc_is_list)); } } } /// Packs positional/kw arguments into a suitable VmVariantList and returns /// it. VmVariantList Pack(InvokeContext &invoke_context, py::sequence pos_args, py::dict kw_args) { // Dynamic dispatch. if (dynamic_dispatch_) { IREE_TRACE_SCOPE_NAMED("ArgumentPacker::PackDynamic"); if (kw_args.size() != 0) { throw std::invalid_argument( "kwargs not supported for dynamic dispatch functions"); } VmVariantList arg_list = VmVariantList::Create(py::len(pos_args)); for (py::handle py_arg : pos_args) { PackCallback packer = statics_.GetGenericPackCallbackFor(py_arg); if (!packer) { std::string message("could not convert python value to VM: "); message.append(py::cast(py::repr(py_arg))); throw std::invalid_argument(std::move(message)); } // TODO: Better error handling by catching the exception and // reporting which arg has a problem. packer(invoke_context, arg_list.raw_ptr(), py_arg); } return arg_list; } else { IREE_TRACE_SCOPE_NAMED("ArgumentPacker::PackReflection"); // Reflection based dispatch. std::vector py_args(flat_arg_packers_.size()); auto pos_args_size = py::len(pos_args); if (pos_args_size > pos_only_arg_count_) { std::string message("mismatched call arity: expected "); message.append(std::to_string(pos_only_arg_count_)); message.append(" got "); message.append(std::to_string(pos_args_size)); throw std::invalid_argument(std::move(message)); } // Positional args. size_t pos_index = 0; for (py::handle py_arg : pos_args) { py_args[pos_index++] = py_arg; } // Keyword args. for (auto it : kw_args) { int found_index; try { found_index = py::cast(kwarg_to_index_[it.first]); } catch (std::exception &) { std::string message("specified kwarg '"); message.append(py::cast(it.first)); message.append("' is unknown"); throw std::invalid_argument(std::move(message)); } if (py_args[found_index]) { std::string message( "mismatched call arity: duplicate keyword argument '"); message.append(py::cast(it.first)); message.append("'"); throw std::invalid_argument(std::move(message)); } py_args[found_index] = it.second; } // Now check to see that all args are set. for (size_t i = 0; i < py_args.size(); ++i) { if (!py_args[i]) { std::string message( "mismatched call arity: expected a value for argument "); message.append(std::to_string(i)); throw std::invalid_argument(std::move(message)); } } // Start packing into the list. VmVariantList arg_list = VmVariantList::Create(flat_arg_packers_.size()); for (size_t i = 0; i < py_args.size(); ++i) { // TODO: Better error handling by catching the exception and // reporting which arg has a problem. flat_arg_packers_[i](invoke_context, arg_list.raw_ptr(), py_args[i]); } return arg_list; } } private: InvokeStatics &statics_; int pos_only_arg_count_ = 0; // Dictionary of py::str -> py::int_ mapping kwarg names to position in // the argument list. We store this as a py::dict because it is optimized // for py::str lookup. py::dict kwarg_to_index_; std::vector flat_arg_packers_; // If true, then there is no dispatch metadata and we process fully // dynamically. bool dynamic_dispatch_ = false; }; } // namespace void SetupInvokeBindings(nanobind::module_ &m) { py::class_(m, "_InvokeStatics"); py::class_(m, "InvokeContext").def(py::init()); py::class_(m, "ArgumentPacker") .def(py::init>(), py::arg("statics"), py::arg("arg_descs") = py::none()) .def("pack", &ArgumentPacker::Pack); m.attr("_invoke_statics") = py::cast(InvokeStatics()); } } // namespace python } // namespace iree