// Copyright (c) 2021 Google LLC. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/opt/dataflow.h" #include #include #include "gmock/gmock.h" #include "gtest/gtest.h" #include "opt/function_utils.h" #include "source/opt/build_module.h" namespace spvtools { namespace opt { namespace { using DataFlowTest = ::testing::Test; // Simple analyses for testing: // Stores the result IDs of visited instructions in visit order. struct VisitOrder : public ForwardDataFlowAnalysis { std::vector visited_result_ids; VisitOrder(IRContext& context, LabelPosition label_position) : ForwardDataFlowAnalysis(context, label_position) {} VisitResult Visit(Instruction* inst) override { if (inst->HasResultId()) { visited_result_ids.push_back(inst->result_id()); } return DataFlowAnalysis::VisitResult::kResultFixed; } }; // For each block, stores the set of blocks it can be preceded by. // For example, with the following CFG: // V-----------. // -> 11 -> 12 -> 13 -> 15 // \-> 14 ---^ // // The answer is: // 11: 11, 12, 13 // 12: 11, 12, 13 // 13: 11, 12, 13 // 14: 11, 12, 13 // 15: 11, 12, 13, 14 struct BackwardReachability : public ForwardDataFlowAnalysis { std::map> reachable_from; BackwardReachability(IRContext& context) : ForwardDataFlowAnalysis( context, ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly) {} VisitResult Visit(Instruction* inst) override { // Conditional branches can be enqueued from labels, so skip them. if (inst->opcode() != spv::Op::OpLabel) return DataFlowAnalysis::VisitResult::kResultFixed; uint32_t id = inst->result_id(); VisitResult ret = DataFlowAnalysis::VisitResult::kResultFixed; std::set& precedents = reachable_from[id]; for (uint32_t pred : context().cfg()->preds(id)) { bool pred_inserted = precedents.insert(pred).second; if (pred_inserted) { ret = DataFlowAnalysis::VisitResult::kResultChanged; } for (uint32_t block : reachable_from[pred]) { bool inserted = precedents.insert(block).second; if (inserted) { ret = DataFlowAnalysis::VisitResult::kResultChanged; } } } return ret; } void InitializeWorklist(Function* function, bool is_first_iteration) override { // Since successor function is exact, only need one pass. if (is_first_iteration) { ForwardDataFlowAnalysis::InitializeWorklist(function, true); } } }; TEST_F(DataFlowTest, ReversePostOrder) { // Note: labels and IDs are intentionally out of order. // // CFG: (order of branches is from bottom to top) // V-----------. // -> 50 -> 40 -> 20 -> 60 -> 70 // \-> 30 ---^ // DFS tree with RPO numbering: // -> 50[0] -> 40[1] -> 20[2] 60[4] -> 70[5] // \-> 30[3] ---^ const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" OpExecutionMode %2 OriginUpperLeft OpSource GLSL 430 %3 = OpTypeVoid %4 = OpTypeFunction %3 %6 = OpTypeBool %5 = OpConstantTrue %6 %2 = OpFunction %3 None %4 %50 = OpLabel %51 = OpUndef %6 %52 = OpUndef %6 OpBranch %40 %70 = OpLabel %69 = OpUndef %6 OpReturn %60 = OpLabel %61 = OpUndef %6 OpBranchConditional %5 %70 %40 %30 = OpLabel %29 = OpUndef %6 OpBranch %60 %20 = OpLabel %21 = OpUndef %6 OpBranch %60 %40 = OpLabel %39 = OpUndef %6 OpBranchConditional %5 %30 %20 OpFunctionEnd )"; std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(context, nullptr); Function* function = spvtest::GetFunction(context->module(), 2); std::map> expected_order; expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsOnly] = { 50, 40, 20, 30, 60, 70, }; expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtBeginning] = { 50, 51, 52, 40, 39, 20, 21, 30, 29, 60, 61, 70, 69, }; expected_order[ForwardDataFlowAnalysis::LabelPosition::kLabelsAtEnd] = { 51, 52, 50, 39, 40, 21, 20, 29, 30, 61, 60, 69, 70, }; expected_order[ForwardDataFlowAnalysis::LabelPosition::kNoLabels] = { 51, 52, 39, 21, 29, 61, 69, }; for (const auto& test_case : expected_order) { VisitOrder analysis(*context, test_case.first); analysis.Run(function); EXPECT_EQ(test_case.second, analysis.visited_result_ids); } } TEST_F(DataFlowTest, BackwardReachability) { // CFG: // V-----------. // -> 11 -> 12 -> 13 -> 15 // \-> 14 ---^ const std::string text = R"( OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %2 "main" OpExecutionMode %2 OriginUpperLeft OpSource GLSL 430 %3 = OpTypeVoid %4 = OpTypeFunction %3 %6 = OpTypeBool %5 = OpConstantTrue %6 %2 = OpFunction %3 None %4 %11 = OpLabel OpBranch %12 %12 = OpLabel OpBranchConditional %5 %14 %13 %13 = OpLabel OpBranchConditional %5 %15 %11 %14 = OpLabel OpBranch %15 %15 = OpLabel OpReturn OpFunctionEnd )"; std::unique_ptr context = BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text, SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS); ASSERT_NE(context, nullptr); Function* function = spvtest::GetFunction(context->module(), 2); BackwardReachability analysis(*context); analysis.Run(function); std::map> expected_result; expected_result[11] = {11, 12, 13}; expected_result[12] = {11, 12, 13}; expected_result[13] = {11, 12, 13}; expected_result[14] = {11, 12, 13}; expected_result[15] = {11, 12, 13, 14}; EXPECT_EQ(expected_result, analysis.reachable_from); } } // namespace } // namespace opt } // namespace spvtools