syntax = "proto3"; package jax_triton; message TritonKernel { string kernel_name = 1; // Kernel function name within module. uint32 num_warps = 2; uint32 shared_mem_bytes = 3; string ptx = 4; string ttir = 5; uint32 compute_capability = 6; uint32 cluster_dim_0 = 7; uint32 cluster_dim_1 = 8; uint32 cluster_dim_2 = 9; } message TritonKernelCall { message Parameter { message Array { uint64 bytes_to_zero = 1; uint64 ptr_divisibility = 2; } oneof value { Array array = 1; bool bool_ = 2; int32 i32 = 3; uint32 u32 = 4; int64 i64 = 5; uint64 u64 = 6; float f32 = 7; double f64 = 8; } } TritonKernel kernel = 1; uint32 grid_0 = 2; uint32 grid_1 = 3; uint32 grid_2 = 4; repeated Parameter parameters = 5; } message TritonAutotunedKernelCall { message Config { TritonKernelCall kernel_call = 1; string description = 2; } message InputOutputAlias { uint32 input_buffer_idx = 1; uint32 output_buffer_idx = 2; uint64 buffer_size_bytes = 3; } string name = 1; // Name used in auto-tuning log messages. repeated Config configs = 2; repeated InputOutputAlias input_output_aliases = 3; } message TritonAnyKernelCall { oneof value { TritonKernelCall kernel_call = 1; TritonAutotunedKernelCall autotuned_kernel_call = 2; } bytes metadata = 3; string name = 4; // User assigned name. }