// Copyright 2024 The IREE Authors // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include #include #include #include #include #include "iree/base/api.h" #include "iree/base/internal/cpu.h" #include "iree/base/internal/flags.h" #include "iree/base/internal/math.h" #include "iree/base/internal/path.h" #include "iree/hal/api.h" #include "iree/modules/hal/module.h" #include "iree/tooling/context_util.h" #include "iree/tooling/device_util.h" #include "iree/vm/api.h" #include "iree/vm/native_module_cc.h" IREE_FLAG(bool, require_exact_results, true, "Requires floating point result elements to match exactly."); IREE_FLAG( float, acceptable_fp_delta, 1e-5f, "Maximum absolute difference allowed with inexact floating point results."); IREE_FLAG( int32_t, max_elements_to_check, 10000, "Maximum number of matrix elements to check for each matmul. For larger " "matrices, only every n-th element will be checked for some n chosed to " "stay just under that threshold and to avoid being a divisor of the inner " "dimension size to avoid special patterns. As the check uses a slow " "reference implementation, this is a trade-off between test latency and " "coverage. The value 0 means check all elements."); //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// static const char* emoji(bool good) { return good ? "🦄" : "🐞"; } static int calculate_check_every(iree_hal_dim_t m_size, iree_hal_dim_t n_size) { int check_every = 1; if (FLAG_max_elements_to_check) { check_every = ((m_size * n_size) + FLAG_max_elements_to_check - 1) / FLAG_max_elements_to_check; if (check_every < 1) check_every = 1; if (check_every > 1) while ((n_size % check_every) == 0) ++check_every; } return check_every; } // Defines the type of a primitive value. typedef enum iree_e2e_test_value_type_e { // Not a value type. IREE_E2E_TEST_VALUE_TYPE_NONE = 0, // int8_t. IREE_E2E_TEST_VALUE_TYPE_I8 = 1, // int16_t. IREE_E2E_TEST_VALUE_TYPE_I16 = 2, // int32_t. IREE_E2E_TEST_VALUE_TYPE_I32 = 3, // int64_t. IREE_E2E_TEST_VALUE_TYPE_I64 = 4, // halft_t. IREE_E2E_TEST_VALUE_TYPE_F16 = 5, // float. IREE_E2E_TEST_VALUE_TYPE_F32 = 6, // double. IREE_E2E_TEST_VALUE_TYPE_F64 = 7, // bfloat16 IREE_E2E_TEST_VALUE_TYPE_BF16 = 8, } iree_e2e_test_value_type_t; // Maximum size, in bytes, of any value type we can represent. #define IREE_E2E_TEST_VALUE_STORAGE_SIZE 8 // A variant value type. typedef struct iree_e2e_test_value_t { iree_e2e_test_value_type_t type; union { int8_t i8; int16_t i16; int32_t i32; int64_t i64; float f32; uint16_t f16_u16; uint16_t bf16_u16; double f64; uint8_t value_storage[IREE_E2E_TEST_VALUE_STORAGE_SIZE]; // max size of all // value types }; } iree_e2e_test_value_t; static inline iree_e2e_test_value_t iree_e2e_test_value_make_none() { iree_e2e_test_value_t result; result.type = IREE_E2E_TEST_VALUE_TYPE_NONE; return result; } static inline iree_e2e_test_value_t iree_e2e_test_value_make_i8(int8_t value) { iree_e2e_test_value_t result; result.type = IREE_E2E_TEST_VALUE_TYPE_I8; result.i8 = value; return result; } static inline iree_e2e_test_value_t iree_e2e_test_value_make_i16( int16_t value) { iree_e2e_test_value_t result; result.type = IREE_E2E_TEST_VALUE_TYPE_I16; result.i16 = value; return result; } static inline iree_e2e_test_value_t iree_e2e_test_value_make_i32( int32_t value) { iree_e2e_test_value_t result; result.type = IREE_E2E_TEST_VALUE_TYPE_I32; result.i32 = value; return result; } static inline iree_e2e_test_value_t iree_e2e_test_value_make_f16( uint16_t value) { iree_e2e_test_value_t result; result.type = IREE_E2E_TEST_VALUE_TYPE_F16; result.f16_u16 = value; return result; } static inline iree_e2e_test_value_t iree_e2e_test_value_make_bf16( uint16_t value) { iree_e2e_test_value_t result; result.type = IREE_E2E_TEST_VALUE_TYPE_BF16; result.bf16_u16 = value; return result; } static inline iree_e2e_test_value_t iree_e2e_test_value_make_f32(float value) { iree_e2e_test_value_t result; result.type = IREE_E2E_TEST_VALUE_TYPE_F32; result.f32 = value; return result; } //===----------------------------------------------------------------------===// // Reference matmul //===----------------------------------------------------------------------===// // Reads an element from a mapped row-major matrix buffer. static iree_e2e_test_value_t read_matrix_element( iree_hal_dim_t m_size, iree_hal_dim_t n_size, iree_hal_element_type_t result_type, const void* data, iree_hal_dim_t m, iree_hal_dim_t n) { iree_host_size_t index = n + m * n_size; (void)m_size; if (iree_hal_element_type_is_integer(result_type, 8)) { return iree_e2e_test_value_make_i8(((int8_t*)data)[index]); } else if (iree_hal_element_type_is_integer(result_type, 16)) { return iree_e2e_test_value_make_i16(((int16_t*)data)[index]); } else if (iree_hal_element_type_is_integer(result_type, 32)) { return iree_e2e_test_value_make_i32(((int32_t*)data)[index]); } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) { return iree_e2e_test_value_make_f16(((uint16_t*)data)[index]); } else if (result_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) { return iree_e2e_test_value_make_bf16(((uint16_t*)data)[index]); } else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { return iree_e2e_test_value_make_f32(((float*)data)[index]); } iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "unhandled matmul result type")); return iree_e2e_test_value_make_none(); } // Get the shape of a buffer_view that is a matrix, i.e. 2D shape. static iree_status_t get_matrix_shape(iree_hal_buffer_view_t* buffer_view, iree_hal_dim_t* dims) { iree_host_size_t shape_rank = iree_hal_buffer_view_shape_rank(buffer_view); if (shape_rank != 2) { return iree_make_status( IREE_STATUS_INVALID_ARGUMENT, "expected a matrix (2D tensor) shape, got a %" PRIhsz "-dimensional shape", shape_rank); } dims[0] = iree_hal_buffer_view_shape_dim(buffer_view, 0); dims[1] = iree_hal_buffer_view_shape_dim(buffer_view, 1); if (!(dims[0] > 0 && dims[1] > 0)) { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "expected matrix dims to be positive, got %" PRIdim "x%" PRIdim, dims[0], dims[1]); } return iree_ok_status(); } #define REFERENCE_MATMUL(LHSTYPE, RHSTYPE, RESTYPE, ACCTYPE) \ static void reference_matmul_##LHSTYPE##_##RHSTYPE##_##RESTYPE##_##ACCTYPE( \ iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, \ iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, \ iree_hal_element_type_t acc_type, bool transpose_rhs, \ const LHSTYPE* lhs_data, const RHSTYPE* rhs_data, \ const ACCTYPE* acc_data, RESTYPE* result_data, iree_hal_dim_t m, \ iree_hal_dim_t n) { \ ACCTYPE acc = acc_data ? acc_data[n + m * n_size] : 0; \ for (iree_hal_dim_t k = 0; k < k_size; ++k) { \ LHSTYPE lhs_value = lhs_data[k + m * k_size]; \ RHSTYPE rhs_value = \ transpose_rhs ? rhs_data[k + n * k_size] : rhs_data[n + k * n_size]; \ acc += (ACCTYPE)lhs_value * (ACCTYPE)rhs_value; \ } \ result_data[n + m * n_size] = acc; \ } // Reference mamtul instantiations from macro REFERENCE_MATMUL // for the f32 input, f32 accumlation, and f32 result. // [float <= float * float + float] REFERENCE_MATMUL(float, float, float, float) // Reference mamtul instantiations from macro REFERENCE_MATMUL // for the int8_t input, int32_t accumlation, and int32_t result. // [i32 <= i8 * i8 + i32] REFERENCE_MATMUL(int8_t, int8_t, int32_t, int32_t) // Reference mamtul instantiations from macro REFERENCE_MATMUL // for the int32_t input, int32_t accumlation, and int32_t result. // [i32 <= i32 * i32 + i32] REFERENCE_MATMUL(int32_t, int32_t, int32_t, int32_t) // Reference mamtul for the f16 input, f16 accumlation, and f16 result. // [f16 <= f16 * f16 + f16] static void reference_matmul_f16_f16_f16_f16( iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, iree_hal_element_type_t acc_type, bool transpose_rhs, const uint16_t* lhs_data, const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m, iree_hal_dim_t n) { float acc = acc_data ? iree_math_f16_to_f32(acc_data[n + m * n_size]) : 0.f; for (iree_hal_dim_t k = 0; k < k_size; ++k) { int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size; acc += iree_math_f16_to_f32(lhs_data[k + m * k_size]) * iree_math_f16_to_f32(rhs_data[rhs_index]); } result_data[n + m * n_size] = iree_math_f32_to_f16(acc); } // Reference mamtul for the f16 input, f32 accumlation, and f32 result. // [f32 <= f16 * f16 + f32] static void reference_matmul_f16_f16_f32_f32( iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, iree_hal_element_type_t acc_type, bool transpose_rhs, const uint16_t* lhs_data, const uint16_t* rhs_data, const float* acc_data, float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) { float acc = acc_data ? acc_data[n + m * n_size] : 0.f; for (iree_hal_dim_t k = 0; k < k_size; ++k) { int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size; acc += iree_math_f16_to_f32(lhs_data[k + m * k_size]) * iree_math_f16_to_f32(rhs_data[rhs_index]); } result_data[n + m * n_size] = acc; } // Reference mamtul for the bf16 input, bf16 accumlation, and bf16 result. // [bf16 <= bf16 * bf16 + bf16] static void reference_matmul_bf16_bf16_bf16_bf16( iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, iree_hal_element_type_t acc_type, bool transpose_rhs, const uint16_t* lhs_data, const uint16_t* rhs_data, const uint16_t* acc_data, uint16_t* result_data, iree_hal_dim_t m, iree_hal_dim_t n) { float acc = acc_data ? iree_math_bf16_to_f32(acc_data[n + m * n_size]) : 0.f; for (iree_hal_dim_t k = 0; k < k_size; ++k) { int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size; acc += iree_math_bf16_to_f32(lhs_data[k + m * k_size]) * iree_math_bf16_to_f32(rhs_data[rhs_index]); } result_data[n + m * n_size] = iree_math_f32_to_bf16(acc); } // Reference mamtul for the bf16 input, f32 accumlation, and f32 result. // [f32 <= bf16 * bf16 + f32] static void reference_matmul_bf16_bf16_f32_f32( iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, iree_hal_element_type_t acc_type, bool transpose_rhs, const uint16_t* lhs_data, const uint16_t* rhs_data, const float* acc_data, float* result_data, iree_hal_dim_t m, iree_hal_dim_t n) { float acc = acc_data ? acc_data[n + m * n_size] : 0.f; for (iree_hal_dim_t k = 0; k < k_size; ++k) { int64_t rhs_index = transpose_rhs ? k + n * k_size : n + k * n_size; acc += iree_math_bf16_to_f32(lhs_data[k + m * k_size]) * iree_math_bf16_to_f32(rhs_data[rhs_index]); } result_data[n + m * n_size] = acc; } // Helper for reference_matmul. // Computes one element in the result matrix. static iree_status_t reference_matmul_element( iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, iree_hal_element_type_t acc_type, bool transpose_rhs, void* lhs_data, void* rhs_data, void* acc_data, void* result_data, iree_hal_dim_t m, iree_hal_dim_t n) { if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 && acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { reference_matmul_float_float_float_float( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, (const float*)lhs_data, (const float*)rhs_data, (const float*)acc_data, (float*)result_data, m, n); } else if (iree_hal_element_type_is_integer(lhs_type, 8) && iree_hal_element_type_is_integer(rhs_type, 8) && iree_hal_element_type_is_integer(acc_type, 32)) { reference_matmul_int8_t_int8_t_int32_t_int32_t( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, (const int8_t*)lhs_data, (const int8_t*)rhs_data, (const int32_t*)acc_data, (int32_t*)result_data, m, n); } else if (iree_hal_element_type_is_integer(lhs_type, 32) && iree_hal_element_type_is_integer(rhs_type, 32) && iree_hal_element_type_is_integer(acc_type, 32)) { reference_matmul_int32_t_int32_t_int32_t_int32_t( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, (const int32_t*)lhs_data, (const int32_t*)rhs_data, (const int32_t*)acc_data, (int32_t*)result_data, m, n); } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 && rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 && acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16) { reference_matmul_f16_f16_f16_f16( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, (const uint16_t*)lhs_data, (const uint16_t*)rhs_data, (const uint16_t*)acc_data, (uint16_t*)result_data, m, n); } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 && rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_16 && acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { reference_matmul_f16_f16_f32_f32( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, (const uint16_t*)lhs_data, (const uint16_t*)rhs_data, (const float*)acc_data, (float*)result_data, m, n); } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 && rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 && acc_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16) { reference_matmul_bf16_bf16_bf16_bf16( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, (const uint16_t*)lhs_data, (const uint16_t*)rhs_data, (const uint16_t*)acc_data, (uint16_t*)result_data, m, n); } else if (lhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 && rhs_type == IREE_HAL_ELEMENT_TYPE_BFLOAT_16 && acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) { reference_matmul_bf16_bf16_f32_f32( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, (const uint16_t*)lhs_data, (const uint16_t*)rhs_data, (const float*)acc_data, (float*)result_data, m, n); } else { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "unhandled combination of element types in matmul"); } return iree_ok_status(); } // Reference matmul implementation, used to compare matmul results against. static iree_status_t reference_matmul( iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, iree_hal_element_type_t lhs_type, iree_hal_element_type_t rhs_type, iree_hal_element_type_t acc_type, bool transpose_rhs, iree_byte_span_t lhs_contents, iree_byte_span_t rhs_contents, iree_byte_span_t acc_contents, iree_byte_span_t result_contents, int compute_every) { IREE_TRACE_ZONE_BEGIN(z0); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, m_size); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, k_size); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, n_size); iree_host_size_t count = 0; for (iree_hal_dim_t m = 0; m < m_size; ++m) { for (iree_hal_dim_t n = 0; n < n_size; ++n) { if (++count < compute_every) continue; count = 0; IREE_RETURN_IF_ERROR(reference_matmul_element( m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs, lhs_contents.data, rhs_contents.data, acc_contents.data, result_contents.data, m, n)); } } IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } //===----------------------------------------------------------------------===// // Matmul comparison/logging //===----------------------------------------------------------------------===// typedef struct { iree_allocator_t host_allocator; iree_hal_dim_t m; iree_hal_dim_t k; iree_hal_dim_t n; iree_hal_element_type_t lhs_type; iree_hal_element_type_t rhs_type; iree_hal_element_type_t acc_type; iree_hal_element_type_t result_type; bool transpose_rhs; iree_byte_span_t lhs_contents; iree_byte_span_t rhs_contents; iree_byte_span_t acc_contents; iree_byte_span_t actual_contents; iree_byte_span_t expected_contents; } matmul_results_t; static void matmul_results_deinitialize(matmul_results_t* results); static iree_status_t matmul_results_initialize( iree_hal_device_t* device, iree_hal_dim_t m_size, iree_hal_dim_t k_size, iree_hal_dim_t n_size, uint32_t transpose_rhs, iree_hal_buffer_view_t* lhs, iree_hal_buffer_view_t* rhs, iree_hal_buffer_view_t* acc, iree_hal_buffer_view_t* result, iree_allocator_t host_allocator, matmul_results_t* out_results) { IREE_TRACE_ZONE_BEGIN(z0); memset(out_results, 0, sizeof(*out_results)); out_results->host_allocator = host_allocator; out_results->m = m_size; out_results->k = k_size; out_results->n = n_size; out_results->lhs_type = iree_hal_buffer_view_element_type(lhs); out_results->rhs_type = iree_hal_buffer_view_element_type(rhs); out_results->acc_type = iree_hal_buffer_view_element_type(result); out_results->result_type = iree_hal_buffer_view_element_type(result); out_results->transpose_rhs = transpose_rhs != 0; iree_hal_buffer_t* lhs_buffer = iree_hal_buffer_view_buffer(lhs); iree_hal_buffer_t* rhs_buffer = iree_hal_buffer_view_buffer(rhs); iree_hal_buffer_t* acc_buffer = acc ? iree_hal_buffer_view_buffer(acc) : NULL; iree_hal_buffer_t* result_buffer = iree_hal_buffer_view_buffer(result); iree_status_t status = iree_ok_status(); if (iree_status_is_ok(status)) { out_results->lhs_contents.data_length = iree_hal_buffer_byte_length(lhs_buffer); status = iree_allocator_malloc(host_allocator, out_results->lhs_contents.data_length, (void**)&out_results->lhs_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( device, lhs_buffer, 0, out_results->lhs_contents.data, out_results->lhs_contents.data_length, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); } if (iree_status_is_ok(status)) { out_results->rhs_contents.data_length = iree_hal_buffer_byte_length(rhs_buffer); status = iree_allocator_malloc(host_allocator, out_results->rhs_contents.data_length, (void**)&out_results->rhs_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( device, rhs_buffer, 0, out_results->rhs_contents.data, out_results->rhs_contents.data_length, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); } if (acc_buffer) { if (iree_status_is_ok(status)) { out_results->acc_contents.data_length = iree_hal_buffer_byte_length(acc_buffer); status = iree_allocator_malloc(host_allocator, out_results->acc_contents.data_length, (void**)&out_results->acc_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( device, acc_buffer, 0, out_results->acc_contents.data, out_results->acc_contents.data_length, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); } } if (iree_status_is_ok(status)) { out_results->actual_contents.data_length = iree_hal_buffer_byte_length(result_buffer); status = iree_allocator_malloc(host_allocator, out_results->actual_contents.data_length, (void**)&out_results->actual_contents.data); } if (iree_status_is_ok(status)) { status = iree_hal_device_transfer_d2h( device, result_buffer, 0, out_results->actual_contents.data, out_results->actual_contents.data_length, IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()); } if (iree_status_is_ok(status)) { out_results->expected_contents.data_length = iree_hal_buffer_byte_length(result_buffer); status = iree_allocator_malloc( host_allocator, out_results->expected_contents.data_length, (void**)&out_results->expected_contents.data); } if (!iree_status_is_ok(status)) { matmul_results_deinitialize(out_results); } IREE_TRACE_ZONE_END(z0); return status; } static void matmul_results_deinitialize(matmul_results_t* results) { IREE_TRACE_ZONE_BEGIN(z0); iree_allocator_free(results->host_allocator, results->lhs_contents.data); iree_allocator_free(results->host_allocator, results->rhs_contents.data); if (!iree_byte_span_is_empty(results->acc_contents)) { iree_allocator_free(results->host_allocator, results->acc_contents.data); } iree_allocator_free(results->host_allocator, results->actual_contents.data); iree_allocator_free(results->host_allocator, results->expected_contents.data); IREE_TRACE_ZONE_END(z0); } // Enum controlling how many decimals to print floats with. typedef enum precision_e { PRECISION_LOW, PRECISION_HIGH, } precision_t; // Prints a iree_e2e_test_value_t to a string buffer. Returns the number of // characters written. Like snprintf. static int snprintf_value(char* buf, size_t bufsize, iree_e2e_test_value_t value, precision_t precision) { switch (value.type) { case IREE_E2E_TEST_VALUE_TYPE_I8: return snprintf(buf, bufsize, "%" PRIi8, value.i8); case IREE_E2E_TEST_VALUE_TYPE_I16: return snprintf(buf, bufsize, "%" PRIi16, value.i16); case IREE_E2E_TEST_VALUE_TYPE_I32: return snprintf(buf, bufsize, "%" PRIi32, value.i32); case IREE_E2E_TEST_VALUE_TYPE_I64: return snprintf(buf, bufsize, "%" PRIi64, value.i64); case IREE_E2E_TEST_VALUE_TYPE_F16: return snprintf(buf, bufsize, precision == PRECISION_HIGH ? "%.5g" : "%.4g", iree_math_f16_to_f32(value.f16_u16)); case IREE_E2E_TEST_VALUE_TYPE_BF16: return snprintf(buf, bufsize, precision == PRECISION_HIGH ? "%.5g" : "%.4g", iree_math_bf16_to_f32(value.bf16_u16)); case IREE_E2E_TEST_VALUE_TYPE_F32: return snprintf(buf, bufsize, precision == PRECISION_HIGH ? "%.8g" : "%.4g", value.f32); case IREE_E2E_TEST_VALUE_TYPE_F64: return snprintf(buf, bufsize, precision == PRECISION_HIGH ? "%.16g" : "%.4g", value.f64); default: iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "unhandled value type")); return 0; } } // Returns true if |expected| and |actual| agree to tolerable accuracy. static bool matmul_result_elements_agree(iree_e2e_test_value_t expected, iree_e2e_test_value_t actual) { if (expected.type != actual.type) { iree_status_abort( iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "mismatched types")); return false; } switch (expected.type) { case IREE_E2E_TEST_VALUE_TYPE_I32: return actual.i32 == expected.i32; // Since we fill buffers with small integers for floating point GEMMs // functional testing, we can test for bit-exactness on the actual and // expected values. Inexact results are only permitted when the // `require_exact_results` flag is set to `false`. case IREE_E2E_TEST_VALUE_TYPE_F16: if (actual.f16_u16 == expected.f16_u16) return true; if (FLAG_require_exact_results) return false; return fabsf(iree_math_f16_to_f32(actual.f16_u16) - iree_math_f16_to_f32(expected.f16_u16)) < FLAG_acceptable_fp_delta; case IREE_E2E_TEST_VALUE_TYPE_BF16: if (actual.bf16_u16 == expected.bf16_u16) return true; if (FLAG_require_exact_results) return false; return fabsf(iree_math_bf16_to_f32(actual.bf16_u16) - iree_math_bf16_to_f32(expected.bf16_u16)) < FLAG_acceptable_fp_delta; case IREE_E2E_TEST_VALUE_TYPE_F32: if (actual.f32 == expected.f32) return true; if (FLAG_require_exact_results) return false; return fabsf(actual.f32 - expected.f32) < FLAG_acceptable_fp_delta; default: iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "unhandled value type")); return false; } } // Returns the largest number of characters to print any matrix element. static int get_max_elem_width(precision_t precision, iree_hal_dim_t rows, iree_hal_dim_t row_start, iree_hal_dim_t row_end, iree_hal_dim_t cols, iree_hal_dim_t col_start, iree_hal_dim_t col_end, iree_hal_element_type_t element_type, const uint8_t* matrix) { int max_elem_width = 0; for (int row = row_start; row < row_end; row++) { for (int col = col_start; col < col_end; col++) { iree_e2e_test_value_t elem = read_matrix_element(rows, cols, element_type, matrix, row, col); // NOTE: iree_max is a macro and may evaluate its args twice. char buf[64]; int this_elem_width = snprintf_value(buf, sizeof(buf), elem, precision); max_elem_width = iree_max(max_elem_width, this_elem_width); } } return max_elem_width; } // Prints |matrix| to |file|, with |label| as caption. // |precision| controls how many decimals are printed for float values. // // If |other_matrix| is not NULL, then any matrix entries that disagree // between |matrix| and |other_matrix| (according to // matmul_result_elements_agree) are highlighted. // // |highlight| is either NULL or is a UTF-8 string that will be printed next to // any entry of |matrix| that disagrees with the corresponding entry of // |other_matrix|. // // |highlight| should be NULL if and only if |other_matrix| is NULL. // // In order for matrix columns to be properly laid out, the rendering of // |highlight| in a fixed-width font should have the width of two regular Latin // characters. According to // https://www.unicode.org/reports/tr11/#Recommendations, a single emoji // character should meet that requirement. static void print_matrix(FILE* file, const char* label, precision_t precision, iree_hal_dim_t rows, iree_hal_dim_t row_start, iree_hal_dim_t row_end, iree_hal_dim_t cols, iree_hal_dim_t col_start, iree_hal_dim_t col_end, iree_hal_element_type_t element_type, const uint8_t* matrix, const uint8_t* other_matrix, const char* highlight) { IREE_ASSERT((other_matrix == NULL) == (highlight == NULL)); int max_elem_width = get_max_elem_width(precision, rows, row_start, row_end, cols, col_start, col_end, element_type, matrix); if (other_matrix) { // NOTE: iree_max is a macro and may evaluate its args twice. int other_matrix_max_elem_width = get_max_elem_width(precision, rows, row_start, row_end, cols, col_start, col_end, element_type, other_matrix); max_elem_width = iree_max(max_elem_width, other_matrix_max_elem_width); } fprintf(file, "%s (rows %" PRIdsz "..%" PRIdsz " out of 0..%" PRIdsz ", columns %" PRIdsz "..%" PRIdsz " out of 0..%" PRIdsz ")\n", label, row_start, row_end - 1, rows - 1, col_start, col_end - 1, cols - 1); for (int row = row_start; row < row_end; row++) { for (int col = col_start; col < col_end; col++) { iree_e2e_test_value_t element = read_matrix_element(rows, cols, element_type, matrix, row, col); bool disagree = false; if (other_matrix) { iree_e2e_test_value_t other_element = read_matrix_element( rows, cols, element_type, other_matrix, row, col); disagree = !matmul_result_elements_agree(element, other_element); } char buf[64]; snprintf_value(buf, sizeof(buf), element, precision); fprintf(file, "%*s", max_elem_width, buf); // See comment on |highlight| function parameter for why 2 spaces. // A 3rd space is added unconditionally to make it clear that a highlight // concerns the matrix entry to its left. fprintf(file, "%s ", disagree ? highlight : " "); } fprintf(file, "\n"); } } // Helper for check_matmul_results: handler for the failure case. // If |file| is not NULL, detailed logging is written to it. static iree_status_t check_matmul_failure(FILE* file, const matmul_results_t* results, iree_e2e_test_value_t actual_value, iree_e2e_test_value_t expected_value, iree_hal_dim_t row, iree_hal_dim_t col, int check_every) { if (!file || check_every > 1) { // No logging of errors with check_every>1 as most of the reference matrix // elements have not been computed. The caller is expected to retry with // check_every=1. return iree_make_status(IREE_STATUS_ABORTED); } IREE_TRACE_ZONE_BEGIN(z0); fprintf(file, "\n\nerror: the actual and expected result matrices disagree " "at row %" PRIdim ", column %" PRIdim ".\n\n", row, col); char actual_value_buf[32]; char expected_value_buf[32]; snprintf_value(actual_value_buf, sizeof(actual_value_buf), actual_value, PRECISION_HIGH); snprintf_value(expected_value_buf, sizeof(expected_value_buf), expected_value, PRECISION_HIGH); fprintf(file, "actual value: %s\n", actual_value_buf); fprintf(file, "expected value: %s\n", expected_value_buf); iree_hal_dim_t context = 8; const char* context_env = getenv("IREE_MATMUL_TEST_SHOW_CONTEXT"); if (context_env) { if (1 != sscanf(context_env, "%" PRIdim, &context)) { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "failed to parse IREE_MATMUL_TEST_SHOW_CONTEXT " "as \"%%" PRIdim "\"; got \"%s\"", context_env); } } iree_hal_dim_t m_start = (iree_hal_dim_t)iree_max(0, (int64_t)row - (int64_t)context); iree_hal_dim_t m_end = iree_min(results->m, row + context); iree_hal_dim_t n_start = (iree_hal_dim_t)iree_max(0, (int64_t)col - (int64_t)context); iree_hal_dim_t n_end = iree_min(results->n, col + context); iree_hal_dim_t k_start = 0; iree_hal_dim_t k_end = iree_min(results->k, 2 * context); // [k_start, k_end) could be arbitrarily long at this point. Constrain it a // bit to avoid huge output. k_end = iree_min(k_end, k_start + 4 * context); fprintf(file, "\n"); print_matrix(file, "left-hand side", PRECISION_LOW, results->m, m_start, m_end, results->k, k_start, k_end, results->lhs_type, results->lhs_contents.data, NULL, NULL); fprintf(file, "\n"); print_matrix(file, "right-hand side", PRECISION_LOW, results->k, k_start, k_end, results->n, n_start, n_end, results->rhs_type, results->rhs_contents.data, NULL, NULL); fprintf(file, "\n"); if (results->acc_contents.data) { print_matrix(file, "input accumulator", PRECISION_LOW, results->m, m_start, m_end, results->n, n_start, n_end, results->acc_type, results->acc_contents.data, NULL, NULL); fprintf(file, "\n"); } print_matrix(file, "expected result", PRECISION_LOW, results->m, m_start, m_end, results->n, n_start, n_end, results->result_type, results->expected_contents.data, results->actual_contents.data, emoji(true)); fprintf(file, "\n"); print_matrix(file, "actual result", PRECISION_LOW, results->m, m_start, m_end, results->n, n_start, n_end, results->result_type, results->actual_contents.data, results->expected_contents.data, emoji(false)); fprintf(file, "\n"); IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_ABORTED); } // Helper for check_matmul_results: the actual interesting part once we've // obtained and validated the {m,k,n}_size values. On error, detailed logging is // written to |file| if it is not NULL. static iree_status_t check_matmul_results_impl(FILE* file, const matmul_results_t* results, int check_every) { IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, reference_matmul( results->m, results->k, results->n, results->lhs_type, results->rhs_type, results->acc_type, results->transpose_rhs, results->lhs_contents, results->rhs_contents, results->acc_contents, results->expected_contents, check_every)); int count = 0; for (iree_hal_dim_t m = 0; m < results->m; ++m) { for (iree_hal_dim_t n = 0; n < results->n; ++n) { if (++count < check_every) continue; count = 0; iree_e2e_test_value_t actual_value = read_matrix_element(results->m, results->n, results->result_type, results->actual_contents.data, m, n); iree_e2e_test_value_t expected_value = read_matrix_element(results->m, results->n, results->result_type, results->expected_contents.data, m, n); if (!matmul_result_elements_agree(actual_value, expected_value)) { iree_status_t status = check_matmul_failure( file, results, actual_value, expected_value, m, n, check_every); IREE_TRACE_ZONE_END(z0); return status; } } } IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } // Given an actual matmul's inputs and output (all host-local), uses a reference // matmul implementation on the same inputs to check if the output is correct. // On error, detailed logging is written to |file| if it is not NULL. static iree_status_t check_matmul_results(FILE* file, const matmul_results_t* results) { IREE_TRACE_ZONE_BEGIN(z0); int check_every = calculate_check_every(results->m, results->n); iree_status_t status = check_matmul_results_impl(file, results, check_every); if (!iree_status_is_ok(status) && check_every > 1) { // If we got a failure with check_every>1, that didn't log a useful // numerical summary, as most of the reference matrix entries hadn't been // computed. Rerun now with check_every=1 to get that numerical logging. iree_status_ignore(status); status = check_matmul_results_impl(file, results, 1); } IREE_TRACE_ZONE_END(z0); return status; } //===----------------------------------------------------------------------===// // RNG utilities //===----------------------------------------------------------------------===// // Parameter for locally defined lcg similar to std::minstd_rand. #define IREE_PRNG_MULTIPLIER 48271 #define IREE_PRNG_MODULUS 2147483647 // Writes an element of the given |element_type| with the given integral |value| // to |dst|. static void write_element(iree_hal_element_type_t element_type, int32_t value, void* dst) { #define WRITE_ELEMENT_CASE(ETYPE, CTYPE) \ case IREE_HAL_ELEMENT_TYPE_##ETYPE: \ *(CTYPE*)dst = (CTYPE)value; \ break; switch (element_type) { WRITE_ELEMENT_CASE(INT_8, int8_t) WRITE_ELEMENT_CASE(INT_16, int16_t) WRITE_ELEMENT_CASE(INT_32, int32_t) WRITE_ELEMENT_CASE(INT_64, int64_t) WRITE_ELEMENT_CASE(SINT_8, int8_t) WRITE_ELEMENT_CASE(SINT_16, int16_t) WRITE_ELEMENT_CASE(SINT_32, int32_t) WRITE_ELEMENT_CASE(SINT_64, int64_t) WRITE_ELEMENT_CASE(UINT_8, uint8_t) WRITE_ELEMENT_CASE(UINT_16, uint16_t) WRITE_ELEMENT_CASE(UINT_32, uint32_t) WRITE_ELEMENT_CASE(UINT_64, uint64_t) // clang-format off case IREE_HAL_ELEMENT_TYPE_FLOAT_16: *(uint16_t*)dst = iree_math_f32_to_f16((float)value); break; case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: *(uint16_t*)dst = iree_math_f32_to_bf16((float)value); break; WRITE_ELEMENT_CASE(FLOAT_32, float) WRITE_ELEMENT_CASE(FLOAT_64, double) // clang-format on default: IREE_ASSERT(false, "unhandled element type"); break; } #undef WRITE_ELEMENT_CASE } // Simple deterministic pseudorandom generator. // This function is same as C++'s std::minstd_rand. static uint32_t pseudorandom_uint32(uint32_t* state) { *state = (*state * IREE_PRNG_MULTIPLIER) % IREE_PRNG_MODULUS; return *state; } // Returns a random uint32_t in the range [0, range). static inline uint32_t pseudorandom_range(uint32_t* state, uint32_t range) { return pseudorandom_uint32(state) % range; } // Get minimum and maximum for integer-valued uniform distribution. static void get_min_max_for_element_type(iree_hal_element_type_t element_type, int32_t* min, int32_t* max) { switch (element_type) { case IREE_HAL_ELEMENT_TYPE_INT_8: case IREE_HAL_ELEMENT_TYPE_SINT_8: *min = -2; *max = +2; break; case IREE_HAL_ELEMENT_TYPE_UINT_8: *min = 0; *max = +2; break; case IREE_HAL_ELEMENT_TYPE_INT_16: case IREE_HAL_ELEMENT_TYPE_SINT_16: case IREE_HAL_ELEMENT_TYPE_FLOAT_16: *min = -4; *max = +4; break; case IREE_HAL_ELEMENT_TYPE_BFLOAT_16: *min = -2; *max = +2; break; case IREE_HAL_ELEMENT_TYPE_UINT_16: *min = 0; *max = +4; break; case IREE_HAL_ELEMENT_TYPE_INT_32: case IREE_HAL_ELEMENT_TYPE_SINT_32: case IREE_HAL_ELEMENT_TYPE_FLOAT_32: *min = -8; *max = +8; break; case IREE_HAL_ELEMENT_TYPE_UINT_32: *min = 0; *max = +8; break; case IREE_HAL_ELEMENT_TYPE_INT_64: case IREE_HAL_ELEMENT_TYPE_SINT_64: case IREE_HAL_ELEMENT_TYPE_FLOAT_64: *min = -16; *min = +16; break; case IREE_HAL_ELEMENT_TYPE_UINT_64: *min = 0; *max = +16; break; default: IREE_ASSERT(false, "unhandled element type"); break; } } //===----------------------------------------------------------------------===// // `matmul_test` custom module //===----------------------------------------------------------------------===// // This uses the C++ wrapper to keep things simple. Though easier to use it's // got additional overhead/code-size bloat that doesn't matter in a test like // this. Making a C module builder API that removes the boilerplate there is TBD // so this file is written in C besides this module so that we can swap it back // to being pure C in the future. namespace { using namespace iree; class MatmulTestModuleState final { public: explicit MatmulTestModuleState(iree_allocator_t host_allocator) : host_allocator_(host_allocator) {} ~MatmulTestModuleState() = default; // Fills the destination span with pseudorandom values of the given // |element_type|. The given |seed| is passed to the pseudorandom generator. // The pseudorandom values are reproducible both across runs and across // machines. StatusOr> GenerateRandomMatrix( const vm::ref device, int64_t dim0, int64_t dim1, iree_hal_element_type_t element_type, int32_t seed) { iree_hal_dim_t dims[2] = { (iree_hal_dim_t)dim0, (iree_hal_dim_t)dim1, }; iree_hal_buffer_params_t buffer_params = {0}; buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT; buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL; buffer_params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE; vm::ref result_view; struct callback_state_t { iree_hal_element_type_t element_type; int32_t seed; } callback_state = { element_type, seed, }; IREE_RETURN_IF_ERROR(iree_hal_buffer_view_generate_buffer( device.get(), iree_hal_device_allocator(device.get()), IREE_ARRAYSIZE(dims), dims, element_type, IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params, +[](iree_hal_buffer_mapping_t* mapping, void* user_data) { callback_state_t callback_state = *(callback_state_t*)user_data; iree_byte_span_t span = mapping->contents; // Generate "uniform" integer-valued numbers in the range [min, max]. int32_t min = 0; int32_t max = 0; get_min_max_for_element_type(callback_state.element_type, &min, &max); uint32_t range = (max - min + 1); iree_host_size_t element_byte_count = iree_hal_element_dense_byte_count(callback_state.element_type); uint8_t* data_end = span.data + span.data_length; uint32_t state = callback_state.seed; for (uint8_t* data = span.data; data < data_end; data += element_byte_count) { int32_t value = (int32_t)pseudorandom_range(&state, range) + min; write_element(callback_state.element_type, value, data); } return iree_ok_status(); }, &callback_state, &result_view)); return std::move(result_view); } Status CheckMatmulResults( const vm::ref device, int64_t m, int64_t k, int64_t n, int32_t transpose_rhs, const vm::ref lhs, const vm::ref rhs, const vm::ref acc, const vm::ref actual_result) { matmul_results_t results = {}; IREE_RETURN_IF_ERROR(matmul_results_initialize( device.get(), (iree_hal_dim_t)m, (iree_hal_dim_t)k, (iree_hal_dim_t)n, transpose_rhs, lhs.get(), rhs.get(), acc.get(), actual_result.get(), host_allocator_, &results)); iree_status_t status = check_matmul_results(stderr, &results); matmul_results_deinitialize(&results); return status; } private: iree_allocator_t host_allocator_; }; static const vm::NativeFunction kMatmulTestModuleFunctions[] = { vm::MakeNativeFunction("generate_random_matrix", &MatmulTestModuleState::GenerateRandomMatrix), vm::MakeNativeFunction("check_matmul_results", &MatmulTestModuleState::CheckMatmulResults), }; struct MatmulTestModule final : public vm::NativeModule { using vm::NativeModule::NativeModule; StatusOr> CreateState( iree_allocator_t host_allocator) override { return std::make_unique(host_allocator); } }; } // namespace static iree_status_t matmul_test_module_create(iree_vm_instance_t* instance, iree_allocator_t host_allocator, iree_vm_module_t** out_module) { IREE_ASSERT_ARGUMENT(out_module); *out_module = NULL; auto module = std::make_unique( "matmul_test", /*version=*/0, instance, host_allocator, iree::span>( kMatmulTestModuleFunctions)); *out_module = module.release()->interface(); return iree_ok_status(); } //===----------------------------------------------------------------------===// // Test runner //===----------------------------------------------------------------------===// // Returns true if the |function| is a supported callable test function. // We only support functions that are publicly exported, not an internal // compiler/runtime function (__ prefixed), and take/return no args/results. static iree_status_t check_test_function(iree_vm_function_t function, bool* out_is_valid) { *out_is_valid = true; iree_string_view_t function_name = iree_vm_function_name(&function); if (iree_string_view_starts_with(function_name, iree_make_cstring_view("__"))) { // Internal compiler/runtime support function. *out_is_valid = false; } iree_vm_function_signature_t function_signature = iree_vm_function_signature(&function); iree_host_size_t argument_count = 0; iree_host_size_t result_count = 0; IREE_RETURN_IF_ERROR(iree_vm_function_call_count_arguments_and_results( &function_signature, &argument_count, &result_count)); if (argument_count || result_count) { // Takes args or has results we don't expect. *out_is_valid = false; } return iree_ok_status(); } // Synchronous runs a test |function|. // If the test fails then the failure status is returned to the caller. static iree_status_t run_test_function(iree_vm_context_t* context, iree_vm_function_t function, iree_allocator_t host_allocator) { IREE_TRACE_ZONE_BEGIN(z0); iree_string_view_t function_name = iree_vm_function_name(&function); IREE_TRACE_ZONE_APPEND_TEXT(z0, function_name.data, function_name.size); fprintf(stderr, "--- TEST[%.*s] ---\n", (int)function_name.size, function_name.data); iree_string_view_t function_desc = iree_vm_function_lookup_attr_by_name(&function, IREE_SV("description")); if (!iree_string_view_is_empty(function_desc)) { fprintf(stderr, "%.*s\n", (int)function_desc.size, function_desc.data); } iree_status_t status = iree_vm_invoke( context, function, IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/NULL, /*inputs=*/NULL, /*outputs=*/NULL, host_allocator); IREE_TRACE_ZONE_END(z0); return status; } // Runs all test functions in |test_module|. static iree_status_t run_all_test_functions(iree_vm_context_t* context, iree_vm_module_t* test_module, iree_allocator_t host_allocator) { IREE_TRACE_ZONE_BEGIN(z0); // Walk all functions and find the ones we can run (no args, non-internal). const iree_vm_module_signature_t module_signature = iree_vm_module_signature(test_module); for (iree_host_size_t i = 0; i < module_signature.export_function_count; ++i) { // Get the function and filter to just the public user exports. iree_vm_function_t function; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_vm_module_lookup_function_by_ordinal( test_module, IREE_VM_FUNCTION_LINKAGE_EXPORT, i, &function)); bool is_valid = false; IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, check_test_function(function, &is_valid)); if (is_valid) { // Try to run the function and fail on mismatch. IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, run_test_function(context, function, host_allocator)); } } IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } // Returns OK if there are declared requirements on |module| and they are all // met and otherwise UNAVAILABLE indicating that the module should not be run. static iree_status_t check_module_requirements(iree_vm_module_t* module) { iree_string_view_t target_features = iree_vm_module_lookup_attr_by_name(module, IREE_SV("target_features")); while (!iree_string_view_is_empty(target_features)) { iree_string_view_t required_feature; iree_string_view_split(target_features, ',', &required_feature, &target_features); if (iree_string_view_is_empty(required_feature)) continue; int64_t feature_is_supported = 0; IREE_RETURN_IF_ERROR( iree_cpu_lookup_data_by_key(required_feature, &feature_is_supported)); if (!feature_is_supported) { return iree_make_status( // The error status matters. We distinguish "feature not supported" // which is a normal thing to happen from actual errors. IREE_STATUS_UNAVAILABLE, "target device does not have the required feature '%.*s'", (int)required_feature.size, required_feature.data); } } return iree_ok_status(); } static iree_status_t load_and_run_e2e_tests(iree_allocator_t host_allocator) { IREE_TRACE_ZONE_BEGIN(z0); iree_cpu_initialize(host_allocator); iree_vm_instance_t* instance = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_tooling_create_instance(host_allocator, &instance)); iree_tooling_module_list_t module_list; iree_tooling_module_list_initialize(&module_list); // Create the test module providing helper functions used by test programs. iree_vm_module_t* matmul_test_module = NULL; iree_status_t status = matmul_test_module_create(instance, host_allocator, &matmul_test_module); if (iree_status_is_ok(status)) { status = iree_tooling_module_list_push_back(&module_list, matmul_test_module); } iree_vm_module_release(matmul_test_module); // Load all modules specified by --module= flags. if (iree_status_is_ok(status)) { status = iree_tooling_load_modules_from_flags(instance, host_allocator, &module_list); } iree_vm_module_t* test_module = iree_tooling_module_list_back(&module_list); // Create the context with our support module and all --module= flags. iree_vm_context_t* context = NULL; iree_hal_device_t* device = NULL; if (iree_status_is_ok(status)) { status = iree_tooling_create_context_from_flags( instance, module_list.count, module_list.values, /*default_device_uri=*/iree_string_view_empty(), host_allocator, &context, &device, /*out_device_allocator=*/NULL); } // Ensure the test module is possible to run. if (iree_status_is_ok(status)) { status = check_module_requirements(test_module); } iree_tooling_module_list_reset(&module_list); // Begin profiling (if enabled). if (iree_status_is_ok(status)) { status = iree_hal_begin_profiling_from_flags(device); } // Run all of the tests in the test module. if (iree_status_is_ok(status)) { status = run_all_test_functions(context, test_module, host_allocator); } // End profiling (if enabled). if (iree_status_is_ok(status)) { status = iree_hal_end_profiling_from_flags(device); } iree_hal_device_release(device); iree_vm_context_release(context); iree_vm_instance_release(instance); IREE_TRACE_ZONE_END(z0); return status; } int main(int argc, char** argv) { IREE_TRACE_APP_ENTER(); iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv); if (argc != 1) { fprintf(stderr, "use --module= flags to specify the modules to run\n"); IREE_TRACE_APP_EXIT(EXIT_FAILURE); return EXIT_FAILURE; } iree_status_t status = load_and_run_e2e_tests(iree_allocator_system()); int exit_code = EXIT_SUCCESS; if (!iree_status_is_ok(status)) { iree_status_fprint(stderr, status); bool is_unavailable = iree_status_is_unavailable(status); iree_status_free(status); exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE; } IREE_TRACE_APP_EXIT(exit_code); return exit_code; }