// Copyright 2015 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "eight_bit_int_gemm.h" #include // gemmlowp symbols should have hidden visibility. // currently this is ensured in the build system by // passing -finlines-visibility-hidden. TODO: it would be // safer to hardcode it here with some #pragma's. #include "../public/gemmlowp.h" // Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON // code. This code path consists of a number of meta-programmed, automatically // generated GEMM kernels that are suitable for some sizes of input matrices. // Due to the fact that the generated code relies heavily on loop unrolling, // inling and currying of runtime parameters the size of the generated binary // is quite significant (approx. 200kb) which might be prohibitive in // low-memory situations. #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) #include "../meta/legacy_multi_thread_gemm.h" #else #if defined(GEMMLOWP_USE_META_FASTPATH) #warning "META fast path turned on without NEON!" #endif #endif namespace gemmlowp { namespace eight_bit_int_gemm { namespace { // To be used as template parameter for GlobalLock. // GlobalLock is the global lock // on EightBitIntGemm entry points, protecting // EightBitIntGemm's global state. struct EightBitIntGemmLockId; // Global state: consists of one global GemmContext instance. GemmContext* global_context; GemmContext* GetOrCreateGlobalContext() { if (!global_context) { global_context = new GemmContext; } return global_context; } void DestroyGlobalContext() { delete global_context; global_context = nullptr; } template void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k, const std::uint8_t* a, std::int32_t a_offset, int lda, const std::uint8_t* b, std::int32_t b_offset, int ldb, std::uint8_t* c, std::int32_t c_offset, std::int32_t c_mult_int, std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) { const int lhs_offset = a_offset; const int rhs_offset = b_offset; const int result_offset = c_offset; const int result_mult_int = c_mult_int; const int result_shift = c_shift; static const MapOrder ResultOrder = transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; static const MapOrder LhsOrder = transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; static const MapOrder RhsOrder = transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; MatrixMap lhs(a, m, k, lda); MatrixMap rhs(b, k, n, ldb); MatrixMap result(c, m, n, ldc); switch (bit_depth) { #define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ case BitDepthSetting::BIT_DEPTH_SETTING: \ Gemm( \ context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \ result_mult_int, result_shift); \ return; GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams) GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams) default: abort(); #undef GEMMLOWP_HANDLE_BIT_DEPTH } } template void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k, const std::uint8_t* a, std::int32_t a_offset, int lda, const std::uint8_t* b, std::int32_t b_offset, int ldb, std::int32_t* c, int ldc, BitDepthSetting bit_depth) { const int lhs_offset = a_offset; const int rhs_offset = b_offset; static const MapOrder ResultOrder = transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor; static const MapOrder LhsOrder = transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor; static const MapOrder RhsOrder = transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor; MatrixMap lhs(a, m, k, lda); MatrixMap rhs(b, k, n, ldb); MatrixMap result(c, m, n, ldc); auto empty_pipeline = std::make_tuple(); switch (bit_depth) { #define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \ case BitDepthSetting::BIT_DEPTH_SETTING: \ GemmWithOutputPipeline( \ context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \ return; GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams) GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams) default: abort(); #undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32 } } class Scratch { public: Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {} void AssureSize(std::int32_t required_size) { if (size_ >= required_size) { return; } buffer_.reset(new std::uint8_t[required_size + 32]); buffer_32_ = buffer_.get() + ((32 - (reinterpret_cast(buffer_.get()) % 32)) % 32); assert((reinterpret_cast(buffer_32_) % 32) == 0); size_ = required_size; } void Clear() { buffer_.reset(nullptr); buffer_32_ = nullptr; size_ = 0; } std::uint8_t* buffer() { return buffer_32_; } private: std::unique_ptr buffer_; std::uint8_t* buffer_32_; std::int32_t size_; }; Scratch* global_scratch = nullptr; Scratch* GetOrCreateGlobalScratch() { if (global_scratch == nullptr) { global_scratch = new Scratch(); } return global_scratch; } void DestroyGlobalScratch() { delete global_scratch; global_scratch = nullptr; } #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) { // Is it row major and nicely packed? if (transpose && stride == cols) { return true; } // Is it a one row vector? (a vector is both row and column major) if (rows == 1) { return true; } return false; } bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) { // Is it column major and nicely packed? if (!transpose && stride == rows) { return true; } // Is it a one column vector? (a vector is both row and column major) if (cols == 1) { return true; } return false; } bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c, int m, int n, int k, int lda, int ldb, int ldc, BitDepthSetting depth_setting) { // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048. if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) { return false; } // The first operand needs to be a row major matrix or a vector. if (!IsRowMajorOrVector(transpose_a, lda, m, k)) { return false; } // The second operand needs to be a column major matrix or a vector. if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) { return false; } // The result can either be a row major matrix, a column major matrix or // a vector. if (IsRowMajorOrVector(transpose_c, ldc, m, n)) { return true; } if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) { return true; } return false; } // Assure enough scratch memory is allocated and run the fast path gemm. void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs, const std::uint8_t* rhs, int m, int n, int k, std::int32_t lhs_offset, std::int32_t rhs_offset, std::int32_t sum_offset, std::int32_t multiplicative_offset, std::int32_t shift, bool result_transpose, std::int32_t result_stride, std::uint8_t* result) { Scratch* scratch = GetOrCreateGlobalScratch(); const std::int32_t max_num_threads = context->max_num_threads(); if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads)); meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset, multiplicative_offset, shift, result); } else { scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads)); meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads, scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset, multiplicative_offset, shift, result); } } // Assure enough scratch memory is allocated and run the 8bit to float fast // path gemm. void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs, const std::uint8_t* rhs, int m, int n, int k, std::int32_t lhs_offset, std::int32_t rhs_offset, float result_offset, bool result_transpose, std::int32_t result_stride, float* result) { Scratch* scratch = GetOrCreateGlobalScratch(); const std::int32_t max_num_threads = context->max_num_threads(); if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) { scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads)); meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, scratch->buffer(), lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result); } else { scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads)); meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads, scratch->buffer(), rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result); } } #endif } // end anonymous namespace // Public interface entry points void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, int m, int n, int k, const std::uint8_t* a, std::int32_t a_offset, int lda, const std::uint8_t* b, std::int32_t b_offset, int ldb, std::uint8_t* c, std::int32_t c_offset, std::int32_t c_mult_int, std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) { ScopedLock sl(GlobalMutexes::EightBitIntGemm()); GemmContext* context = GetOrCreateGlobalContext(); #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, ldb, ldc, bit_depth)) { MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset, c_mult_int, c_shift, transpose_c, ldc, c); return; } #endif #define GEMMLOWP_HANDLE_CASE(ta, tb, tc) \ if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ EightBitIntGemmImpl(context, m, n, k, a, a_offset, lda, b, \ b_offset, ldb, c, c_offset, c_mult_int, \ c_shift, ldc, bit_depth); \ } GEMMLOWP_HANDLE_CASE(false, false, false) GEMMLOWP_HANDLE_CASE(false, false, true) GEMMLOWP_HANDLE_CASE(false, true, false) GEMMLOWP_HANDLE_CASE(false, true, true) GEMMLOWP_HANDLE_CASE(true, false, false) GEMMLOWP_HANDLE_CASE(true, false, true) GEMMLOWP_HANDLE_CASE(true, true, false) GEMMLOWP_HANDLE_CASE(true, true, true) #undef GEMMLOWP_HANDLE_CASE } void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c, int m, int n, int k, const std::uint8_t* a, std::int32_t a_offset, std::int32_t lda, const std::uint8_t* b, std::int32_t b_offset, std::int32_t ldb, float* c, float c_offset, std::int32_t ldc, BitDepthSetting bit_depth) { ScopedLock sl(GlobalMutexes::EightBitIntGemm()); GemmContext* context = GetOrCreateGlobalContext(); #if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON) if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda, ldb, ldc, bit_depth)) { MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset, transpose_c, ldc, c); return; } #endif // TODO(maciekc): implement a float output stage, get rid of scratch memory. Scratch* scratch = GetOrCreateGlobalScratch(); if (transpose_c) { scratch->AssureSize(m * ldc * sizeof(std::int32_t)); } else { scratch->AssureSize(n * ldc * sizeof(std::int32_t)); } std::int32_t* temp_c = reinterpret_cast(scratch->buffer()); #define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc) \ if (transpose_a == ta && transpose_b == tb && transpose_c == tc) { \ EightBitIntGemmInt32Impl(context, m, n, k, a, a_offset, lda, \ b, b_offset, ldb, temp_c, ldc, \ bit_depth); \ } GEMMLOWP_HANDLE_INT32_CASE(false, false, false) GEMMLOWP_HANDLE_INT32_CASE(false, false, true) GEMMLOWP_HANDLE_INT32_CASE(false, true, false) GEMMLOWP_HANDLE_INT32_CASE(false, true, true) GEMMLOWP_HANDLE_INT32_CASE(true, false, false) GEMMLOWP_HANDLE_INT32_CASE(true, false, true) GEMMLOWP_HANDLE_INT32_CASE(true, true, false) GEMMLOWP_HANDLE_INT32_CASE(true, true, true) #undef GEMMLOWP_HANDLE_INT32_CASE if (transpose_c) { // Row major. for (int i = 0; i < m; ++i) { float* dest_row = c + i * ldc; std::int32_t* src_row = temp_c + i * ldc; for (int j = 0; j < n; ++j) { dest_row[j] = static_cast(src_row[j]) * c_offset; } } } else { // Column major. for (int i = 0; i < n; ++i) { float* dest_column = c + i * ldc; std::int32_t* src_column = temp_c + i * ldc; for (int j = 0; j < m; ++j) { dest_column[j] = static_cast(src_column[j]) * c_offset; } } } } void SetMaxNumThreads(int n) { ScopedLock sl(GlobalMutexes::EightBitIntGemm()); GemmContext* context = GetOrCreateGlobalContext(); context->set_max_num_threads(n); } void FreePersistentResources() { ScopedLock sl(GlobalMutexes::EightBitIntGemm()); DestroyGlobalContext(); DestroyGlobalScratch(); } } // namespace eight_bit_int_gemm } // namespace gemmlowp