/** * \file imperative/tablegen/targets/pybind11.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./pybind11.h" #include "../emitter.h" namespace mlir::tblgen { namespace { class OpDefEmitter final : public EmitterBase { public: OpDefEmitter(MgbOp& op_, raw_ostream& os_, Environment& env_) : EmitterBase(os_, env_), op(op_) {} void emit(); private: MgbOp& op; }; void OpDefEmitter::emit() { auto className = op.getCppClassName(); os << formatv( "py::class_<{0}, std::shared_ptr<{0}>, OpDef> {0}Inst(m, \"{0}\");\n\n", className); for (auto&& i : op.getMgbAttributes()) { if (auto attr = llvm::dyn_cast(&i.attr)) { unsigned int enumID; if (auto alias = llvm::dyn_cast(attr)) { auto&& aliasBase = alias->getAliasBase(); enumID = llvm::cast(aliasBase).getBaseRecord()->getID(); } else { enumID = attr->getBaseRecord()->getID(); } auto&& enumAlias = env().enumAlias; auto&& iter = enumAlias.find(enumID); if (iter == enumAlias.end()) { os << formatv( "py::enum_<{0}::{1}>({0}Inst, \"{1}\")", className, attr->getEnumName()); std::vector body; for (auto&& i : attr->getEnumMembers()) { size_t d1 = i.find(' '); size_t d2 = i.find('='); size_t d = d1 <= d2 ? d1 : d2; os << formatv( "\n .value(\"{2}\", {0}::{1}::{2})", className, attr->getEnumName(), i.substr(0, d)); body.push_back( formatv("if (str == \"{2}\") return {0}::{1}::{2};", className, attr->getEnumName(), i.substr(0, d))); } if (attr->getEnumCombinedFlag()) { //! define operator | os << formatv( "\n .def(\"__or__\", []({0}::{1} s0, {0}::{1} s1) {{ " "\n return static_cast<{0}::{1}>(uint32_t(s0) | " "uint32_t(s1));" "\n })", className, attr->getEnumName()); //! define operator & os << formatv( "\n .def(\"__and__\", []({0}::{1} s0, {0}::{1} s1) {{" "\n return static_cast<{0}::{1}>(uint32_t(s0) & " "uint32_t(s1));" "\n })", className, attr->getEnumName()); } os << formatv( "\n .def(py::init([](const std::string& in) {" "\n auto&& str = normalize_enum(in);" "\n {0}" "\n throw py::cast_error(\"invalid enum value \" + in);" "\n }));\n", llvm::join(body, "\n ")); os << formatv( "py::implicitly_convertible();\n\n", className, attr->getEnumName()); enumAlias.emplace( enumID, std::make_pair(className, attr->getEnumName())); } else { os << formatv( "{0}Inst.attr(\"{1}\") = {2}Inst.attr(\"{3}\");\n\n", className, attr->getEnumName(), iter->second.first, iter->second.second); } } } // generate op class binding os << formatv("{0}Inst", className); bool hasDefaultCtor = op.getMgbAttributes().empty(); if (!hasDefaultCtor) { os << "\n .def(py::init<"; std::vector targs; for (auto&& i : op.getMgbAttributes()) { targs.push_back(i.attr.getReturnType()); } os << llvm::join(targs, ", "); os << ", std::string>()"; for (auto&& i : op.getMgbAttributes()) { os << formatv(", py::arg(\"{0}\")", i.name); auto defaultValue = i.attr.getDefaultValue(); if (!defaultValue.empty()) { os << formatv(" = {0}", defaultValue); } else { hasDefaultCtor = true; } } os << ", py::arg(\"scope\") = {})"; } if (hasDefaultCtor) { os << "\n .def(py::init<>())"; } for (auto&& i : op.getMgbAttributes()) { os << formatv("\n .def_readwrite(\"{0}\", &{1}::{0})", i.name, className); } os << ";\n\n"; } } // namespace bool gen_op_def_pybind11(raw_ostream& os, llvm::RecordKeeper& keeper) { Environment env; using namespace std::placeholders; foreach_operator(keeper, [&](MgbOp& op) { OpDefEmitter(op, os, env).emit(); }); return false; } } // namespace mlir::tblgen