// Copyright (c) 2016 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "source/opt/type_manager.h" #include #include #include #include #include "source/opt/ir_context.h" #include "source/opt/log.h" #include "source/opt/reflect.h" #include "source/util/make_unique.h" #include "source/util/string_utils.h" namespace spvtools { namespace opt { namespace analysis { namespace { constexpr int kSpvTypePointerStorageClass = 1; constexpr int kSpvTypePointerTypeIdInIdx = 2; } // namespace TypeManager::TypeManager(const MessageConsumer& consumer, IRContext* c) : consumer_(consumer), context_(c) { AnalyzeTypes(*c->module()); } Type* TypeManager::GetType(uint32_t id) const { auto iter = id_to_type_.find(id); if (iter != id_to_type_.end()) return (*iter).second; iter = id_to_incomplete_type_.find(id); if (iter != id_to_incomplete_type_.end()) return (*iter).second; return nullptr; } std::pair> TypeManager::GetTypeAndPointerType( uint32_t id, spv::StorageClass sc) const { Type* type = GetType(id); if (type) { return std::make_pair(type, MakeUnique(type, sc)); } else { return std::make_pair(type, std::unique_ptr()); } } uint32_t TypeManager::GetId(const Type* type) const { auto iter = type_to_id_.find(type); if (iter != type_to_id_.end()) { return (*iter).second; } return 0; } void TypeManager::AnalyzeTypes(const Module& module) { // First pass through the constants, as some will be needed when traversing // the types in the next pass. for (const auto* inst : module.GetConstants()) { id_to_constant_inst_[inst->result_id()] = inst; } // Then pass through the types. Any types that reference a forward pointer // (directly or indirectly) are incomplete, and are added to incomplete types. for (const auto* inst : module.GetTypes()) { RecordIfTypeDefinition(*inst); } if (incomplete_types_.empty()) { return; } // Get the real pointer definition for all of the forward pointers. for (auto& type : incomplete_types_) { if (type.type()->kind() == Type::kForwardPointer) { auto* t = GetType(type.id()); assert(t); auto* p = t->AsPointer(); assert(p); type.type()->AsForwardPointer()->SetTargetPointer(p); } } // Replaces the references to the forward pointers in the incomplete types. for (auto& type : incomplete_types_) { ReplaceForwardPointers(type.type()); } // Delete the forward pointers now that they are not referenced anymore. for (auto& type : incomplete_types_) { if (type.type()->kind() == Type::kForwardPointer) { type.ResetType(nullptr); } } // Compare the complete types looking for types that are the same. If there // are two types that are the same, then replace one with the other. // Continue until we reach a fixed point. bool restart = true; while (restart) { restart = false; for (auto it1 = incomplete_types_.begin(); it1 != incomplete_types_.end(); ++it1) { uint32_t id1 = it1->id(); Type* type1 = it1->type(); if (!type1) { continue; } for (auto it2 = it1 + 1; it2 != incomplete_types_.end(); ++it2) { uint32_t id2 = it2->id(); (void)(id2 + id1); Type* type2 = it2->type(); if (!type2) { continue; } if (type1->IsSame(type2)) { ReplaceType(type1, type2); it2->ResetType(nullptr); id_to_incomplete_type_[it2->id()] = type1; restart = true; } } } } // Add the remaining incomplete types to the type pool. for (auto& type : incomplete_types_) { if (type.type() && !type.type()->AsForwardPointer()) { std::vector decorations = context()->get_decoration_mgr()->GetDecorationsFor(type.id(), true); for (auto dec : decorations) { AttachDecoration(*dec, type.type()); } auto pair = type_pool_.insert(type.ReleaseType()); id_to_type_[type.id()] = pair.first->get(); type_to_id_[pair.first->get()] = type.id(); id_to_incomplete_type_.erase(type.id()); } } // Add a mapping for any ids that whose original type was replaced by an // equivalent type. for (auto& type : id_to_incomplete_type_) { id_to_type_[type.first] = type.second; } #ifndef NDEBUG // Check if the type pool contains two types that are the same. This // is an indication that the hashing and comparison are wrong. It // will cause a problem if the type pool gets resized and everything // is rehashed. for (auto& i : type_pool_) { for (auto& j : type_pool_) { Type* ti = i.get(); Type* tj = j.get(); assert((ti == tj || !ti->IsSame(tj)) && "Type pool contains two types that are the same."); } } #endif } void TypeManager::RemoveId(uint32_t id) { auto iter = id_to_type_.find(id); if (iter == id_to_type_.end()) return; auto& type = iter->second; if (!type->IsUniqueType(true)) { auto tIter = type_to_id_.find(type); if (tIter != type_to_id_.end() && tIter->second == id) { // |type| currently maps to |id|. // Search for an equivalent type to re-map. bool found = false; for (auto& pair : id_to_type_) { if (pair.first != id && *pair.second == *type) { // Equivalent ambiguous type, re-map type. type_to_id_.erase(type); type_to_id_[pair.second] = pair.first; found = true; break; } } // No equivalent ambiguous type, remove mapping. if (!found) type_to_id_.erase(tIter); } } else { // Unique type, so just erase the entry. type_to_id_.erase(type); } // Erase the entry for |id|. id_to_type_.erase(iter); } uint32_t TypeManager::GetTypeInstruction(const Type* type) { uint32_t id = GetId(type); if (id != 0) return id; std::unique_ptr typeInst; // TODO(1841): Handle id overflow. id = context()->TakeNextId(); if (id == 0) { return 0; } RegisterType(id, *type); switch (type->kind()) { #define DefineParameterlessCase(kind) \ case Type::k##kind: \ typeInst = MakeUnique(context(), spv::Op::OpType##kind, 0, \ id, std::initializer_list{}); \ break DefineParameterlessCase(Void); DefineParameterlessCase(Bool); DefineParameterlessCase(Sampler); DefineParameterlessCase(Event); DefineParameterlessCase(DeviceEvent); DefineParameterlessCase(ReserveId); DefineParameterlessCase(Queue); DefineParameterlessCase(PipeStorage); DefineParameterlessCase(NamedBarrier); DefineParameterlessCase(AccelerationStructureNV); DefineParameterlessCase(RayQueryKHR); DefineParameterlessCase(HitObjectNV); #undef DefineParameterlessCase case Type::kInteger: typeInst = MakeUnique( context(), spv::Op::OpTypeInt, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsInteger()->width()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {(type->AsInteger()->IsSigned() ? 1u : 0u)}}}); break; case Type::kFloat: typeInst = MakeUnique( context(), spv::Op::OpTypeFloat, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsFloat()->width()}}}); break; case Type::kVector: { uint32_t subtype = GetTypeInstruction(type->AsVector()->element_type()); if (subtype == 0) { return 0; } typeInst = MakeUnique(context(), spv::Op::OpTypeVector, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {subtype}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsVector()->element_count()}}}); break; } case Type::kMatrix: { uint32_t subtype = GetTypeInstruction(type->AsMatrix()->element_type()); if (subtype == 0) { return 0; } typeInst = MakeUnique(context(), spv::Op::OpTypeMatrix, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {subtype}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {type->AsMatrix()->element_count()}}}); break; } case Type::kImage: { const Image* image = type->AsImage(); uint32_t subtype = GetTypeInstruction(image->sampled_type()); if (subtype == 0) { return 0; } typeInst = MakeUnique( context(), spv::Op::OpTypeImage, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {subtype}}, {SPV_OPERAND_TYPE_DIMENSIONALITY, {static_cast(image->dim())}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {image->depth()}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {(image->is_arrayed() ? 1u : 0u)}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {(image->is_multisampled() ? 1u : 0u)}}, {SPV_OPERAND_TYPE_LITERAL_INTEGER, {image->sampled()}}, {SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT, {static_cast(image->format())}}, {SPV_OPERAND_TYPE_ACCESS_QUALIFIER, {static_cast(image->access_qualifier())}}}); break; } case Type::kSampledImage: { uint32_t subtype = GetTypeInstruction(type->AsSampledImage()->image_type()); if (subtype == 0) { return 0; } typeInst = MakeUnique( context(), spv::Op::OpTypeSampledImage, 0, id, std::initializer_list{{SPV_OPERAND_TYPE_ID, {subtype}}}); break; } case Type::kArray: { uint32_t subtype = GetTypeInstruction(type->AsArray()->element_type()); if (subtype == 0) { return 0; } typeInst = MakeUnique( context(), spv::Op::OpTypeArray, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {subtype}}, {SPV_OPERAND_TYPE_ID, {type->AsArray()->LengthId()}}}); break; } case Type::kRuntimeArray: { uint32_t subtype = GetTypeInstruction(type->AsRuntimeArray()->element_type()); if (subtype == 0) { return 0; } typeInst = MakeUnique( context(), spv::Op::OpTypeRuntimeArray, 0, id, std::initializer_list{{SPV_OPERAND_TYPE_ID, {subtype}}}); break; } case Type::kStruct: { std::vector ops; const Struct* structTy = type->AsStruct(); for (auto ty : structTy->element_types()) { uint32_t member_type_id = GetTypeInstruction(ty); if (member_type_id == 0) { return 0; } ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {member_type_id})); } typeInst = MakeUnique(context(), spv::Op::OpTypeStruct, 0, id, ops); break; } case Type::kOpaque: { const Opaque* opaque = type->AsOpaque(); // Convert to null-terminated packed UTF-8 string. std::vector words = spvtools::utils::MakeVector(opaque->name()); typeInst = MakeUnique( context(), spv::Op::OpTypeOpaque, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_LITERAL_STRING, words}}); break; } case Type::kPointer: { const Pointer* pointer = type->AsPointer(); uint32_t subtype = GetTypeInstruction(pointer->pointee_type()); if (subtype == 0) { return 0; } typeInst = MakeUnique( context(), spv::Op::OpTypePointer, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast(pointer->storage_class())}}, {SPV_OPERAND_TYPE_ID, {subtype}}}); break; } case Type::kFunction: { std::vector ops; const Function* function = type->AsFunction(); uint32_t return_type_id = GetTypeInstruction(function->return_type()); if (return_type_id == 0) { return 0; } ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {return_type_id})); for (auto ty : function->param_types()) { uint32_t paramater_type_id = GetTypeInstruction(ty); if (paramater_type_id == 0) { return 0; } ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {paramater_type_id})); } typeInst = MakeUnique(context(), spv::Op::OpTypeFunction, 0, id, ops); break; } case Type::kPipe: typeInst = MakeUnique( context(), spv::Op::OpTypePipe, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_ACCESS_QUALIFIER, {static_cast(type->AsPipe()->access_qualifier())}}}); break; case Type::kForwardPointer: typeInst = MakeUnique( context(), spv::Op::OpTypeForwardPointer, 0, 0, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {type->AsForwardPointer()->target_id()}}, {SPV_OPERAND_TYPE_STORAGE_CLASS, {static_cast( type->AsForwardPointer()->storage_class())}}}); break; case Type::kCooperativeMatrixNV: { auto coop_mat = type->AsCooperativeMatrixNV(); uint32_t const component_type = GetTypeInstruction(coop_mat->component_type()); if (component_type == 0) { return 0; } typeInst = MakeUnique( context(), spv::Op::OpTypeCooperativeMatrixNV, 0, id, std::initializer_list{ {SPV_OPERAND_TYPE_ID, {component_type}}, {SPV_OPERAND_TYPE_SCOPE_ID, {coop_mat->scope_id()}}, {SPV_OPERAND_TYPE_ID, {coop_mat->rows_id()}}, {SPV_OPERAND_TYPE_ID, {coop_mat->columns_id()}}}); break; } default: assert(false && "Unexpected type"); break; } context()->AddType(std::move(typeInst)); context()->AnalyzeDefUse(&*--context()->types_values_end()); AttachDecorations(id, type); return id; } uint32_t TypeManager::FindPointerToType(uint32_t type_id, spv::StorageClass storage_class) { Type* pointeeTy = GetType(type_id); Pointer pointerTy(pointeeTy, storage_class); if (pointeeTy->IsUniqueType(true)) { // Non-ambiguous type. Get the pointer type through the type manager. return GetTypeInstruction(&pointerTy); } // Ambiguous type, do a linear search. Module::inst_iterator type_itr = context()->module()->types_values_begin(); for (; type_itr != context()->module()->types_values_end(); ++type_itr) { const Instruction* type_inst = &*type_itr; if (type_inst->opcode() == spv::Op::OpTypePointer && type_inst->GetSingleWordOperand(kSpvTypePointerTypeIdInIdx) == type_id && spv::StorageClass(type_inst->GetSingleWordOperand( kSpvTypePointerStorageClass)) == storage_class) return type_inst->result_id(); } // Must create the pointer type. // TODO(1841): Handle id overflow. uint32_t resultId = context()->TakeNextId(); std::unique_ptr type_inst( new Instruction(context(), spv::Op::OpTypePointer, 0, resultId, {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS, {uint32_t(storage_class)}}, {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}})); context()->AddType(std::move(type_inst)); context()->get_type_mgr()->RegisterType(resultId, pointerTy); return resultId; } void TypeManager::AttachDecorations(uint32_t id, const Type* type) { for (auto vec : type->decorations()) { CreateDecoration(id, vec); } if (const Struct* structTy = type->AsStruct()) { for (auto pair : structTy->element_decorations()) { uint32_t element = pair.first; for (auto vec : pair.second) { CreateDecoration(id, vec, /* is_member */ true, element); } } } } void TypeManager::CreateDecoration(uint32_t target, const std::vector& decoration, bool is_member, uint32_t element) { std::vector ops; ops.push_back(Operand(SPV_OPERAND_TYPE_ID, {target})); if (is_member) { ops.push_back(Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {element})); } ops.push_back(Operand(SPV_OPERAND_TYPE_DECORATION, {decoration[0]})); for (size_t i = 1; i < decoration.size(); ++i) { ops.push_back(Operand(SPV_OPERAND_TYPE_LITERAL_INTEGER, {decoration[i]})); } context()->AddAnnotationInst(MakeUnique( context(), (is_member ? spv::Op::OpMemberDecorate : spv::Op::OpDecorate), 0, 0, ops)); Instruction* inst = &*--context()->annotation_end(); context()->get_def_use_mgr()->AnalyzeInstUse(inst); } Type* TypeManager::RebuildType(const Type& type) { // The comparison and hash on the type pool will avoid inserting the rebuilt // type if an equivalent type already exists. The rebuilt type will be deleted // when it goes out of scope at the end of the function in that case. Repeated // insertions of the same Type will, at most, keep one corresponding object in // the type pool. std::unique_ptr rebuilt_ty; switch (type.kind()) { #define DefineNoSubtypeCase(kind) \ case Type::k##kind: \ rebuilt_ty.reset(type.Clone().release()); \ return type_pool_.insert(std::move(rebuilt_ty)).first->get() DefineNoSubtypeCase(Void); DefineNoSubtypeCase(Bool); DefineNoSubtypeCase(Integer); DefineNoSubtypeCase(Float); DefineNoSubtypeCase(Sampler); DefineNoSubtypeCase(Opaque); DefineNoSubtypeCase(Event); DefineNoSubtypeCase(DeviceEvent); DefineNoSubtypeCase(ReserveId); DefineNoSubtypeCase(Queue); DefineNoSubtypeCase(Pipe); DefineNoSubtypeCase(PipeStorage); DefineNoSubtypeCase(NamedBarrier); DefineNoSubtypeCase(AccelerationStructureNV); DefineNoSubtypeCase(RayQueryKHR); DefineNoSubtypeCase(HitObjectNV); #undef DefineNoSubtypeCase case Type::kVector: { const Vector* vec_ty = type.AsVector(); const Type* ele_ty = vec_ty->element_type(); rebuilt_ty = MakeUnique(RebuildType(*ele_ty), vec_ty->element_count()); break; } case Type::kMatrix: { const Matrix* mat_ty = type.AsMatrix(); const Type* ele_ty = mat_ty->element_type(); rebuilt_ty = MakeUnique(RebuildType(*ele_ty), mat_ty->element_count()); break; } case Type::kImage: { const Image* image_ty = type.AsImage(); const Type* ele_ty = image_ty->sampled_type(); rebuilt_ty = MakeUnique(RebuildType(*ele_ty), image_ty->dim(), image_ty->depth(), image_ty->is_arrayed(), image_ty->is_multisampled(), image_ty->sampled(), image_ty->format(), image_ty->access_qualifier()); break; } case Type::kSampledImage: { const SampledImage* image_ty = type.AsSampledImage(); const Type* ele_ty = image_ty->image_type(); rebuilt_ty = MakeUnique(RebuildType(*ele_ty)); break; } case Type::kArray: { const Array* array_ty = type.AsArray(); rebuilt_ty = MakeUnique(array_ty->element_type(), array_ty->length_info()); break; } case Type::kRuntimeArray: { const RuntimeArray* array_ty = type.AsRuntimeArray(); const Type* ele_ty = array_ty->element_type(); rebuilt_ty = MakeUnique(RebuildType(*ele_ty)); break; } case Type::kStruct: { const Struct* struct_ty = type.AsStruct(); std::vector subtypes; subtypes.reserve(struct_ty->element_types().size()); for (const auto* ele_ty : struct_ty->element_types()) { subtypes.push_back(RebuildType(*ele_ty)); } rebuilt_ty = MakeUnique(subtypes); Struct* rebuilt_struct = rebuilt_ty->AsStruct(); for (auto pair : struct_ty->element_decorations()) { uint32_t index = pair.first; for (const auto& dec : pair.second) { // Explicit copy intended. std::vector copy(dec); rebuilt_struct->AddMemberDecoration(index, std::move(copy)); } } break; } case Type::kPointer: { const Pointer* pointer_ty = type.AsPointer(); const Type* ele_ty = pointer_ty->pointee_type(); rebuilt_ty = MakeUnique(RebuildType(*ele_ty), pointer_ty->storage_class()); break; } case Type::kFunction: { const Function* function_ty = type.AsFunction(); const Type* ret_ty = function_ty->return_type(); std::vector param_types; param_types.reserve(function_ty->param_types().size()); for (const auto* param_ty : function_ty->param_types()) { param_types.push_back(RebuildType(*param_ty)); } rebuilt_ty = MakeUnique(RebuildType(*ret_ty), param_types); break; } case Type::kForwardPointer: { const ForwardPointer* forward_ptr_ty = type.AsForwardPointer(); rebuilt_ty = MakeUnique(forward_ptr_ty->target_id(), forward_ptr_ty->storage_class()); const Pointer* target_ptr = forward_ptr_ty->target_pointer(); if (target_ptr) { rebuilt_ty->AsForwardPointer()->SetTargetPointer( RebuildType(*target_ptr)->AsPointer()); } break; } case Type::kCooperativeMatrixNV: { const CooperativeMatrixNV* cm_type = type.AsCooperativeMatrixNV(); const Type* component_type = cm_type->component_type(); rebuilt_ty = MakeUnique( RebuildType(*component_type), cm_type->scope_id(), cm_type->rows_id(), cm_type->columns_id()); break; } default: assert(false && "Unhandled type"); return nullptr; } for (const auto& dec : type.decorations()) { // Explicit copy intended. std::vector copy(dec); rebuilt_ty->AddDecoration(std::move(copy)); } return type_pool_.insert(std::move(rebuilt_ty)).first->get(); } void TypeManager::RegisterType(uint32_t id, const Type& type) { // Rebuild |type| so it and all its constituent types are owned by the type // pool. Type* rebuilt = RebuildType(type); assert(rebuilt->IsSame(&type)); id_to_type_[id] = rebuilt; if (GetId(rebuilt) == 0) { type_to_id_[rebuilt] = id; } } Type* TypeManager::GetRegisteredType(const Type* type) { uint32_t id = GetTypeInstruction(type); if (id == 0) { return nullptr; } return GetType(id); } Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) { if (!IsTypeInst(inst.opcode())) return nullptr; Type* type = nullptr; switch (inst.opcode()) { case spv::Op::OpTypeVoid: type = new Void(); break; case spv::Op::OpTypeBool: type = new Bool(); break; case spv::Op::OpTypeInt: type = new Integer(inst.GetSingleWordInOperand(0), inst.GetSingleWordInOperand(1)); break; case spv::Op::OpTypeFloat: type = new Float(inst.GetSingleWordInOperand(0)); break; case spv::Op::OpTypeVector: type = new Vector(GetType(inst.GetSingleWordInOperand(0)), inst.GetSingleWordInOperand(1)); break; case spv::Op::OpTypeMatrix: type = new Matrix(GetType(inst.GetSingleWordInOperand(0)), inst.GetSingleWordInOperand(1)); break; case spv::Op::OpTypeImage: { const spv::AccessQualifier access = inst.NumInOperands() < 8 ? spv::AccessQualifier::ReadOnly : static_cast( inst.GetSingleWordInOperand(7)); type = new Image( GetType(inst.GetSingleWordInOperand(0)), static_cast(inst.GetSingleWordInOperand(1)), inst.GetSingleWordInOperand(2), inst.GetSingleWordInOperand(3) == 1, inst.GetSingleWordInOperand(4) == 1, inst.GetSingleWordInOperand(5), static_cast(inst.GetSingleWordInOperand(6)), access); } break; case spv::Op::OpTypeSampler: type = new Sampler(); break; case spv::Op::OpTypeSampledImage: type = new SampledImage(GetType(inst.GetSingleWordInOperand(0))); break; case spv::Op::OpTypeArray: { const uint32_t length_id = inst.GetSingleWordInOperand(1); const Instruction* length_constant_inst = id_to_constant_inst_[length_id]; assert(length_constant_inst); // How will we distinguish one length value from another? // Determine extra words required to distinguish this array length // from another. std::vector extra_words{Array::LengthInfo::kDefiningId}; // If it is a specialised constant, retrieve its SpecId. // Only OpSpecConstant has a SpecId. uint32_t spec_id = 0u; bool has_spec_id = false; if (length_constant_inst->opcode() == spv::Op::OpSpecConstant) { context()->get_decoration_mgr()->ForEachDecoration( length_id, uint32_t(spv::Decoration::SpecId), [&spec_id, &has_spec_id](const Instruction& decoration) { assert(decoration.opcode() == spv::Op::OpDecorate); spec_id = decoration.GetSingleWordOperand(2u); has_spec_id = true; }); } const auto opcode = length_constant_inst->opcode(); if (has_spec_id) { extra_words.push_back(spec_id); } if ((opcode == spv::Op::OpConstant) || (opcode == spv::Op::OpSpecConstant)) { // Always include the literal constant words. In the spec constant // case, the constant might not be overridden, so it's still // significant. extra_words.insert(extra_words.end(), length_constant_inst->GetOperand(2).words.begin(), length_constant_inst->GetOperand(2).words.end()); extra_words[0] = has_spec_id ? Array::LengthInfo::kConstantWithSpecId : Array::LengthInfo::kConstant; } else { assert(extra_words[0] == Array::LengthInfo::kDefiningId); extra_words.push_back(length_id); } assert(extra_words.size() >= 2); Array::LengthInfo length_info{length_id, extra_words}; type = new Array(GetType(inst.GetSingleWordInOperand(0)), length_info); if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) { incomplete_types_.emplace_back(inst.result_id(), type); id_to_incomplete_type_[inst.result_id()] = type; return type; } } break; case spv::Op::OpTypeRuntimeArray: type = new RuntimeArray(GetType(inst.GetSingleWordInOperand(0))); if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) { incomplete_types_.emplace_back(inst.result_id(), type); id_to_incomplete_type_[inst.result_id()] = type; return type; } break; case spv::Op::OpTypeStruct: { std::vector element_types; bool incomplete_type = false; for (uint32_t i = 0; i < inst.NumInOperands(); ++i) { uint32_t type_id = inst.GetSingleWordInOperand(i); element_types.push_back(GetType(type_id)); if (id_to_incomplete_type_.count(type_id)) { incomplete_type = true; } } type = new Struct(element_types); if (incomplete_type) { incomplete_types_.emplace_back(inst.result_id(), type); id_to_incomplete_type_[inst.result_id()] = type; return type; } } break; case spv::Op::OpTypeOpaque: { type = new Opaque(inst.GetInOperand(0).AsString()); } break; case spv::Op::OpTypePointer: { uint32_t pointee_type_id = inst.GetSingleWordInOperand(1); type = new Pointer( GetType(pointee_type_id), static_cast(inst.GetSingleWordInOperand(0))); if (id_to_incomplete_type_.count(pointee_type_id)) { incomplete_types_.emplace_back(inst.result_id(), type); id_to_incomplete_type_[inst.result_id()] = type; return type; } id_to_incomplete_type_.erase(inst.result_id()); } break; case spv::Op::OpTypeFunction: { bool incomplete_type = false; uint32_t return_type_id = inst.GetSingleWordInOperand(0); if (id_to_incomplete_type_.count(return_type_id)) { incomplete_type = true; } Type* return_type = GetType(return_type_id); std::vector param_types; for (uint32_t i = 1; i < inst.NumInOperands(); ++i) { uint32_t param_type_id = inst.GetSingleWordInOperand(i); param_types.push_back(GetType(param_type_id)); if (id_to_incomplete_type_.count(param_type_id)) { incomplete_type = true; } } type = new Function(return_type, param_types); if (incomplete_type) { incomplete_types_.emplace_back(inst.result_id(), type); id_to_incomplete_type_[inst.result_id()] = type; return type; } } break; case spv::Op::OpTypeEvent: type = new Event(); break; case spv::Op::OpTypeDeviceEvent: type = new DeviceEvent(); break; case spv::Op::OpTypeReserveId: type = new ReserveId(); break; case spv::Op::OpTypeQueue: type = new Queue(); break; case spv::Op::OpTypePipe: type = new Pipe( static_cast(inst.GetSingleWordInOperand(0))); break; case spv::Op::OpTypeForwardPointer: { // Handling of forward pointers is different from the other types. uint32_t target_id = inst.GetSingleWordInOperand(0); type = new ForwardPointer(target_id, static_cast( inst.GetSingleWordInOperand(1))); incomplete_types_.emplace_back(target_id, type); id_to_incomplete_type_[target_id] = type; return type; } case spv::Op::OpTypePipeStorage: type = new PipeStorage(); break; case spv::Op::OpTypeNamedBarrier: type = new NamedBarrier(); break; case spv::Op::OpTypeAccelerationStructureNV: type = new AccelerationStructureNV(); break; case spv::Op::OpTypeCooperativeMatrixNV: type = new CooperativeMatrixNV(GetType(inst.GetSingleWordInOperand(0)), inst.GetSingleWordInOperand(1), inst.GetSingleWordInOperand(2), inst.GetSingleWordInOperand(3)); break; case spv::Op::OpTypeRayQueryKHR: type = new RayQueryKHR(); break; case spv::Op::OpTypeHitObjectNV: type = new HitObjectNV(); break; default: SPIRV_UNIMPLEMENTED(consumer_, "unhandled type"); break; } uint32_t id = inst.result_id(); SPIRV_ASSERT(consumer_, id != 0, "instruction without result id found"); SPIRV_ASSERT(consumer_, type != nullptr, "type should not be nullptr at this point"); std::vector decorations = context()->get_decoration_mgr()->GetDecorationsFor(id, true); for (auto dec : decorations) { AttachDecoration(*dec, type); } std::unique_ptr unique(type); auto pair = type_pool_.insert(std::move(unique)); id_to_type_[id] = pair.first->get(); type_to_id_[pair.first->get()] = id; return type; } void TypeManager::AttachDecoration(const Instruction& inst, Type* type) { const spv::Op opcode = inst.opcode(); if (!IsAnnotationInst(opcode)) return; switch (opcode) { case spv::Op::OpDecorate: { const auto count = inst.NumOperands(); std::vector data; for (uint32_t i = 1; i < count; ++i) { data.push_back(inst.GetSingleWordOperand(i)); } type->AddDecoration(std::move(data)); } break; case spv::Op::OpMemberDecorate: { const auto count = inst.NumOperands(); const uint32_t index = inst.GetSingleWordOperand(1); std::vector data; for (uint32_t i = 2; i < count; ++i) { data.push_back(inst.GetSingleWordOperand(i)); } if (Struct* st = type->AsStruct()) { st->AddMemberDecoration(index, std::move(data)); } else { SPIRV_UNIMPLEMENTED(consumer_, "OpMemberDecorate non-struct type"); } } break; default: SPIRV_UNREACHABLE(consumer_); break; } } const Type* TypeManager::GetMemberType( const Type* parent_type, const std::vector& access_chain) { for (uint32_t element_index : access_chain) { if (const Struct* struct_type = parent_type->AsStruct()) { parent_type = struct_type->element_types()[element_index]; } else if (const Array* array_type = parent_type->AsArray()) { parent_type = array_type->element_type(); } else if (const RuntimeArray* runtime_array_type = parent_type->AsRuntimeArray()) { parent_type = runtime_array_type->element_type(); } else if (const Vector* vector_type = parent_type->AsVector()) { parent_type = vector_type->element_type(); } else if (const Matrix* matrix_type = parent_type->AsMatrix()) { parent_type = matrix_type->element_type(); } else { assert(false && "Trying to get a member of a type without members."); } } return parent_type; } void TypeManager::ReplaceForwardPointers(Type* type) { switch (type->kind()) { case Type::kArray: { const ForwardPointer* element_type = type->AsArray()->element_type()->AsForwardPointer(); if (element_type) { type->AsArray()->ReplaceElementType(element_type->target_pointer()); } } break; case Type::kRuntimeArray: { const ForwardPointer* element_type = type->AsRuntimeArray()->element_type()->AsForwardPointer(); if (element_type) { type->AsRuntimeArray()->ReplaceElementType( element_type->target_pointer()); } } break; case Type::kStruct: { auto& member_types = type->AsStruct()->element_types(); for (auto& member_type : member_types) { if (member_type->AsForwardPointer()) { member_type = member_type->AsForwardPointer()->target_pointer(); assert(member_type); } } } break; case Type::kPointer: { const ForwardPointer* pointee_type = type->AsPointer()->pointee_type()->AsForwardPointer(); if (pointee_type) { type->AsPointer()->SetPointeeType(pointee_type->target_pointer()); } } break; case Type::kFunction: { Function* func_type = type->AsFunction(); const ForwardPointer* return_type = func_type->return_type()->AsForwardPointer(); if (return_type) { func_type->SetReturnType(return_type->target_pointer()); } auto& param_types = func_type->param_types(); for (auto& param_type : param_types) { if (param_type->AsForwardPointer()) { param_type = param_type->AsForwardPointer()->target_pointer(); } } } break; default: break; } } void TypeManager::ReplaceType(Type* new_type, Type* original_type) { assert(original_type->kind() == new_type->kind() && "Types must be the same for replacement.\n"); for (auto& p : incomplete_types_) { Type* type = p.type(); if (!type) { continue; } switch (type->kind()) { case Type::kArray: { const Type* element_type = type->AsArray()->element_type(); if (element_type == original_type) { type->AsArray()->ReplaceElementType(new_type); } } break; case Type::kRuntimeArray: { const Type* element_type = type->AsRuntimeArray()->element_type(); if (element_type == original_type) { type->AsRuntimeArray()->ReplaceElementType(new_type); } } break; case Type::kStruct: { auto& member_types = type->AsStruct()->element_types(); for (auto& member_type : member_types) { if (member_type == original_type) { member_type = new_type; } } } break; case Type::kPointer: { const Type* pointee_type = type->AsPointer()->pointee_type(); if (pointee_type == original_type) { type->AsPointer()->SetPointeeType(new_type); } } break; case Type::kFunction: { Function* func_type = type->AsFunction(); const Type* return_type = func_type->return_type(); if (return_type == original_type) { func_type->SetReturnType(new_type); } auto& param_types = func_type->param_types(); for (auto& param_type : param_types) { if (param_type == original_type) { param_type = new_type; } } } break; default: break; } } } } // namespace analysis } // namespace opt } // namespace spvtools