/** * \file imperative/tablegen/helper.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include #include #include #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/Operator.h" using llvm::formatv; using llvm::Record; using llvm::StringRef; #define ASSERT(stmt, msg) \ if (!(stmt)) { \ std::cerr << "\033[1;31m" \ << "tablegen autogen abort due to: " << msg << "\033[0m" \ << std::endl; \ exit(1); \ } namespace mlir { namespace tblgen { template struct MgbInterface : public ConcreteType { MgbInterface() = delete; MgbInterface(const MgbInterface&) = delete; MgbInterface(MgbInterface&&) = delete; ~MgbInterface() = delete; }; struct MgbAttrWrapperBase : public MgbInterface { private: struct RecordVisitor : public MgbInterface { public: static bool classof(const Constraint*) { return true; } const llvm::Record* getDef() const { return def; } }; public: static bool classof(const Attribute* attr) { return attr->isSubClassOf("MgbAttrWrapperBase"); } const llvm::Record* getBaseRecord() const { auto baseAttr = getBaseAttr(); return llvm::cast(baseAttr).getDef(); } llvm::StringRef getUnderlyingType() const { return def->getValueAsString("underlyingType"); } }; struct MgbEnumAttrMixin : public MgbAttrWrapperBase { static bool classof(const Attribute* attr) { return attr->getBaseAttr().isSubClassOf("MgbEnumAttrMixin"); } llvm::StringRef getParentNamespace() const { return getBaseRecord()->getValueAsString("parentNamespace"); } llvm::StringRef getEnumName() const { return getBaseRecord()->getValueAsString("enumName"); } std::vector getEnumMembers() const { return getBaseRecord()->getValueAsListOfStrings("enumMembers"); } bool supportToString() const { return getBaseRecord()->getValueAsBit("supportToString"); } bool getEnumCombinedFlag() const { return getBaseRecord()->getValueAsBit("enumCombined"); } }; struct MgbHashableAttrMixin : public MgbAttrWrapperBase { static bool classof(const Attribute* attr) { return attr->getBaseAttr().isSubClassOf("MgbHashableAttrMixin"); } llvm::StringRef getHashFunctionTemplate() const { return getBaseRecord()->getValueAsString("hashFunction"); } llvm::StringRef getCmpFunctionTemplate() const { return getBaseRecord()->getValueAsString("cmpFunction"); } llvm::StringRef getReprFunctionTemplate() const { return getBaseRecord()->getValueAsString("reprFunction"); } }; struct MgbAliasAttrMixin : public MgbAttrWrapperBase { static bool classof(const Attribute* attr) { return attr->getBaseAttr().isSubClassOf("MgbAliasAttrMixin"); } Attribute getAliasBase() const { return Attribute(getBaseRecord()->getValueAsDef("aliasBase")); } }; class MgbPackedParam { public: MgbPackedParam(Record* def_) : def(def_) { auto&& dag = def->getValueAsDag("fields"); for (size_t i = 0; i < dag->getNumArgs(); ++i) { fields.push_back( {dag->getArgNameStr(i), Attribute(llvm::cast(dag->getArg(i)))}); } } llvm::StringRef getFullName() const { return def->getValueAsString("fullName"); } std::vector getFields() const { return fields; } llvm::StringRef getAccessor() const { return def->getValueAsString("paramAccessor"); } private: std::vector fields; Record* def; }; struct MgbOpBase : public MgbInterface { static bool isPackedParam(Record* def) { return def->isSubClassOf("MgbPackedParamBase"); } public: static bool classof(const Operator* op) { return op->getDef().isSubClassOf("MgbOp"); } std::vector getMgbAttributes() const { std::vector ret; for (auto&& i : getAttributes()) { if (isa(i.attr)) { ret.push_back(i); } } return ret; } std::vector getExtraArguments() const { std::vector ret; auto&& dag = getDef().getValueAsDag("extraArguments"); for (size_t i = 0; i < dag->getNumArgs(); ++i) { ret.push_back( {dag->getArgNameStr(i), Attribute(llvm::cast(dag->getArg(i)))}); } return ret; } llvm::Optional getExtraOpdefDecl() const { return getDef().getValueAsOptionalString("extraOpdefDecl"); } std::vector getPackedParams() const { std::vector ret; for (auto&& i : getDef().getValueAsListOfDefs("dnnParams")) { if (isPackedParam(i)) { ret.emplace_back(i); } } return ret; } std::string getNameFunctionTemplate() const { if (auto f = getDef().getValueAsOptionalString("nameFunction")) { return f.getValue().str(); } return formatv(" return \"{0}\";\n", getCppClassName()); } }; struct MgbHashableOpMixin : public MgbOpBase { private: std::string getDefaultHashFunction() const { std::string body = " size_t val = mgb::hash($_self.dyn_typeinfo());\n"; if (!getMgbAttributes().empty()) { auto getHashFunc = [&](auto&& iter) { auto&& attr = llvm::cast(iter.attr); return attr.getHashFunctionTemplate(); }; mlir::tblgen::FmtContext ctx; for (auto&& it : getMgbAttributes()) { body += formatv(" val = mgb::hash_pair_combine(val, {0});\n", mlir::tblgen::tgfmt( getHashFunc(it), &ctx, "$_self." + it.name)); } } body += " return val;\n"; return body; } std::string getDefaultCmpFunction() const { std::string body; if (!getMgbAttributes().empty()) { mlir::tblgen::FmtContext ctx; for (auto&& it : getMgbAttributes()) { auto&& attr = llvm::cast(it.attr); body += formatv(" if ({0}) return false;\n", mlir::tblgen::tgfmt( attr.getCmpFunctionTemplate(), &ctx, "$0." + it.name, "$1." + it.name)); } } body += " return true;\n"; return body; } std::string getDefaultPropsFunction() const { std::string body = " std::vector> props_;\n"; if (!getMgbAttributes().empty()) { mlir::tblgen::FmtContext ctx; for (auto&& it : getMgbAttributes()) { if (auto* enumAttr = llvm::dyn_cast(&it.attr)) { body += formatv(" switch ({0}){{\n", "$_self." + it.name); for (auto&& enumMember : enumAttr->getEnumMembers()) { size_t d1 = enumMember.find(' '); size_t d2 = enumMember.find('='); size_t d = d1 <= d2 ? d1 : d2; body += formatv( " case {0}::{1}::{2}:\n", getCppClassName(), enumAttr->getEnumName(), enumMember.substr(0, d)); body += formatv(" props_.emplace_back(\"{0}\", " "\"{1}\");\n", it.name, enumMember.substr(0, d)); body += " break;\n"; } body += " default: break;\n"; body += " }\n"; } else { auto&& attr = llvm::cast(it.attr); body += formatv(" props_.emplace_back(\"{0}\", {1});\n", it.name, mlir::tblgen::tgfmt( attr.getReprFunctionTemplate(), &ctx, "$_self." + it.name)); } } } body += " return props_;\n"; return body; } public: static bool classof(const Operator* op) { return op->getDef().isSubClassOf("MgbHashableOpMixin"); } std::string getHashFunctionTemplate() const { if (auto f = getDef().getValueAsOptionalString("hashFunction")) { return f.getValue().str(); } return getDefaultHashFunction(); } std::string getCmpFunctionTemplate() const { if (auto f = getDef().getValueAsOptionalString("cmpFunction")) { return f.getValue().str(); } return getDefaultCmpFunction(); } std::string getPropsFunctionTemplate() const { if (auto f = getDef().getValueAsOptionalString("propsFunction")) { return f.getValue().str(); } return getDefaultPropsFunction(); } }; using MgbAttrWrapper = mlir::tblgen::MgbAttrWrapperBase; using MgbEnumAttr = mlir::tblgen::MgbEnumAttrMixin; using MgbHashableAttr = mlir::tblgen::MgbHashableAttrMixin; using MgbAliasAttr = mlir::tblgen::MgbAliasAttrMixin; using MgbOp = mlir::tblgen::MgbOpBase; using MgbHashableOp = mlir::tblgen::MgbHashableOpMixin; static inline void foreach_operator( llvm::RecordKeeper& keeper, std::function callback) { auto op_base_class = keeper.getClass("Op"); ASSERT(op_base_class, "could not find base class Op"); for (auto&& i : keeper.getDefs()) { auto&& r = i.second; if (r->isSubClassOf(op_base_class)) { auto op = mlir::tblgen::Operator(r.get()); if (op.getDialectName().str() == "mgb") { std::cerr << "\033[34;15m" << "Generating " << r->getName().str() << "\033[0m" << std::endl; callback(llvm::cast(op)); } } } } } // namespace tblgen } // namespace mlir