// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. // // This file is part of the AMD Render Pipeline Shaders SDK which is // released under the AMD INTERNAL EVALUATION LICENSE. // // See file LICENSE.RTF for full license details. #ifndef RPS_CALLBACK_WRAPPER_H #define RPS_CALLBACK_WRAPPER_H #include #include namespace rps { /// @brief Place holder for an unused argument of a callback function. /// /// It can be used to skip parameter marshalling during command node callbacks, /// while keeping the parameter ordinals match between the callback functions and node declarations. /// For example for node declaration: /// ``` /// node foo( rtv param0, srv param1 ); /// ``` /// If the callback function does not need to bind the render target param0 explicitly, it can be declared as: /// ``` /// void FooCallback( const RpsCmdCallbackContext* pContext, rps::UnusedArg unusedParam0, D3D12_CPU_DESCRIPTOR_HANDLE usedParam1 ); /// ``` /// So that the runtime will ignore unusedParam0, while still pass usedParam1 to the callback. /// /// @ingroup RpsRenderGraphCommandRecording struct UnusedArg { }; namespace details { template struct CommandArgUnwrapper { }; // Value types or const ref types template struct CommandArgUnwrapper::value>::type> { using ValueT = typename std::remove_cv::type>::type; const ValueT& operator()(const RpsCmdCallbackContext* pContext) { return *static_cast(pContext->ppArgs[Index]); } }; // Const pointer types template struct CommandArgUnwrapper< Index, T, typename std::enable_if::value && std::is_const::type>::value>::type> { T operator()(const RpsCmdCallbackContext* pContext) { return static_cast(pContext->ppArgs[Index]); } }; // Skipping unused args template struct CommandArgUnwrapper { rps::UnusedArg operator()(const RpsCmdCallbackContext* pContext) { return {}; } }; // Converting RpsBool to bool template struct CommandArgUnwrapper { bool operator()(const RpsCmdCallbackContext* pContext) { const RpsBool value = *static_cast(pContext->ppArgs[Index]); return !!value; } }; #if __cplusplus >= 201402L // Non-recursive argument unwrapping with index_sequence template struct FunctionWrapper { template static TRet Call(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext*, TArgs...), std::index_sequence) { return pFn(pContext, CommandArgUnwrapper()(pContext)...); } template static TRet Call(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext*, TArgs...), std::index_sequence) { return (pThis->*pFn)(pContext, CommandArgUnwrapper()(pContext)...); } template static TRet Call(const RpsCmdCallbackContext* pContext, std::function fn, std::index_sequence) { return fn(pContext, CommandArgUnwrapper()(pContext)...); } template static TRet Call(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext&, TArgs...), std::index_sequence) { return pFn(*pContext, CommandArgUnwrapper()(pContext)...); } template static TRet Call(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext&, TArgs...), std::index_sequence) { return (pThis->*pFn)(*pContext, CommandArgUnwrapper()(pContext)...); } template static TRet Call(const RpsCmdCallbackContext* pContext, std::function fn, std::index_sequence) { return fn(*pContext, CommandArgUnwrapper()(pContext)...); } }; template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext*, TArgs...)) { return FunctionWrapper::template Call<>(pContext, pFn, std::index_sequence_for{}); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext*, TArgs...)) { return FunctionWrapper::template Call( pContext, pThis, pFn, std::index_sequence_for{}); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, std::function fn) { return FunctionWrapper::template Call<>(pContext, fn, std::index_sequence_for{}); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext&, TArgs...)) { return FunctionWrapper::template Call<>(pContext, pFn, std::index_sequence_for{}); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext&, TArgs...)) { return FunctionWrapper::template Call( pContext, pThis, pFn, std::index_sequence_for{}); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, std::function fn) { return FunctionWrapper::template Call<>(pContext, fn, std::index_sequence_for{}); } #else //#if __cplusplus >= 201402L // Recursive argument unwrapping template struct FunctionWrapperRecursive { template static TRet Wrapped(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext*, TArgs...), TUnwrappedArgs&&... unwrappedArgs) { using TupleType = std::tuple; return FunctionWrapperRecursive::Wrapped( pContext, pFn, CommandArgUnwrapper::type>()(pContext), unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext*, TArgs...), TUnwrappedArgs&&... unwrappedArgs) { using TupleType = std::tuple; return FunctionWrapperRecursive::Wrapped( pContext, pThis, pFn, CommandArgUnwrapper::type>()(pContext), unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, std::function func, TUnwrappedArgs&&... unwrappedArgs) { using TupleType = std::tuple; return FunctionWrapperRecursive::Wrapped( pContext, func, CommandArgUnwrapper::type>()(pContext), unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext&, TArgs...), TUnwrappedArgs&&... unwrappedArgs) { using TupleType = std::tuple; return FunctionWrapperRecursive::Wrapped( pContext, pFn, CommandArgUnwrapper::type>()(pContext), unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext&, TArgs...), TUnwrappedArgs&&... unwrappedArgs) { using TupleType = std::tuple; return FunctionWrapperRecursive::Wrapped( pContext, pThis, pFn, CommandArgUnwrapper::type>()(pContext), unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, std::function func, TUnwrappedArgs&&... unwrappedArgs) { using TupleType = std::tuple; return FunctionWrapperRecursive::Wrapped( pContext, func, CommandArgUnwrapper::type>()(pContext), unwrappedArgs...); } }; template <> struct FunctionWrapperRecursive<0> { template static TRet Wrapped(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext*, TArgs...), TUnwrappedArgs&&... unwrappedArgs) { return pFn(pContext, unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext*, TArgs...), TUnwrappedArgs&&... unwrappedArgs) { return (pThis->*pFn)(pContext, unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, std::function func, TUnwrappedArgs&&... unwrappedArgs) { return func(pContext, unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext&, TArgs...), TUnwrappedArgs&&... unwrappedArgs) { return pFn(*pContext, *unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext&, TArgs...), TUnwrappedArgs&&... unwrappedArgs) { return (pThis->*pFn)(*pContext, unwrappedArgs...); } template static TRet Wrapped(const RpsCmdCallbackContext* pContext, std::function func, TUnwrappedArgs&&... unwrappedArgs) { return func(*pContext, unwrappedArgs...); } }; template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext*, TArgs...)) { return FunctionWrapperRecursive::Wrapped(pContext, pFn); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext*, TArgs...)) { return FunctionWrapperRecursive::Wrapped(pContext, pThis, pFn); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, std::function fn) { return FunctionWrapperRecursive::Wrapped(pContext, fn); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, TRet (*pFn)(const RpsCmdCallbackContext&, TArgs...)) { return FunctionWrapperRecursive::Wrapped(pContext, pFn); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, TClass* pThis, TRet (TClass::*pFn)(const RpsCmdCallbackContext&, TArgs...)) { return FunctionWrapperRecursive::Wrapped(pContext, pThis, pFn); } template TRet WrappedFunction(const RpsCmdCallbackContext* pContext, std::function fn) { return FunctionWrapperRecursive::Wrapped(pContext, fn); } #endif //#if __cplusplus >= 201402L template void WrappedFunction(const RpsCmdCallbackContext* pContext, TClass* pThis, std::nullptr_t n) { } template struct MemberNodeCallbackContext { TObject* target; TFunc method; MemberNodeCallbackContext(TObject* inTarget, TFunc inFunc) : target(inTarget) , method(inFunc) { } static void Callback(const RpsCmdCallbackContext* pContext) { auto pThis = static_cast*>(pContext->pCmdCallbackContext); details::WrappedFunction(pContext, pThis->target, pThis->method); } }; template struct NonMemberNodeCallbackContext { TFunc func; NonMemberNodeCallbackContext(TFunc inFunc) : func(inFunc) { } static void Callback(const RpsCmdCallbackContext* pContext) { auto pThis = static_cast*>(pContext->pCmdCallbackContext); details::WrappedFunction(pContext, pThis->func); } }; } // namespace details } // namespace rps #endif //RPS_CALLBACK_WRAPPER_H