// Copyright 2019 The IREE Authors // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #ifndef IREE_BINDINGS_PYTHON_IREE_RT_VM_H_ #define IREE_BINDINGS_PYTHON_IREE_RT_VM_H_ #include #include "./binding.h" #include "./status_utils.h" #include "iree/base/api.h" #include "iree/vm/api.h" #include "iree/vm/bytecode/module.h" namespace iree { namespace python { class FunctionAbi; //------------------------------------------------------------------------------ // Retain/release bindings //------------------------------------------------------------------------------ template <> struct ApiPtrAdapter { static void Retain(iree_vm_buffer_t* b) { iree_vm_buffer_retain(b); } static void Release(iree_vm_buffer_t* b) { iree_vm_buffer_release(b); } }; template <> struct ApiPtrAdapter { static void Retain(iree_vm_instance_t* b) { iree_vm_instance_retain(b); } static void Release(iree_vm_instance_t* b) { iree_vm_instance_release(b); } }; template <> struct ApiPtrAdapter { static void Retain(iree_vm_context_t* b) { iree_vm_context_retain(b); } static void Release(iree_vm_context_t* b) { iree_vm_context_release(b); } }; template <> struct ApiPtrAdapter { static void Retain(iree_vm_list_t* b) { iree_vm_list_retain(b); } static void Release(iree_vm_list_t* b) { iree_vm_list_release(b); } }; template <> struct ApiPtrAdapter { static void Retain(iree_vm_module_t* b) { iree_vm_module_retain(b); } static void Release(iree_vm_module_t* b) { iree_vm_module_release(b); } }; template <> struct ApiPtrAdapter { static void Retain(iree_vm_invocation_t* b) { iree_vm_invocation_retain(b); } static void Release(iree_vm_invocation_t* b) { iree_vm_invocation_release(b); } }; template <> struct ApiPtrAdapter { static void Retain(iree_vm_ref_t* b) { iree_vm_ref_t out_ref; std::memset(&out_ref, 0, sizeof(out_ref)); iree_vm_ref_retain(b, &out_ref); } static void Release(iree_vm_ref_t* b) { iree_vm_ref_release(b); } }; //------------------------------------------------------------------------------ // VmBuffer //------------------------------------------------------------------------------ class VmBuffer : public ApiRefCounted {}; //------------------------------------------------------------------------------ // VmVariantList // TODO: Rename to VmList //------------------------------------------------------------------------------ class VmVariantList : public ApiRefCounted { public: static VmVariantList Create(iree_host_size_t capacity) { iree_vm_list_t* list; CheckApiStatus( iree_vm_list_create(iree_vm_make_undefined_type_def(), capacity, iree_allocator_system(), &list), "Error allocating variant list"); return VmVariantList::StealFromRawPtr(list); } iree_host_size_t size() const { return iree_vm_list_size(raw_ptr()); } void AppendNullRef() { iree_vm_ref_t null_ref = {0}; CheckApiStatus(iree_vm_list_push_ref_move(raw_ptr(), &null_ref), "Error appending to list"); } std::string DebugString() const; void PushFloat(double fvalue); void PushInt(int64_t ivalue); void PushList(VmVariantList& other); void PushRef(py::handle ref_or_object); py::object GetAsList(int index); py::object GetAsRef(int index); py::object GetAsObject(int index, py::object clazz); py::object GetVariant(int index); py::object GetAsSerializedTraceValue(int index); }; //------------------------------------------------------------------------------ // VmInstance //------------------------------------------------------------------------------ class VmInstance : public ApiRefCounted { public: static VmInstance Create(); }; //------------------------------------------------------------------------------ // VmModule //------------------------------------------------------------------------------ class VmModule : public ApiRefCounted { public: static VmModule ResolveModuleDependency(VmInstance* instance, const std::string& name, uint32_t minimum_version); static VmModule WrapBuffer(VmInstance* instance, py::object buffer_obj, py::object destroy_callback, bool close_buffer); static VmModule MMap(VmInstance* instance, std::string filepath, py::object destroy_callback); static VmModule CopyBuffer(VmInstance* instance, py::object buffer_obj); static VmModule FromBuffer(VmInstance* instance, py::object buffer_obj, bool warn_if_copy); std::optional LookupFunction( const std::string& name, iree_vm_function_linkage_t linkage); std::string name() const { auto name_sv = iree_vm_module_name(raw_ptr()); return std::string(name_sv.data, name_sv.size); } py::object get_stashed_flatbuffer_blob() { return stashed_flatbuffer_blob; } private: // If the module was created from a FlatBuffer blob, we stash it here. // Note that this buffer will typically be captured here at the Python // instance level and within the deallocator of the native vm module. // Since this child field is destroyed first (before the base class wrapped // vm module), we ensure that there are no dangling references when // that deallocator is called. py::object stashed_flatbuffer_blob = py::none(); }; //------------------------------------------------------------------------------ // VmContext //------------------------------------------------------------------------------ class VmContext : public ApiRefCounted { public: // Creates a context, optionally with modules, which will make the context // static, disallowing further module registration (and may be more // efficient). static VmContext Create(VmInstance* instance, std::optional>& modules); // Registers additional modules. Only valid for non static contexts (i.e. // those created without modules. void RegisterModules(std::vector modules); // Unique id for this context. int context_id() const { return iree_vm_context_id(raw_ptr()); } // Synchronously invokes the given function. void Invoke(iree_vm_function_t f, VmVariantList& inputs, VmVariantList& outputs); }; class VmInvocation : public ApiRefCounted { }; //------------------------------------------------------------------------------ // VmRef (represents a pointer to an arbitrary reference object). //------------------------------------------------------------------------------ class VmRef { public: //---------------------------------------------------------------------------- // Binds the reference protocol to a VmRefObject bound class. // This defines three attributes: // __iree_vm_type__() // Gets the type from the object. // [readonly property] __iree_vm_ref__ : // Gets a VmRef from the object. // __iree_vm_cast__(ref) : // Dereferences the VmRef to the concrete type. Returns None on cast // failure. // // In addition, a user attribute of "ref" will be added that is an alias of // __iree_vm_ref__. // // An __eq__ method is added which returns true if the python objects refer // to the same vm object. // // The BindRefProtocol() helper is used on a py::class_ defined for a // reference object. It takes some of the C helper functions that are defined // for each type and is generic. //---------------------------------------------------------------------------- static const char* const kTypeAttr; static const char* const kRefAttr; static const char* const kCastAttr; template static void BindRefProtocol(PyClass& cls, TypeFunctor type, RetainRefFunctor retain_ref, DerefFunctor deref, IsaFunctor isa) { using WrapperType = typename PyClass::Type; auto ref_lambda = [=](WrapperType& self) { return VmRef::Steal(retain_ref(self.raw_ptr())); }; cls.def_static(VmRef::kTypeAttr, [=]() { return type(); }); cls.def_prop_ro(VmRef::kRefAttr, ref_lambda); cls.def_prop_ro("ref", ref_lambda); cls.def_static(VmRef::kCastAttr, [=](VmRef& ref) -> py::object { if (!isa(ref.ref())) { return py::none(); } return py::cast(WrapperType::BorrowFromRawPtr(deref(ref.ref())), py::rv_policy::move); }); cls.def("__eq__", [](WrapperType& self, WrapperType& other) { return self.raw_ptr() == other.raw_ptr(); }); cls.def("__eq__", [](WrapperType& self, py::object& other) { return false; }); } // Initializes a null ref. VmRef() { std::memset(&ref_, 0, sizeof(ref_)); } VmRef(VmRef&& other) : ref_(other.ref_) { std::memset(&other.ref_, 0, sizeof(other.ref_)); } VmRef(const VmRef&) = delete; VmRef& operator=(const VmRef&) = delete; ~VmRef() { if (ref_.ptr) { iree_vm_ref_release(&ref_); } } // Creates a VmRef from an owned ref, taking the reference count. static VmRef Steal(iree_vm_ref_t ref) { return VmRef(ref); } iree_vm_ref_t& ref() { return ref_; } py::object Deref(py::object ref_object_class, bool optional); bool IsInstance(py::object ref_object_class); std::string ToString(); private: // Initializes with an owned ref. VmRef(iree_vm_ref_t ref) : ref_(ref) {} iree_vm_ref_t ref_; }; void SetupVmBindings(nanobind::module_ m); } // namespace python } // namespace iree #endif // IREE_BINDINGS_PYTHON_IREE_RT_VM_H_