/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // This proto file defines messages which represent the HLO module. This is a // full fidelity serialization of the c++ HLO constructs. // // Many of the protos below are simple 1-to-1 serializations of the // corresponding C++ classes, e.g., HloModule, HloComputation, and // HloInstruction. // // FIELD NAMES ARE IMPORTANT // // Unlike most protos, you can't safely change the names of fields, even if you // keep the numeric ids the same. This is because we sometimes serialize these // protos as JSON, which includes the field names in the serialization. syntax = "proto3"; package xla; import "google/protobuf/any.proto"; import "xla/xla_data.proto"; option cc_enable_arenas = true; enum CustomCallSchedule { SCHEDULE_NONE = 0; SCHEDULE_LATEST = 1; SCHEDULE_EARLIEST = 2; } // The version of the API used by the custom call function. The signatures for // each version are given below. // TODO(b/189822916): Remove this enum when all clients are migrated to the // status-returning API. enum CustomCallApiVersion { API_VERSION_UNSPECIFIED = 0; // The first version of the API, with the following signatures: // // CPU: // void do_custom_call(void* out, const void** in); // // GPU: // void do_custom_call(CUstream stream, void** buffers, // const char* opaque, size_t opaque_len); API_VERSION_ORIGINAL = 1; // When the ability to return success/failure status was added: // // CPU: // void do_custom_call(void* out, const void** in, // XlaCustomCallStatus* status); // // GPU: // void do_custom_call(CUstream stream, void** buffers, // const char* opaque, size_t opaque_len, // XlaCustomCallStatus* status); // API_VERSION_STATUS_RETURNING = 2; // Fixes the API signatures on the CPU side of the version STATUS_RETURNING by // adding the opaque string so that the custom call API is consistent across // CPUs and GPUs. For GPUs, the behaviors invoked by // API_VERSION_STATUS_RETURNING and API_VERSION_STATUS_RETURNING_UNIFIED are // the same. // // CPU: // void do_custom_call(void* out, const void** in, // const char* opaque, size_t opaque_len, // XlaCustomCallStatus* status); // // GPU: // void do_custom_call(CUstream stream, void** buffers, // const char* opaque, size_t opaque_len, // XlaCustomCallStatus* status); // API_VERSION_STATUS_RETURNING_UNIFIED = 3; // Api version implementing XLA runtime custom call calling convention. These // custom calls can be registered as an XLA runtime custom call (1) or as XLA // runtime FFI binding (2). // // This type of custom call uses custom ABI to pass type information along // with custom call arguments. Also it passes buffer arguments together with // data type, sizes and strides. // // Example: (XLA runtime custom call) // // absl::Status DoCustomCall(StridedMemrefView arg, float attr); // // CustomCall::Bind("custom_call") // .Arg() // .Attr("attr") // .To(DoCustomCall); // // (1) xla/runtime/custom_call.h // (2) xla/runtime/ffi/ffi.h API_VERSION_TYPED_FFI = 4; } // Serialization of HloInstruction. // Next ID: 90 message HloInstructionProto { reserved 10; reserved "parameter_name"; reserved 12; reserved "fused_instructions_computation"; reserved 4; reserved "operand_names"; reserved 5; reserved "control_predecessor_names"; reserved 6; reserved "called_computation_names"; reserved 44; reserved "replica_group_ids"; // Use backend_config instead for custom_call_opaque. reserved 53; reserved "custom_call_opaque"; // Use backend_config instead for all_reduce_barrier. reserved 46; reserved "all_reduce_barrier"; string name = 1; string opcode = 2; xla.ShapeProto shape = 3; xla.OpMetadata metadata = 7; // Literal, only present for kConstant. xla.LiteralProto literal = 8; // Parameter number is only present for kParameter. int64 parameter_number = 9; // Fusion state, only present for kFusion. string fusion_kind = 11; // Index for kGetTupleElement. int64 tuple_index = 13; // Dimensions present for some operations that require reshaping or // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. repeated int64 dimensions = 14; // Describes the window in a windowed operation such as convolution. xla.Window window = 15; // Describes the dimension numbers used for a convolution. xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; // The number of feature groups. Used for a convolution. Must be a divisor of // the input feature dimension and output feature dimension. If not specified, // it will use a default value of 1. int64 feature_group_count = 50; int64 batch_group_count = 58; // Describes the [begin, end) index range and stride for slices. message SliceDimensions { int64 start = 1; int64 limit = 2; int64 stride = 3; } repeated SliceDimensions slice_dimensions = 17; // The bit sizes for a reduce-precision operation. int32 exponent_bits = 18; int32 mantissa_bits = 19; // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). repeated int64 dynamic_slice_sizes = 20; // The padding configuration that describes the edge padding and interior // padding of this pad instruction. Only set for pad instructions. xla.PaddingConfig padding_config = 21; // Outfeed configuration information, only present for kOutfeed. bytes outfeed_config = 22; // The distribution requested for random number generation. // Only present for kRng. xla.RandomDistribution distribution = 23; // A small float number added to the variance to avoid divide-by-zero error. // Only present for kBatchNormTraining, kBatchNormInference, and // kBatchNormGrad. float epsilon = 24; // An integer value representing the index of the feature dimension. // Only present for kBatchNormTraining, kBatchNormInference, and // kBatchNormGrad. int64 feature_index = 25; // Represents a unique identifier for each Send/Recv instruction pair or // optionally for collective instructions (AllReduce, CollectivePermute, // AllToAll). Non-positive channel_id is equivalent to no channel id. int64 channel_id = 26; // The string representation of the infeed configuration. bytes infeed_config = 27; // Name of a external target (eg, global symbol) to call, only present for // kCustomCall. string custom_call_target = 28; // Shape of outfeed request. xla.ShapeProto outfeed_shape = 29; // Describes the dimension numbers used for a dot operation xla.DotDimensionNumbers dot_dimension_numbers = 30; // FFT type (FFT, IFFT, etc). xla.FftType fft_type = 31; // FFT length. repeated int64 fft_length = 32; // Comparison direction only used for kCompare. string comparison_direction = 63; // Gather dimension numbers. xla.GatherDimensionNumbers gather_dimension_numbers = 33; repeated int64 gather_slice_sizes = 34; // Used to be compute host-related fields. reserved 41; reserved 42; // The id of this instruction. int64 id = 35; repeated int64 operand_ids = 36; repeated int64 control_predecessor_ids = 37; repeated int64 called_computation_ids = 38; xla.OpSharding sharding = 40; // Backend configuration for the instruction. Has backend-specific meaning. bytes backend_config = 43; // Deprecated, but keeping for backward compatibility. // Use collective_device_list. Cross replica op fields. repeated ReplicaGroup replica_groups = 49 [deprecated = true]; // Deprecated, but keeping it for backward compatibility. Use channel_id. // Non-positive all_reduce_id is equivalent to no all_reduce_id. int64 all_reduce_id = 45 [deprecated = true]; // If true, interprets ids in ReplicaGroup as global device ids, which is // a linearized id of `replica_id * partition_count + partition_id`. bool use_global_device_ids = 71; // Whether this Send/Recv instruction transfers data to/from the host. Only // present for Send and Recv instructions and their SendDone and RecvDone // partners. bool is_host_transfer = 47; // Whether this Sort instruction should be stable. bool is_stable = 60; xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. xla.PrecisionConfig precision_config = 51; // Collective permute field. repeated SourceTarget source_target_pairs = 52; // Sharding for kDomain instructions. xla.OpSharding domain_entry_sharding = 54; xla.OpSharding domain_exit_sharding = 55; // For custom call this indicates that the layouts are constrained. If // constrain_layout is true then the 'shape' field must contain a layout, and // 'operand_shapes_with_layout' must contain a shape with layout for each // operand. bool constrain_layout = 56; repeated xla.ShapeProto operand_shapes_with_layout = 57; // Options for TriangularSolve xla.TriangularSolveOptions triangular_solve_options = 59; // Options for Cholesky xla.CholeskyOptions cholesky_options = 62; // Describes how parameters behave with regards to replicas. xla.ParameterReplication parameter_replication = 61; reserved 64; // Whether the kCustomCall instruction has side-effects, only present for // kCustomCall. bool custom_call_has_side_effect = 65; // A list of OutputOperandAliasing pairs that specifies aliasing buffers // between output and operands for kCustomCall and kFusion. repeated xla.OutputOperandAliasing output_operand_aliasing = 74; // Specifies the desired schedule for the custom-call. The field is only // present for custom-call. CustomCallSchedule custom_call_schedule = 76; // The delta value for kRngGetAndUpdateState. int64 delta = 66; // Specifies if the gather/scatter indices are guaranteed to be sorted by the // caller. bool indices_are_sorted = 67; // Frontend attributes to pass to the XLA backend. xla.FrontendAttributes frontend_attributes = 68; // Specifies if all elements updated are guaranteed to be unique by // the caller. bool unique_indices = 69; // RNG algorithm used by kRngBitGenerator. xla.RandomAlgorithm rng_algorithm = 70; // The comparison type used for kCompare. string comparison_type = 72; // Specifies if this is a cross-program-prefetch, used by kCopyStart. // Deprecated and replaced by optional_cross_program_prefetch_index. bool is_cross_program_prefetch = 73 [deprecated = true]; // Specifies the cross-program-prefetch index used by kCopyStart. Uses oneof // to emulate the 'optional' keyword for proto3 versions before v3.15.0 // released 2021/2/18. oneof optional_cross_program_prefetch_index { int32 cross_program_prefetch_index = 80; } // If a convolution is dynamic, a dynamic padding type will be specified. xla.PaddingType padding_type = 75; // The API version used by the custom call function. This field is only // present for custom-call. // TODO(b/189822916): Remove this field when all clients are migrated to the // status-returning API. CustomCallApiVersion custom_call_api_version = 77; // Used to be async_group_id. reserved 78; // Represents a unique execution thread name for one or more async groups. // Each HLO module may contain a main thread and one or more parallel threads. // Empty async_execution_thread is equivalent to main thread. string async_execution_thread = 79; // Represents the K value for top-k. int64 k = 81; // Represents the largest flag for top-k. bool largest = 85; // Represents the information for tracking propagation of values within HLO // graph. xla.StatisticsViz statistics_viz = 82; // Used to be operation_queue_id. reserved 83; // Used to be wait_on_operation_queues. reserved 84; // Sparsity descriptor for dot operation. repeated xla.SparsityDescriptor dot_sparsity = 86; // Represents the list of devices that participate in a collective operation. xla.CollectiveDeviceListProto collective_device_list = 87; // For HLO value tracking. xla.OriginalValueProto original_value = 88; // Specifies if a call instruction is a composite. bool is_composite = 89; } // Serialization of HloComputation. message HloComputationProto { reserved 3; reserved "root_name"; string name = 1; // The array of instructions is always in a valid dependency order, where // operands appear before their users. repeated HloInstructionProto instructions = 2; // The program shape (with layout) of this computation. xla.ProgramShapeProto program_shape = 4; // The id of this computation. int64 id = 5; // The id of the root of the computation. int64 root_id = 6; // Whether this is a fusion computation. Fusion computations should use this // to determine whether they are a fusion in CreateFromProto since the // parent fusion_instruction_ may get removed and be nullptr. bool is_fusion_computation = 7; // The name of execution thread this computation belongs to. string execution_thread = 8; } // Serialization of an HLO schedule. An HLO schedule contains a total order of // instructions for each non-fusion computation in the module. message HloScheduleProto { message InstructionSequence { repeated int64 instruction_ids = 1; } // Map from computation id to sequence. map sequences = 1; } enum Kind { // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 // behavior and missing has_*() APIs. UNDEFINED_ALIAS = 0; // The buffers may or may not alias at runtime. MAY_ALIAS = 1; // The buffers must alias at runtime. MUST_ALIAS = 2; } message HloInputOutputAliasProto { // The following proto describes a pair of aliased an input // (described by parameter number and a ShapeIndex of the parameter) // and an output (described by a ShapeIndex of the root // instruction). For example: // // entry = { // output_shape_index={1}, // parameter_number=0, // parameter_shape_index={1, 2}, // } // // This entry indicates that the first parameter's {1, 2} element is // aliased with the {1} element of the root instruction. message AliasEntryProto { // ShapeIndex of the root hlo. repeated int64 output_shape_index = 1; // Number of the parameter in entry computation. int64 parameter_number = 2; // ShapeIndex of the parameter instruction. repeated int64 parameter_shape_index = 3; // The kind of alias to be setup. Kind kind = 4; } repeated AliasEntryProto entries = 1; } message HloBufferDonorProto { // The following proto describes an input (described by parameter number and a // ShapeIndex of the parameter) that can donate its butter to any output // tensor. It is similar to HloInputOutputAliasProto, but without a paired // output. For example: // // entry = { // parameter_number=0, // parameter_shape_index={1, 2}, // } // // This entry indicates that the first parameter's {1, 2} element can donate // its buffer. message BufferDonorEntryProto { // Number of the parameter in entry computation. int64 parameter_number = 1; // ShapeIndex of the parameter instruction. repeated int64 parameter_shape_index = 2; } repeated BufferDonorEntryProto entries = 1; } message CrossProgramPrefetch { int64 parameter = 1; repeated int64 index = 2; int64 offset = 3; } // Serialization of stack frames index representations. // Stack frames index presented in four flat arrays: // 1. File names array. // 2. Function names array. // 3. File location array. // 4. Frame array. // All reference ids in sub-protos are 1-based positions of the // entity in the flat array. // Ids are 1-based to keep 0 value as representation of non-set property. message StackFrameIndexProto { // Serialization of file position. message FileLocation { // 1-based position of file name. int32 file_name_id = 1; // 1-based position of function name. int32 function_name_id = 2; // Line number. int32 line = 3; // Column number. int32 column = 4; } // Serialization of frame. message StackFrame { // 1-based position of file location. int32 file_location_id = 1; // 1-based position of the parent frame. int32 parent_frame_id = 2; } // Flat index array of file names. repeated string file_names = 1; // Flat index array of function names. repeated string function_names = 2; // Flat index array of file locations. repeated FileLocation file_locations = 3; // Flat index array of frames. repeated StackFrame stack_frames = 4; } // Serialization of HloModule. message HloModuleProto { string name = 1; string entry_computation_name = 2; int64 entry_computation_id = 6; // The array of computations is always in a valid dependency order, where // callees appear before their callers. repeated HloComputationProto computations = 3; // The host program shape (with layout) of the entry computation. xla.ProgramShapeProto host_program_shape = 4; // The id of this module. int64 id = 5; // The schedule for this module. HloScheduleProto schedule = 7; // Describes alias information between inputs and outputs. HloInputOutputAliasProto input_output_alias = 8; // Describes the information of input buffer donors. HloBufferDonorProto buffer_donor = 18; repeated CrossProgramPrefetch cross_program_prefetches = 10; // True if the module contains dynamic computation. bool is_dynamic = 11; xla.OpSharding spmd_output_sharding = 12; repeated xla.OpSharding spmd_parameters_shardings = 14; // Uses AutoSharding pass or not. bool use_auto_spmd_partitioning = 16; // The type of optimization profile in use for module-level optimizations. enum ProfileType { INVALID = 0; FLAG = 1; FUSION = 2; LAYOUT = 3; DOT = 4; FLAGNET = 5; } // Information about the optimization profile that this module contains. message ProfileInfo { // The optimization profiles that this module contains. ProfileType profile_type = 1; // Speedup of tuned config compared to default config. double relative_speedup = 2; // The source of the optimization profile that this module contains. xla.ProfileSource profile_source = 3; // The compilation event that triggered the use of the profile. xla.CompilationEvent compilation_event = 4; // The fingerprint of the unoptimized module this profile was applied to. string fingerprint = 5; } // Profile information for the HLO module. repeated ProfileInfo profile_info = 13; // DeviceAssignment object information. DeviceAssignmentProto device_assignment = 15; // Stack frames index. StackFrameIndexProto stack_frame_index = 17; // Frontend attributes to pass to the XLA backend. xla.FrontendAttributes frontend_attributes = 19; reserved 9; reserved "dynamic_parameter_binding"; } // Serialization of LogicalBuffer. message LogicalBufferProto { // Location represents an instruction and its shape index, which uniquely // identifies a point where a buffer is needed. message Location { // TODO(b/239098765): Remove instruction_name and computation_name. string instruction_name = 2 [deprecated = true]; int64 instruction_id = 4; repeated int64 shape_index = 3; reserved 1; } int64 id = 1; int64 size = 2; // The location where the buffer is defined. Location defined_at = 3; int64 color = 4; } // Serialization of BufferAllocation. message BufferAllocationProto { // Assigned represents a single LogicalBuffer that is assigned to this // BufferAllocation. message Assigned { int64 logical_buffer_id = 1; int64 offset = 2; int64 size = 3; } int64 index = 1; int64 size = 2; bool is_thread_local = 3; bool is_tuple = 11; bool is_entry_computation_parameter = 5; bool is_constant = 12; int64 parameter_number = 6; repeated int64 parameter_shape_index = 10; bool maybe_live_out = 7; int64 color = 8; repeated Assigned assigned = 9; } // A trace of a HeapSimulator run. message HeapSimulatorTrace { // The trace includes a list of events, where each event describes one action // performed by the heap simulator. message Event { enum Kind { ALLOC = 0; // A memory region was allocated for the buffer. FREE = 1; // A memory region was freed for the buffer. // A buffer was shared with another (canonical) buffer. This is similar to // ALLOC, except that instead of allocating a new region of memory, the // memory region of the canonical buffer is directly re-used. Multiple // buffers may share with the same canonical buffer. The lifetime of the // canonical buffer is extended to the union of all lifetimes. SHARE_WITH = 2; } Kind kind = 1; // The id of the LogicalBuffer that the event applies to. int64 buffer_id = 2; // The HloInstruction that the simulation was processing that caused this // event to occur, identified by its computation and instruction name. E.g. // buffers defined by instruction A are allocated when processing A. string computation_name = 3; string instruction_name = 4; // The id of the canonical LogicalBuffer that the buffer shares with. Only // set for SHARE_WITH events. int64 share_with_canonical_id = 5; } repeated Event events = 1; bool whole_module_simulation = 2; int64 buffer_allocation_index = 3; } // An abstraction representing a set of HLO module built to run concurrently // across different devices. message HloModuleGroupProto { string name = 1; repeated HloModuleProto hlo_modules = 2; } // Serialization of BufferAssignment. message BufferAssignmentProto { // Alias represents a source LogicalBuffer, and the buffer location that // aliases it. message BufferAlias { int64 source_buffer_id = 1; LogicalBufferProto.Location location = 2; } repeated LogicalBufferProto logical_buffers = 1; repeated BufferAlias buffer_aliases = 2; repeated BufferAllocationProto buffer_allocations = 3; repeated HeapSimulatorTrace heap_simulator_traces = 4; } // Grouping message that contains all of the information above. message HloProto { reserved 2; reserved "hlo_ordering"; HloModuleProto hlo_module = 1; BufferAssignmentProto buffer_assignment = 3; } // Encapsulates HloProto together with the arguments, result, and // execution_platform. This message is used for purposes such as // analysis/replay/file-storage. message HloSnapshot { // The hlo graph. HloProto hlo = 1; // The arguments passed to the graph. repeated LiteralProto arguments = 2; // The result of the graph. LiteralProto result = 3; // The name of the platform used to run the graph. string execution_platform = 4; } // Metadata for an HLO module. Dumped after HLO passes and before LLO lowering // with filename module_####.metadata.textproto, where #### is // canonical_module_id. message HloModuleMetadataProto { // Uniquely identifies an HloModuleMetadata. Equal to the first unique_id // of the module (a module may go through multiple unique_ids). If a module // is partitioned into multiple modules, those modules will each have a new // HloModuleMetadata with a different canonical_module_id. int64 canonical_module_id = 1; // Name of the module group that the module is part of. string module_group_name = 2; // The canonical module id of the module that this one is partitioned from, // if applicable. int64 original_module_id = 3; // The canonical module ids of the modules that this one is partitioned into, // if applicable. repeated int64 partitioned_module_ids = 4; // Metadata for the HLO passes that are run on the module. repeated HloPassMetadata pass_metadata = 5; } // Metadata for one run of an HLO pass on a module. Provides more information // when processing debug dumps of HloProtos about the order of HLO passes and // various other stats like duration. `pass_id` may also be used to identify a // particular run of a pass in debug info that propagates through stages of // compilation. message HloPassMetadata { // For a given module, pass_id uniquely identifies a run of an HLO pass on // that module. Note that a pass_id may not always refer to the same pass // because the order of passes during compilation may change. For finding // metadata for a particular pass, pass_name and pipeline_name would be more // reliable, although note that they may not be unique. int64 pass_id = 1; string pass_name = 2; string pipeline_name = 3; // Filenames of the dumps of the module after this pass ran. Module may be // dumped in multiple formats, and the order of formats in this field will // stay consistent across passes. repeated string dump_filenames = 4; // Return value of pass.Run(). True if this pass changed the module, or, in // the case where the module was run through this pass as part of a module // group, true if this pass changed any module in the same module group. bool module_changed = 5; // The unique_id of the module that this pass is run on. May be different from // the canonical_module_id of the HloModuleMetadata that this HloPassMetadata // is inside. int64 module_id = 6; // If the module went through this pass as part of a module group, this is // set as the ids of all the modules in the module group. Empty otherwise. repeated int64 module_group_module_ids = 7; // Timestamp before and after the pass is run. Note they may be equal. int64 start_timestamp_usec = 8; int64 end_timestamp_usec = 9; // Custom metadata for the pass. google.protobuf.Any custom_metadata = 10; }