#pragma once #include "windows_customizations.h" #include #include namespace diskann { enum Metric { L2 = 0, INNER_PRODUCT = 1, COSINE = 2, FAST_L2 = 3 }; template class Distance { public: DISKANN_DLLEXPORT Distance(diskann::Metric dist_metric) : _distance_metric(dist_metric) { } // distance comparison function DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const = 0; // Needed only for COSINE-BYTE and INNER_PRODUCT-BYTE DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, const float normA, const float normB, uint32_t length) const; // For MIPS, normalization adds an extra dimension to the vectors. // This function lets callers know if the normalization process // changes the dimension. DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const; DISKANN_DLLEXPORT virtual diskann::Metric get_metric() const; // This is for efficiency. If no normalization is required, the callers // can simply ignore the normalize_data_for_build() function. DISKANN_DLLEXPORT virtual bool preprocessing_required() const; // Check the preprocessing_required() function before calling this. // Clients can call the function like this: // // if (metric->preprocessing_required()){ // T* normalized_data_batch; // Split data into batches of batch_size and for each, call: // metric->preprocess_base_points(data_batch, batch_size); // // TODO: This does not take into account the case for SSD inner product // where the dimensions change after normalization. DISKANN_DLLEXPORT virtual void preprocess_base_points(T *original_data, const size_t orig_dim, const size_t num_points); // Invokes normalization for a single vector during search. The scratch space // has to be created by the caller keeping track of the fact that // normalization might change the dimension of the query vector. DISKANN_DLLEXPORT virtual void preprocess_query(const T *query_vec, const size_t query_dim, T *scratch_query); // If an algorithm has a requirement that some data be aligned to a certain // boundary it can use this function to indicate that requirement. Currently, // we are setting it to 8 because that works well for AVX2. If we have AVX512 // implementations of distance algos, they might have to set this to 16 // (depending on how they are implemented) DISKANN_DLLEXPORT virtual size_t get_required_alignment() const; // Providing a default implementation for the virtual destructor because we // don't expect most metric implementations to need it. DISKANN_DLLEXPORT virtual ~Distance() = default; protected: diskann::Metric _distance_metric; size_t _alignment_factor = 8; }; class DistanceCosineInt8 : public Distance { public: DistanceCosineInt8() : Distance(diskann::Metric::COSINE) { } DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const override; }; class DistanceL2Int8 : public Distance { public: DistanceL2Int8() : Distance(diskann::Metric::L2) { } DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t size) const override; }; // AVX implementations. Borrowed from HNSW code. class AVXDistanceL2Int8 : public Distance { public: AVXDistanceL2Int8() : Distance(diskann::Metric::L2) { } DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const override; }; class DistanceCosineFloat : public Distance { public: DistanceCosineFloat() : Distance(diskann::Metric::COSINE) { } DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const override; }; class DistanceL2Float : public Distance { public: DistanceL2Float() : Distance(diskann::Metric::L2) { } #ifdef _WINDOWS DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const; #else DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const override __attribute__((hot)); #endif }; class AVXDistanceL2Float : public Distance { public: AVXDistanceL2Float() : Distance(diskann::Metric::L2) { } DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const override; }; template class SlowDistanceL2 : public Distance { public: SlowDistanceL2() : Distance(diskann::Metric::L2) { } DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const override; }; class SlowDistanceCosineUInt8 : public Distance { public: SlowDistanceCosineUInt8() : Distance(diskann::Metric::COSINE) { } DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t length) const override; }; class DistanceL2UInt8 : public Distance { public: DistanceL2UInt8() : Distance(diskann::Metric::L2) { } DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t size) const override; }; template class DistanceInnerProduct : public Distance { public: DistanceInnerProduct() : Distance(diskann::Metric::INNER_PRODUCT) { } DistanceInnerProduct(diskann::Metric metric) : Distance(metric) { } inline float inner_product(const T *a, const T *b, unsigned size) const; inline float compare(const T *a, const T *b, unsigned size) const override { float result = inner_product(a, b, size); // if (result < 0) // return std::numeric_limits::max(); // else return -result; } }; template class DistanceFastL2 : public DistanceInnerProduct { // currently defined only for float. // templated for future use. public: DistanceFastL2() : DistanceInnerProduct(diskann::Metric::FAST_L2) { } float norm(const T *a, unsigned size) const; float compare(const T *a, const T *b, float norm, unsigned size) const; }; class AVXDistanceInnerProductFloat : public Distance { public: AVXDistanceInnerProductFloat() : Distance(diskann::Metric::INNER_PRODUCT) { } DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const override; }; class AVXNormalizedCosineDistanceFloat : public Distance { private: AVXDistanceInnerProductFloat _innerProduct; protected: void normalize_and_copy(const float *a, uint32_t length, float *a_norm) const; public: AVXNormalizedCosineDistanceFloat() : Distance(diskann::Metric::COSINE) { } DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const override { // Inner product returns negative values to indicate distance. // This will ensure that cosine is between -1 and 1. return 1.0f + _innerProduct.compare(a, b, length); } DISKANN_DLLEXPORT virtual uint32_t post_normalization_dimension(uint32_t orig_dimension) const override; DISKANN_DLLEXPORT virtual bool preprocessing_required() const override; DISKANN_DLLEXPORT virtual void preprocess_base_points(float *original_data, const size_t orig_dim, const size_t num_points) override; DISKANN_DLLEXPORT virtual void preprocess_query(const float *query_vec, const size_t query_dim, float *scratch_query_vector) override; }; class VsagDistanceL2Float : public Distance { public: VsagDistanceL2Float(size_t dimension); DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const override; private: float(*dist_func_)(const void *, const void *, const void *); }; class VsagDistanceInnerProductFloat : public Distance { public: VsagDistanceInnerProductFloat(size_t dimension); DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const override; private: float(*dist_func_)(const void *, const void *, const void *); }; template Distance *get_distance_function(Metric m); } // namespace diskann