#include #include #include #include #include #include #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "nanobind/stl/string.h" #include "nanobind/stl/string_view.h" #include "nanobind/stl/tuple.h" #include "nanobind/stl/vector.h" #include "absl/status/statusor.h" #include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/triton.pb.h" #include "jaxlib/gpu/triton_kernels.h" #include "jaxlib/gpu/triton_utils.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) namespace nb = nanobind; namespace jax::JAX_GPU_NAMESPACE { NB_MODULE(_triton, m) { nb::class_(m, "TritonKernel") .def(nb::init()); nb::class_(m, "TritonParameter"); m.def("create_array_parameter", [](size_t bytes_to_zero, size_t ptr_divisibility) { return KernelCall::Parameter{ KernelCall::Parameter::Array{bytes_to_zero, ptr_divisibility}}; }); m.def("create_scalar_parameter", ValueOrThrowWrapper([](bool value, std::string_view dtype) -> absl::StatusOr { if ((dtype == "i1") || (dtype == "B")) { return KernelCall::Parameter{value}; } else { return absl::InvalidArgumentError(std::string("unknown dtype: ") + dtype.data()); } })); m.def("create_scalar_parameter", ValueOrThrowWrapper([](nb::int_ value, std::string_view dtype) -> absl::StatusOr { if (dtype == "i32") { return KernelCall::Parameter{static_cast(value)}; } else if (dtype == "u32") { return KernelCall::Parameter{static_cast(value)}; } else if (dtype == "i64") { return KernelCall::Parameter{static_cast(value)}; } else if (dtype == "u64") { return KernelCall::Parameter{static_cast(value)}; } else { return absl::InvalidArgumentError(std::string("unknown dtype: ") + dtype.data()); } })); m.def("create_scalar_parameter", ValueOrThrowWrapper([](double value, std::string_view dtype) -> absl::StatusOr { if (dtype == "fp32") { return KernelCall::Parameter{static_cast(value)}; } else if (dtype == "fp64") { return KernelCall::Parameter{static_cast(value)}; } else { return absl::InvalidArgumentError(std::string("unknown dtype: ") + dtype.data()); } })); nb::class_(m, "TritonKernelCall") .def(nb::init>()) .def("to_proto", [](const KernelCall& kernel_call, std::string name, nb::bytes metadata) { jax_triton::TritonAnyKernelCall proto; *proto.mutable_kernel_call() = kernel_call.ToProto(); proto.set_name(std::move(name)); proto.set_metadata(metadata.c_str(), metadata.size()); std::string s = proto.SerializeAsString(); return nb::bytes(s.c_str(), s.size()); }); nb::class_(m, "TritonAutotunedKernelCall") .def("__init__", [](AutotunedKernelCall* call, std::string name, std::vector> calls_and_descriptions, std::vector> input_output_aliases) { std::vector configs; configs.reserve(calls_and_descriptions.size()); for (auto& [kernel_call, desc] : calls_and_descriptions) { configs.push_back({std::move(kernel_call), std::move(desc)}); } new (call) AutotunedKernelCall(std::move(name), std::move(configs), std::move(input_output_aliases)); }) .def("to_proto", [](const AutotunedKernelCall& kernel_call, std::string name, nb::bytes metadata) { jax_triton::TritonAnyKernelCall proto; *proto.mutable_autotuned_kernel_call() = kernel_call.ToProto(); proto.set_name(std::move(name)); proto.set_metadata(metadata.c_str(), metadata.size()); std::string s = proto.SerializeAsString(); return nb::bytes(s.c_str(), s.size()); }); m.def("get_custom_call", [] { return EncapsulateFunction(&TritonKernelCall); }); m.def("get_compute_capability", ValueOrThrowWrapper([](int device) -> absl::StatusOr { int major, minor; GPU_RETURN_IF_ERROR(gpuInit(device)); GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute( &major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)); GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute( &minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)); return major * 10 + minor; })); m.def("get_serialized_metadata", ValueOrThrowWrapper( [](nb::bytes opaque) -> absl::StatusOr { JAX_ASSIGN_OR_RETURN( std::string metadata, GetTritonKernelCallSerializedMetadata( absl::string_view(opaque.c_str(), opaque.size()))); return nb::bytes(metadata.c_str(), metadata.size()); })); } } // namespace jax::JAX_GPU_NAMESPACE