// Low-level (BLAS-like) primitives. #pragma once #include "ctranslate2/devices.h" #include "ctranslate2/types.h" namespace ctranslate2 { template struct primitives { template static T at(const T* x, dim_t index); template static void fill(T* x, T a, dim_t size); template static void strided_fill(T* x, T a, dim_t inc_x, dim_t size); template static void indexed_fill(T* x, T a, const int32_t* indices, dim_t num_indices); template static void copy(const T* x, T* y, dim_t size); template static void convert(const U* x, V* y, dim_t size); template static T sum(const T* array, dim_t size); template static T mean(const T* array, dim_t size) { return sum(array, size) / size; } template static dim_t max_element(const T* array, dim_t size); template static T max(const T* array, dim_t size); template static T amax(const T* array, dim_t size); template static void add(T a, const T* x, T* y, dim_t size); template static void add(T a, T* y, dim_t size) { add(a, y, y, size); } template static void add(const T* a, const T* b, T* c, dim_t size); template static void add(const T* x, T* y, dim_t size) { add(x, y, y, size); } template static void add_batch_broadcast(const T* a, const T* b, T* c, dim_t a_size, dim_t b_size); template static void add_batch_broadcast(const T* x, T* y, dim_t x_size, dim_t y_size) { add_batch_broadcast(x, y, y, x_size, y_size); } template static void add_depth_broadcast(const T* a, const T* b, T* c, dim_t a_size, dim_t b_size); template static void add_depth_broadcast(const T* x, T* y, dim_t x_size, dim_t y_size) { add_depth_broadcast(x, y, y, x_size, y_size); } template static void sub(T a, const T* x, T* y, dim_t size) { T a_rev = -a; add(a_rev, x, y, size); } template static void sub(T a, T* y, dim_t size) { sub(a, y, y, size); } template static void sub(const T* a, const T* b, T* c, dim_t size); template static void max(T a, const T* x, T* y, dim_t size); template static void max(const T* a, const T* b, T* c, dim_t size); template static void max(T a, T* y, dim_t size) { max(a, y, y, size); } template static void min(T a, const T* x, T* y, dim_t size); template static void min(const T* a, const T* b, T* c, dim_t size); template static void min(T a, T* y, dim_t size) { min(a, y, y, size); } template static void mul(T a, const T* x, T* y, dim_t size); template static void mul(T a, T* y, dim_t size) { mul(a, y, y, size); } template static void mul_batch_broadcast(const T* a, const T* b, T* c, dim_t a_size, dim_t b_size); template static void mul_batch_broadcast(const T* x, T* y, dim_t x_size, dim_t y_size) { mul_batch_broadcast(x, y, y, x_size, y_size); } template static void mul(const T* a, const T* b, T* c, dim_t size); template static void mul(const T* x, T* y, dim_t size) { mul(x, y, y, size); } template static void penalize_previous_tokens(T* scores, const T* previous_scores, const int32_t* previous_ids, T penalty, dim_t batch_size, dim_t length, dim_t vocabulary_size); static void prepare_length_mask(const int32_t* lengths, dim_t batch_size, dim_t num_heads, dim_t num_queries, bool mask_future, bool multi_query, int32_t* mask); template static void transpose_2d(const T* a, const dim_t* dims, T* b); template static void transpose_3d(const T* a, const dim_t* dims, const dim_t* perm, T* b); template static void transpose_4d(const T* a, const dim_t* dims, const dim_t* perm, T* b); template static float logsumexp(const T* x, dim_t size); template static void exp(const T* x, T* y, dim_t size); template static void log(const T* x, T* y, dim_t size); template static void cos(const T* x, T* y, dim_t size); template static void sin(const T* x, T* y, dim_t size); template static void tanh(const T* x, T* y, dim_t size); template static void relu(const T* x, T* y, dim_t size); template static void gelu(const T* x, T* y, dim_t size); template static void gelu_tanh(const T* x, T* y, dim_t size); template static void gelu_sigmoid(const T* x, T* y, dim_t size); template static void swish(const T* x, T* y, dim_t size); static void compute_u8_compensation(const int8_t* b, bool transpose_b, dim_t k, dim_t n, float alpha, int32_t* compensation); // If dest is not passed, returns the number of bytes required to store the packed data, // or 0 if packing is not supported. template static dim_t gemm_pack_b(const T* b, const bool transpose_b, const dim_t k, const dim_t n, const float alpha, T* dest = nullptr); template static void gemm(bool a_is_packed, bool b_is_packed, bool transpose_a, bool transpose_b, dim_t m, dim_t n, dim_t k, float alpha, const In* a, dim_t lda, const In* b, dim_t ldb, float beta, Out* c, dim_t ldc, const Out* a_shift_compensation = nullptr); template static void gemm_batch_strided(bool transpose_a, bool transpose_b, dim_t m, dim_t n, dim_t k, float alpha, const In* a, dim_t lda, dim_t stridea, const In* b, dim_t ldb, dim_t strideb, float beta, Out* c, dim_t ldc, dim_t stridec, dim_t batch_size); }; template struct cross_device_primitives { template static void copy(const T* x, T* y, dim_t size); }; }