diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ddaa51..93b1e92 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,9 +69,9 @@ write_basic_package_version_file( VERSION ${LLAMA_INSTALL_VERSION} COMPATIBILITY SameMajorVersion) -install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake - ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama) +# install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake +# ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfigVersion.cmake +# DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama) -set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h) -install(TARGETS llama LIBRARY PUBLIC_HEADER) \ No newline at end of file +# set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/llama.h) +# install(TARGETS llama LIBRARY PUBLIC_HEADER) \ No newline at end of file diff --git a/include/bitnet-lut-kernels.h b/include/bitnet-lut-kernels.h new file mode 100644 index 0000000..489c8cd --- /dev/null +++ b/include/bitnet-lut-kernels.h @@ -0,0 +1,771 @@ +#if defined(GGML_BITNET_ARM_TL1) +#include "ggml-bitnet.h" +#define GGML_BITNET_MAX_NODES 8192 +static bool initialized = false; +static bitnet_tensor_extra * bitnet_tensor_extras = nullptr; +static size_t bitnet_tensor_extras_index = 0; +static void * aligned_malloc(size_t size) {{ +#if defined(_WIN32) + return _aligned_malloc(size, 64); +#else + void * ptr = nullptr; + posix_memalign(&ptr, 64, size); + return ptr; +#endif +}} +static void aligned_free(void * ptr) {{ +#if defined(_WIN32) + _aligned_free(ptr); +#else + free(ptr); +#endif +}} + +void per_tensor_quant(int k, void* lut_scales_, void* b_) {{ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + bitnet_float_type* b = (bitnet_float_type*)b_; +#ifdef __ARM_NEON + float32x4_t temp_max = vdupq_n_f32(0); + for (int i=0; i < k / 4; i++) {{ + float32x4_t vec_bs = vld1q_f32(b + 4 * i); + float32x4_t abssum = vabsq_f32(vec_bs); + temp_max = vmaxq_f32(abssum, temp_max); + }} + float32_t scales = 127 / vmaxvq_f32(temp_max); + *lut_scales = scales; +#elif defined __AVX2__ + __m256 max_vec = _mm256_set1_ps(0.f); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + // #pragma unroll + for (int i = 0; i < k / 8; i++) {{ + __m256 vec_b = _mm256_loadu_ps(b + i * 8); + __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b); + max_vec = _mm256_max_ps(vec_babs, max_vec); + }} + __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec)); + max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1)); + max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1)); + float scales = 127 / _mm_cvtss_f32(max1); + *lut_scales = scales; +#endif +}} + +void partial_max_reset(void* lut_scales_) {{ + bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_; + *lut_scales = 0.0; +}} + +#ifdef __ARM_NEON +inline void Transpose_8_8( + int16x8_t *v0, + int16x8_t *v1, + int16x8_t *v2, + int16x8_t *v3, + int16x8_t *v4, + int16x8_t *v5, + int16x8_t *v6, + int16x8_t *v7) +{{ + int16x8x2_t q04 = vzipq_s16(*v0, *v4); + int16x8x2_t q15 = vzipq_s16(*v1, *v5); + int16x8x2_t q26 = vzipq_s16(*v2, *v6); + int16x8x2_t q37 = vzipq_s16(*v3, *v7); + + int16x8x2_t q0246_0 = vzipq_s16(q04.val[0], q26.val[0]); + int16x8x2_t q0246_1 = vzipq_s16(q04.val[1], q26.val[1]); + int16x8x2_t q1357_0 = vzipq_s16(q15.val[0], q37.val[0]); + int16x8x2_t q1357_1 = vzipq_s16(q15.val[1], q37.val[1]); + + int16x8x2_t q_fin_0 = vzipq_s16(q0246_0.val[0], q1357_0.val[0]); + int16x8x2_t q_fin_1 = vzipq_s16(q0246_0.val[1], q1357_0.val[1]); + int16x8x2_t q_fin_2 = vzipq_s16(q0246_1.val[0], q1357_1.val[0]); + int16x8x2_t q_fin_3 = vzipq_s16(q0246_1.val[1], q1357_1.val[1]); + + *v0 = q_fin_0.val[0]; + *v1 = q_fin_0.val[1]; + *v2 = q_fin_1.val[0]; + *v3 = q_fin_1.val[1]; + *v4 = q_fin_2.val[0]; + *v5 = q_fin_2.val[1]; + *v6 = q_fin_3.val[0]; + *v7 = q_fin_3.val[1]; +}} +#endif + +template +inline void lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {{ +#ifdef __ARM_NEON + int16x8_t vec_lut[16]; + float32_t scales = *lut_scales; + uint8_t tbl_mask[16]; + tbl_mask[0] = 0; + tbl_mask[1] = 2; + tbl_mask[2] = 4; + tbl_mask[3] = 6; + tbl_mask[4] = 8; + tbl_mask[5] = 10; + tbl_mask[6] = 12; + tbl_mask[7] = 14; + tbl_mask[8] = 1; + tbl_mask[9] = 3; + tbl_mask[10] = 5; + tbl_mask[11] = 7; + tbl_mask[12] = 9; + tbl_mask[13] = 11; + tbl_mask[14] = 13; + tbl_mask[15] = 15; + uint8x16_t tbl_mask_q = vld1q_u8(tbl_mask); +#pragma unroll + for (int k = 0; k < act_k / 16; ++k) {{ + float32x4x2_t vec_bs_x0 = vld2q_f32(b + k * 16); + float32x4x2_t vec_bs_x1 = vld2q_f32(b + k * 16 + 8); + float32x4_t vec_f_0 = vmulq_n_f32(vec_bs_x0.val[0], scales); + float32x4_t vec_f_1 = vmulq_n_f32(vec_bs_x0.val[1], scales); + float32x4_t vec_f_2 = vmulq_n_f32(vec_bs_x1.val[0], scales); + float32x4_t vec_f_3 = vmulq_n_f32(vec_bs_x1.val[1], scales); + int32x4_t vec_b_0 = vcvtnq_s32_f32(vec_f_0); + int32x4_t vec_b_1 = vcvtnq_s32_f32(vec_f_1); + int32x4_t vec_b_2 = vcvtnq_s32_f32(vec_f_2); + int32x4_t vec_b_3 = vcvtnq_s32_f32(vec_f_3); + int16x4_t vec_b16_0 = vmovn_s32(vec_b_0); + int16x4_t vec_b16_1 = vmovn_s32(vec_b_1); + int16x4_t vec_b16_2 = vmovn_s32(vec_b_2); + int16x4_t vec_b16_3 = vmovn_s32(vec_b_3); + int16x8_t vec_bs_0 = vcombine_s16(vec_b16_0, vec_b16_2); + int16x8_t vec_bs_1 = vcombine_s16(vec_b16_1, vec_b16_3); + vec_lut[0] = vdupq_n_s16(0); + vec_lut[0] = vec_lut[0] - vec_bs_0; + vec_lut[0] = vec_lut[0] - vec_bs_1; + vec_lut[1] = vdupq_n_s16(0); + vec_lut[1] = vec_lut[1] - vec_bs_0; + vec_lut[2] = vdupq_n_s16(0); + vec_lut[2] = vec_lut[2] - vec_bs_0; + vec_lut[2] = vec_lut[2] + vec_bs_1; + vec_lut[3] = vdupq_n_s16(0); + vec_lut[3] = vec_lut[3] - vec_bs_1; + vec_lut[4] = vdupq_n_s16(0); + vec_lut[5] = vec_bs_1; + vec_lut[6] = vec_bs_0; + vec_lut[6] = vec_lut[6] - vec_bs_1; + vec_lut[7] = vec_bs_0; + vec_lut[8] = vec_bs_0; + vec_lut[8] = vec_lut[8] + vec_bs_1; + Transpose_8_8(&(vec_lut[0]), &(vec_lut[1]), &(vec_lut[2]), &(vec_lut[3]), + &(vec_lut[4]), &(vec_lut[5]), &(vec_lut[6]), &(vec_lut[7])); + Transpose_8_8(&(vec_lut[8]), &(vec_lut[9]), &(vec_lut[10]), &(vec_lut[11]), + &(vec_lut[12]), &(vec_lut[13]), &(vec_lut[14]), &(vec_lut[15])); +#pragma unroll + for (int idx = 0; idx < 8; idx++) {{ + int8x16_t q0_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx]), tbl_mask_q); + int8x8_t q0_low = vget_low_s8(q0_s); + int8x8_t q0_high = vget_high_s8(q0_s); + int8x16_t q1_s = vqtbl1q_s8(vreinterpretq_s8_s16(vec_lut[idx + 8]), tbl_mask_q); + int8x8_t q1_low = vget_low_s8(q1_s); + int8x8_t q1_high = vget_high_s8(q1_s); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2, q0_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 8, q1_high); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 16, q0_low); + vst1_s8(qlut + k * 16 * 8 * 2 + idx * 16 * 2 + 24, q1_low); + }} + }} +#endif +}} + +static bool is_type_supported(enum ggml_type type) {{ + if (type == GGML_TYPE_Q4_0 || + type == GGML_TYPE_TL1) {{ + return true; + }} else {{ + return false; + }} +}} +#include + +#define BM14336_4096 256 +#define BBK14336_4096 128 +inline void tbl_impl_14336_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK14336_4096 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[4]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM14336_4096; i += 32) { + #pragma unroll + for (int i=0; i<4; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 4; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[0] += vec_v_left_1.val[0]; + vec_c[0] += vec_v_right_1.val[0]; + vec_c[1] += vec_v_left_1.val[1]; + vec_c[1] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[2] += vec_v_left_2.val[0]; + vec_c[2] += vec_v_right_2.val[0]; + vec_c[3] += vec_v_left_2.val[1]; + vec_c[3] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[2] += vec_v_left_3.val[0]; + vec_c[2] += vec_v_right_3.val[0]; + vec_c[3] += vec_v_left_3.val[1]; + vec_c[3] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + + } +#endif +} + +int32_t qgemm_lut_14336_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM14336_4096]; + memset(&(CBits[0]), 0, BM14336_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4096 / BBK14336_4096; ++k_outer) { + tbl_impl_14336_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK14336_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK14336_4096 / 2 / 2 * BM14336_4096)]))); + } +#pragma unroll + for (int i = 0; i < BM14336_4096; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM4096_14336 128 +#define BBK4096_14336 64 +inline void tbl_impl_4096_14336(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK4096_14336 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[8]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM4096_14336; i += 64) { + #pragma unroll + for (int i=0; i<8; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 2; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[2] += vec_v_left_1.val[0]; + vec_c[2] += vec_v_right_1.val[0]; + vec_c[3] += vec_v_left_1.val[1]; + vec_c[3] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[4] += vec_v_left_2.val[0]; + vec_c[4] += vec_v_right_2.val[0]; + vec_c[5] += vec_v_left_2.val[1]; + vec_c[5] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[6] += vec_v_left_3.val[0]; + vec_c[6] += vec_v_right_3.val[0]; + vec_c[7] += vec_v_left_3.val[1]; + vec_c[7] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); + int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); + vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); + vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); + int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); + int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); + vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); + vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); + int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); + int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); + vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); + vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); + int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); + int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); + vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); + vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); + + } +#endif +} + +int32_t qgemm_lut_4096_14336(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM4096_14336]; + memset(&(CBits[0]), 0, BM4096_14336 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 14336 / BBK4096_14336; ++k_outer) { + tbl_impl_4096_14336((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_14336 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_14336 / 2 / 2 * BM4096_14336)]))); + } +#pragma unroll + for (int i = 0; i < BM4096_14336; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM1024_4096 256 +#define BBK1024_4096 128 +inline void tbl_impl_1024_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK1024_4096 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[4]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM1024_4096; i += 32) { + #pragma unroll + for (int i=0; i<4; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 4; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[0] += vec_v_left_1.val[0]; + vec_c[0] += vec_v_right_1.val[0]; + vec_c[1] += vec_v_left_1.val[1]; + vec_c[1] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[2] += vec_v_left_2.val[0]; + vec_c[2] += vec_v_right_2.val[0]; + vec_c[3] += vec_v_left_2.val[1]; + vec_c[3] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 4], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 5], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[8 * k + 6], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[8 * k + 7], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[2] += vec_v_left_3.val[0]; + vec_c[2] += vec_v_right_3.val[0]; + vec_c[3] += vec_v_left_3.val[1]; + vec_c[3] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + + } +#endif +} + +int32_t qgemm_lut_1024_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM1024_4096]; + memset(&(CBits[0]), 0, BM1024_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4096 / BBK1024_4096; ++k_outer) { + tbl_impl_1024_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK1024_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK1024_4096 / 2 / 2 * BM1024_4096)]))); + } +#pragma unroll + for (int i = 0; i < BM1024_4096; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; +#include + +#define BM4096_4096 128 +#define BBK4096_4096 64 +inline void tbl_impl_4096_4096(int32_t* c, int8_t* lut, uint8_t* a) { +#ifdef __ARM_NEON + const int KK = BBK4096_4096 / 2; + const uint8x16_t vec_mask = vdupq_n_u8(0x0f); + const int8x16_t vec_zero = vdupq_n_s16(0x0000); + int8x16_t vec_lut[2 * KK]; + int16x8_t vec_c[8]; +#pragma unroll + for (int k = 0; k < 2 * KK; k++) { + vec_lut[k] = vld1q_s8(lut + k * 16); + } + +#pragma unroll + for (int i = 0; i < BM4096_4096; i += 64) { + #pragma unroll + for (int i=0; i<8; i++) { + vec_c[i] = vandq_s16(vec_c[i], vec_zero); + } + +#pragma unroll + for (int k = 0; k < KK / 2; k++) { + + uint8x16_t vec_a_0 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 0 * 16); + uint8x16_t vec_a0_top = vshrq_n_u8(vec_a_0, 4); + uint8x16_t vec_a0_bot = vandq_u8(vec_a_0, vec_mask); + int8x16_t vec_v_0_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a0_top); + int8x16_t vec_v_0_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a0_top); + int8x16_t vec_v_0_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a0_bot); + int8x16_t vec_v_0_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a0_bot); + int8x16x2_t vec_v_left_0 = vzipq_s8(vec_v_0_left_tmp1, vec_v_0_left_tmp0); + int8x16x2_t vec_v_right_0 = vzipq_s8(vec_v_0_right_tmp1, vec_v_0_right_tmp0); + vec_c[0] += vec_v_left_0.val[0]; + vec_c[0] += vec_v_right_0.val[0]; + vec_c[1] += vec_v_left_0.val[1]; + vec_c[1] += vec_v_right_0.val[1]; + + uint8x16_t vec_a_1 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 1 * 16); + uint8x16_t vec_a1_top = vshrq_n_u8(vec_a_1, 4); + uint8x16_t vec_a1_bot = vandq_u8(vec_a_1, vec_mask); + int8x16_t vec_v_1_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a1_top); + int8x16_t vec_v_1_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a1_top); + int8x16_t vec_v_1_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a1_bot); + int8x16_t vec_v_1_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a1_bot); + int8x16x2_t vec_v_left_1 = vzipq_s8(vec_v_1_left_tmp1, vec_v_1_left_tmp0); + int8x16x2_t vec_v_right_1 = vzipq_s8(vec_v_1_right_tmp1, vec_v_1_right_tmp0); + vec_c[2] += vec_v_left_1.val[0]; + vec_c[2] += vec_v_right_1.val[0]; + vec_c[3] += vec_v_left_1.val[1]; + vec_c[3] += vec_v_right_1.val[1]; + + uint8x16_t vec_a_2 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 2 * 16); + uint8x16_t vec_a2_top = vshrq_n_u8(vec_a_2, 4); + uint8x16_t vec_a2_bot = vandq_u8(vec_a_2, vec_mask); + int8x16_t vec_v_2_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a2_top); + int8x16_t vec_v_2_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a2_top); + int8x16_t vec_v_2_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a2_bot); + int8x16_t vec_v_2_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a2_bot); + int8x16x2_t vec_v_left_2 = vzipq_s8(vec_v_2_left_tmp1, vec_v_2_left_tmp0); + int8x16x2_t vec_v_right_2 = vzipq_s8(vec_v_2_right_tmp1, vec_v_2_right_tmp0); + vec_c[4] += vec_v_left_2.val[0]; + vec_c[4] += vec_v_right_2.val[0]; + vec_c[5] += vec_v_left_2.val[1]; + vec_c[5] += vec_v_right_2.val[1]; + + uint8x16_t vec_a_3 = vld1q_u8(a + i * KK / 2 + k * 32 * 2 + 3 * 16); + uint8x16_t vec_a3_top = vshrq_n_u8(vec_a_3, 4); + uint8x16_t vec_a3_bot = vandq_u8(vec_a_3, vec_mask); + int8x16_t vec_v_3_left_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 0], vec_a3_top); + int8x16_t vec_v_3_left_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 1], vec_a3_top); + int8x16_t vec_v_3_right_tmp0 = vqtbl1q_s8(vec_lut[4 * k + 2], vec_a3_bot); + int8x16_t vec_v_3_right_tmp1 = vqtbl1q_s8(vec_lut[4 * k + 3], vec_a3_bot); + int8x16x2_t vec_v_left_3 = vzipq_s8(vec_v_3_left_tmp1, vec_v_3_left_tmp0); + int8x16x2_t vec_v_right_3 = vzipq_s8(vec_v_3_right_tmp1, vec_v_3_right_tmp0); + vec_c[6] += vec_v_left_3.val[0]; + vec_c[6] += vec_v_right_3.val[0]; + vec_c[7] += vec_v_left_3.val[1]; + vec_c[7] += vec_v_right_3.val[1]; + + } + + int32x4_t vec_v_bot_low_low_0 = vmovl_s16(vget_low_s16(vec_c[0])); + int32x4_t vec_v_bot_low_high_0 = vmovl_high_s16(vec_c[0]); + vst1q_s32(c + i + 0, vld1q_s32(c + i + 0) + vec_v_bot_low_low_0); + vst1q_s32(c + i + 4, vld1q_s32(c + i + 4) + vec_v_bot_low_high_0); + int32x4_t vec_v_bot_low_low_1 = vmovl_s16(vget_low_s16(vec_c[1])); + int32x4_t vec_v_bot_low_high_1 = vmovl_high_s16(vec_c[1]); + vst1q_s32(c + i + 8, vld1q_s32(c + i + 8) + vec_v_bot_low_low_1); + vst1q_s32(c + i + 12, vld1q_s32(c + i + 12) + vec_v_bot_low_high_1); + int32x4_t vec_v_bot_low_low_2 = vmovl_s16(vget_low_s16(vec_c[2])); + int32x4_t vec_v_bot_low_high_2 = vmovl_high_s16(vec_c[2]); + vst1q_s32(c + i + 16, vld1q_s32(c + i + 16) + vec_v_bot_low_low_2); + vst1q_s32(c + i + 20, vld1q_s32(c + i + 20) + vec_v_bot_low_high_2); + int32x4_t vec_v_bot_low_low_3 = vmovl_s16(vget_low_s16(vec_c[3])); + int32x4_t vec_v_bot_low_high_3 = vmovl_high_s16(vec_c[3]); + vst1q_s32(c + i + 24, vld1q_s32(c + i + 24) + vec_v_bot_low_low_3); + vst1q_s32(c + i + 28, vld1q_s32(c + i + 28) + vec_v_bot_low_high_3); + int32x4_t vec_v_bot_low_low_4 = vmovl_s16(vget_low_s16(vec_c[4])); + int32x4_t vec_v_bot_low_high_4 = vmovl_high_s16(vec_c[4]); + vst1q_s32(c + i + 32, vld1q_s32(c + i + 32) + vec_v_bot_low_low_4); + vst1q_s32(c + i + 36, vld1q_s32(c + i + 36) + vec_v_bot_low_high_4); + int32x4_t vec_v_bot_low_low_5 = vmovl_s16(vget_low_s16(vec_c[5])); + int32x4_t vec_v_bot_low_high_5 = vmovl_high_s16(vec_c[5]); + vst1q_s32(c + i + 40, vld1q_s32(c + i + 40) + vec_v_bot_low_low_5); + vst1q_s32(c + i + 44, vld1q_s32(c + i + 44) + vec_v_bot_low_high_5); + int32x4_t vec_v_bot_low_low_6 = vmovl_s16(vget_low_s16(vec_c[6])); + int32x4_t vec_v_bot_low_high_6 = vmovl_high_s16(vec_c[6]); + vst1q_s32(c + i + 48, vld1q_s32(c + i + 48) + vec_v_bot_low_low_6); + vst1q_s32(c + i + 52, vld1q_s32(c + i + 52) + vec_v_bot_low_high_6); + int32x4_t vec_v_bot_low_low_7 = vmovl_s16(vget_low_s16(vec_c[7])); + int32x4_t vec_v_bot_low_high_7 = vmovl_high_s16(vec_c[7]); + vst1q_s32(c + i + 56, vld1q_s32(c + i + 56) + vec_v_bot_low_low_7); + vst1q_s32(c + i + 60, vld1q_s32(c + i + 60) + vec_v_bot_low_high_7); + + } +#endif +} + +int32_t qgemm_lut_4096_4096(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + alignas(32) uint32_t CBits[BM4096_4096]; + memset(&(CBits[0]), 0, BM4096_4096 * sizeof(int32_t)); +#pragma unroll + for (int32_t k_outer = 0; k_outer < 4096 / BBK4096_4096; ++k_outer) { + tbl_impl_4096_4096((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK4096_4096 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK4096_4096 / 2 / 2 * BM4096_4096)]))); + } +#pragma unroll + for (int i = 0; i < BM4096_4096; i++) { + ((bitnet_float_type*)C)[i] = (((int32_t*)CBits)[i]) / ((bitnet_float_type*)LUT_Scales)[0] * ((bitnet_float_type*)Scales)[0]; + } + return 0; +}; + +template +void preprocessor_k(void* B, void* LUT_Scales, void* QLUT) {{ + partial_max_reset((&(((bitnet_float_type*)LUT_Scales)[0]))); + per_tensor_quant(K, (&(((bitnet_float_type*)LUT_Scales)[0])), (&(((bitnet_float_type*)B)[0]))); + + lut_ctor((&(((int8_t*)QLUT)[0])), (&(((bitnet_float_type*)B)[0])), (&(((bitnet_float_type*)LUT_Scales)[0]))); +}} +void ggml_preprocessor(int m, int k, void* B, void* LUT_Scales, void* QLUT) { + if (m == 14336 && k == 4096) { + preprocessor_k<4096>(B, LUT_Scales, QLUT); + } + else if (m == 4096 && k == 14336) { + preprocessor_k<14336>(B, LUT_Scales, QLUT); + } + else if (m == 1024 && k == 4096) { + preprocessor_k<4096>(B, LUT_Scales, QLUT); + } + else if (m == 4096 && k == 4096) { + preprocessor_k<4096>(B, LUT_Scales, QLUT); + } +} +void ggml_qgemm_lut(int m, int k, void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) { + if (m == 14336 && k == 4096) { + qgemm_lut_14336_4096(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 4096 && k == 14336) { + qgemm_lut_4096_14336(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 1024 && k == 4096) { + qgemm_lut_1024_4096(A, LUT, Scales, LUT_Scales, C); + } + else if (m == 4096 && k == 4096) { + qgemm_lut_4096_4096(A, LUT, Scales, LUT_Scales, C); + } +} + +void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) { + if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) { + return; + } + + int k = tensor->ne[0]; + int m = tensor->ne[1]; + const int lut_scales_size = 1; + const int scales_size = 1; + int bk = 0; + int bm = 0; + + if (m == 14336 && k == 4096) { + bm = BM14336_4096; + bk = BBK14336_4096; + } +else if (m == 4096 && k == 14336) { + bm = BM4096_14336; + bk = BBK4096_14336; + } +else if (m == 1024 && k == 4096) { + bm = BM1024_4096; + bk = BBK1024_4096; + } +else if (m == 4096 && k == 4096) { + bm = BM4096_4096; + bk = BBK4096_4096; + } + + const int n_tile_num = m / bm; + const int BK = bk; + uint8_t * qweights; + bitnet_float_type * scales; + + scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type)); + qweights = (uint8_t *) tensor->data; + float * i2_scales = (float * )(qweights + k * m / 4); + scales[0] = (bitnet_float_type) i2_scales[0]; + + tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index; + bitnet_tensor_extras[bitnet_tensor_extras_index++] = { + /* .lut_scales_size = */ lut_scales_size, + /* .BK = */ BK, + /* .n_tile_num = */ n_tile_num, + /* .qweights = */ qweights, + /* .scales = */ scales + }; +} +#endif \ No newline at end of file diff --git a/include/kernel_config.ini b/include/kernel_config.ini new file mode 100644 index 0000000..01eed07 --- /dev/null +++ b/include/kernel_config.ini @@ -0,0 +1,28 @@ +[Kernels_0] +m = 14336 +k = 4096 +bm = 256 +bk = 128 +bmm = 32 + +[Kernels_1] +m = 4096 +k = 14336 +bm = 128 +bk = 64 +bmm = 64 + +[Kernels_2] +m = 1024 +k = 4096 +bm = 256 +bk = 128 +bmm = 32 + +[Kernels_3] +m = 4096 +k = 4096 +bm = 128 +bk = 64 +bmm = 64 + diff --git a/setup_env.py b/setup_env.py index 8a9c4b4..a9e93f3 100644 --- a/setup_env.py +++ b/setup_env.py @@ -172,11 +172,12 @@ def compile(): # run_command(["cmake", "--build", "build", "--target", "llama-cli", "--config", "Release"]) run_command(["cmake", "--build", "build", "--config", "Release"], log_step="compile") +# do only code generation, don't download ggml, or compile the binaries here, or download the model def main(): - setup_gguf() + # setup_gguf() gen_code() - compile() - prepare_model() + # compile() + # prepare_model() def parse_args(): _, arch = system_info()