#include "c_api.h" #include #include #include #include #include #include #include #include #include #include struct TModelHandleContent { THolder FullModel; }; #define MODEL_HANDLE_CONTENT_PTR(x) ((TModelHandleContent*)(x)) #define FULL_MODEL_PTR(x) (MODEL_HANDLE_CONTENT_PTR(x)->FullModel) #define EVALUATOR_PTR(x) (MODEL_HANDLE_CONTENT_PTR(x)->FullModel->GetCurrentEvaluator()) #define DATA_WRAPPER_PTR(x) ((TFeaturesDataWrapper*)(x)) struct TErrorMessageHolder { TString Message; }; class TFeaturesDataWrapper { public: TFeaturesDataWrapper(size_t docsCount) : DocsCount(docsCount) { } // dim : floatFeaturesSize x docsCount void AddFloatFeatures(const float** floatFeatures, size_t floatFeaturesSize) { FloatFeatures.emplace_back(floatFeatures, floatFeaturesSize); } void AddCatFeatures(const char*** catFeatures, size_t catFeaturesSize) { CatFeatures.emplace_back(catFeatures, catFeaturesSize); } void AddTextFeatures(const char*** textFeatures, size_t textFeaturesSize) { TextFeatures.emplace_back(textFeatures, textFeaturesSize); } NCB::TDataProviderPtr BuildDataProvider() { size_t floatFeaturesCount = 0; size_t catFeaturesCount = 0; size_t textFeaturesCount = 0; for (auto [_, count] : FloatFeatures) { floatFeaturesCount += count; } for (auto [_, count] : CatFeatures) { catFeaturesCount += count; } for (auto [_, count] : TextFeatures) { textFeaturesCount += count; } TVector catFeaturesIndices(catFeaturesCount); std::iota(catFeaturesIndices.begin(), catFeaturesIndices.end(), static_cast(floatFeaturesCount)); TVector textFeaturesIndices(textFeaturesCount); std::iota(textFeaturesIndices.begin(), textFeaturesIndices.end(), static_cast(floatFeaturesCount + catFeaturesCount)); NCB::TDataMetaInfo metaInfo; metaInfo.TargetType = NCB::ERawTargetType::Float; metaInfo.TargetCount = 1; metaInfo.FeaturesLayout = MakeIntrusive( (ui32)(floatFeaturesCount + catFeaturesCount + textFeaturesCount), catFeaturesIndices, textFeaturesIndices, TVector{}, TVector{} ); NCB::TDataProviderClosure dataProviderClosure( NCB::EDatasetVisitorType::RawFeaturesOrder, NCB::TDataProviderBuilderOptions(), &NPar::LocalExecutor() ); auto* visitor = dataProviderClosure.GetVisitor(); CB_ENSURE(visitor); visitor->Start(metaInfo, DocsCount, NCB::EObjectsOrder::Undefined, {}); { ui32 addedFloatFeaturesCount = 0; for (auto [arr, count] : FloatFeatures) { for (size_t i = 0; i < count; ++i, ++arr, ++addedFloatFeaturesCount) { const float* column = *arr; visitor->AddFloatFeature( addedFloatFeaturesCount, MakeIntrusive>(TVector(column, column + DocsCount)) ); } } } { CatFeaturesVec.assign(catFeaturesIndices.size(), TVector(DocsCount)); ui32 addedCatFeaturesCount = 0; for (auto [arr, count] : CatFeatures) { for (size_t i = 0; i < count; ++i, ++addedCatFeaturesCount) { for (size_t d = 0; d < DocsCount; ++d) { CatFeaturesVec[addedCatFeaturesCount][d] = arr[i][d]; } visitor->AddCatFeature( catFeaturesIndices[addedCatFeaturesCount], CatFeaturesVec[addedCatFeaturesCount] ); } } } { TextFeaturesVec.assign(textFeaturesIndices.size(), TVector(DocsCount)); ui32 addedTextFeaturesCount = 0; for (auto [arr, count] : TextFeatures) { for (size_t i = 0; i < count; ++i, ++addedTextFeaturesCount) { for (size_t d = 0; d < DocsCount; ++d) { TextFeaturesVec[addedTextFeaturesCount][d] = arr[i][d]; } visitor->AddTextFeature( textFeaturesIndices[addedTextFeaturesCount], TextFeaturesVec[addedTextFeaturesCount] ); } } } visitor->Finish(); DataProvider = dataProviderClosure.GetResult(); return DataProvider; } private: TVector> FloatFeatures; TVector> CatFeatures; TVector> TextFeatures; TVector> CatFeaturesVec; TVector> TextFeaturesVec; NCB::TDataProviderPtr DataProvider; size_t DocsCount = 0; }; namespace { void GetSpecificClass(int classId, TArrayRef predictions, size_t dim, TArrayRef result) { for (size_t docId = 0; docId < result.size(); ++docId) { result[docId] = predictions[docId * dim + classId]; } } } // namespace extern "C" { CATBOOST_API DataWrapperHandle* DataWrapperCreate(size_t docsCount) { try { return new TFeaturesDataWrapper(docsCount); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); } return nullptr; } CATBOOST_API void DataWrapperDelete(DataWrapperHandle* dataWrapperHandle) { if (dataWrapperHandle != nullptr) { delete DATA_WRAPPER_PTR(dataWrapperHandle); } } CATBOOST_API void AddFloatFeatures(DataWrapperHandle* dataWrapperHandle, const float** floatFeatures, size_t floatFeaturesSize) { DATA_WRAPPER_PTR(dataWrapperHandle)->AddFloatFeatures(floatFeatures, floatFeaturesSize); } CATBOOST_API void AddCatFeatures(DataWrapperHandle* dataWrapperHandle, const char*** catFeatures, size_t catFeaturesSize) { DATA_WRAPPER_PTR(dataWrapperHandle)->AddCatFeatures(catFeatures, catFeaturesSize); } CATBOOST_API void AddTextFeatures(DataWrapperHandle* dataWrapperHandle, const char*** textFeatures, size_t textFeaturesSize) { DATA_WRAPPER_PTR(dataWrapperHandle)->AddTextFeatures(textFeatures, textFeaturesSize); } CATBOOST_API DataProviderHandle* BuildDataProvider(DataWrapperHandle* dataWrapperHandle) { return DATA_WRAPPER_PTR(dataWrapperHandle)->BuildDataProvider().Get(); } CATBOOST_API ModelCalcerHandle* ModelCalcerCreate() { try { auto* fullModel = new TFullModel; return new TModelHandleContent{.FullModel = THolder(fullModel)}; } catch (...) { Singleton()->Message = CurrentExceptionMessage(); } return nullptr; } CATBOOST_API const char* GetErrorString() { return Singleton()->Message.data(); } CATBOOST_API void ModelCalcerDelete(ModelCalcerHandle* modelHandle) { if (modelHandle != nullptr) { delete MODEL_HANDLE_CONTENT_PTR(modelHandle); } } CATBOOST_API bool LoadFullModelFromFile(ModelCalcerHandle* modelHandle, const char* filename) { try { *FULL_MODEL_PTR(modelHandle) = ReadModel(filename); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool LoadFullModelFromBuffer(ModelCalcerHandle* modelHandle, const void* binaryBuffer, size_t binaryBufferSize) { try { *FULL_MODEL_PTR(modelHandle) = ReadModel(binaryBuffer, binaryBufferSize); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool EnableGPUEvaluation(ModelCalcerHandle* modelHandle, int deviceId) { try { //TODO(kirillovs): fix this after adding set evaluator props interface CB_ENSURE(deviceId == 0, "FIXME: Only device 0 is supported for now"); FULL_MODEL_PTR(modelHandle)->SetEvaluatorType(EFormulaEvaluatorType::GPU); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool SetPredictionType(ModelCalcerHandle* modelHandle, EApiPredictionType predictionType) { try { FULL_MODEL_PTR(modelHandle)->SetPredictionType(static_cast(predictionType)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool SetPredictionTypeString(ModelCalcerHandle* modelHandle, const char* predictionTypeStr) { try { FULL_MODEL_PTR(modelHandle)->SetPredictionType( FromString(predictionTypeStr) ); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool CalcModelPredictionFlat(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, double* result, size_t resultSize) { try { if (docCount == 1) { FULL_MODEL_PTR(modelHandle)->CalcFlatSingle(TConstArrayRef(*floatFeatures, floatFeaturesSize), TArrayRef(result, resultSize)); } else { TVector> featuresVec(docCount); for (size_t i = 0; i < docCount; ++i) { featuresVec[i] = TConstArrayRef(floatFeatures[i], floatFeaturesSize); } FULL_MODEL_PTR(modelHandle)->CalcFlat(featuresVec, TArrayRef(result, resultSize)); } } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool CalcModelPrediction( ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, const char*** catFeatures, size_t catFeaturesSize, double* result, size_t resultSize) { try { TVector> floatFeaturesVec(docCount); TVector> catFeaturesVec(docCount, TVector(catFeaturesSize)); for (size_t i = 0; i < docCount; ++i) { if (floatFeaturesSize > 0) { floatFeaturesVec[i] = TConstArrayRef(floatFeatures[i], floatFeaturesSize); } for (size_t catFeatureIdx = 0; catFeatureIdx < catFeaturesSize; ++catFeatureIdx) { catFeaturesVec[i][catFeatureIdx] = catFeatures[i][catFeatureIdx]; } } FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool CalcModelPredictionText( ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, const char*** catFeatures, size_t catFeaturesSize, const char*** textFeatures, size_t textFeaturesSize, double* result, size_t resultSize) { try { TVector> floatFeaturesVec(docCount); TVector> catFeaturesVec(docCount, TVector(catFeaturesSize)); TVector> textFeaturesVec(docCount, TVector(textFeaturesSize)); for (size_t i = 0; i < docCount; ++i) { if (floatFeaturesSize > 0) { floatFeaturesVec[i] = TConstArrayRef(floatFeatures[i], floatFeaturesSize); } for (size_t catFeatureIdx = 0; catFeatureIdx < catFeaturesSize; ++catFeatureIdx) { catFeaturesVec[i][catFeatureIdx] = catFeatures[i][catFeatureIdx]; } for (size_t textFeatureIdx = 0; textFeatureIdx < textFeaturesSize; ++textFeatureIdx) { textFeaturesVec[i][textFeatureIdx] = textFeatures[i][textFeatureIdx]; } } FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, textFeaturesVec, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool CalcModelPredictionSingle( ModelCalcerHandle* modelHandle, const float* floatFeatures, size_t floatFeaturesSize, const char** catFeatures, size_t catFeaturesSize, double* result, size_t resultSize) { try { TVector> floatFeaturesVec(1); TVector> catFeaturesVec(1, TVector(catFeaturesSize)); if (floatFeaturesSize > 0) { floatFeaturesVec[0] = TConstArrayRef(floatFeatures, floatFeaturesSize); } for (size_t catFeatureIdx = 0; catFeatureIdx < catFeaturesSize; ++catFeatureIdx) { catFeaturesVec[0][catFeatureIdx] = catFeatures[catFeatureIdx]; } FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool CalcModelPredictionWithHashedCatFeatures(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, const int** catFeatures, size_t catFeaturesSize, double* result, size_t resultSize) { try { TVector> floatFeaturesVec(docCount); TVector> catFeaturesVec(docCount); for (size_t i = 0; i < docCount; ++i) { if (floatFeaturesSize > 0) { floatFeaturesVec[i] = TConstArrayRef(floatFeatures[i], floatFeaturesSize); } if (catFeaturesSize > 0) { catFeaturesVec[i] = TConstArrayRef(catFeatures[i], catFeaturesSize); } } FULL_MODEL_PTR(modelHandle)->Calc(floatFeaturesVec, catFeaturesVec, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool CalcModelPredictionWithHashedCatFeaturesAndTextFeatures(ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, const int** catFeatures, size_t catFeaturesSize, const char*** textFeatures, size_t textFeaturesSize, double* result, size_t resultSize) { try { TVector> floatFeaturesVec(docCount); TVector> catFeaturesVec(docCount); TVector> textFeaturesVec(docCount, TVector(textFeaturesSize)); for (size_t i = 0; i < docCount; ++i) { if (floatFeaturesSize > 0) { floatFeaturesVec[i] = TConstArrayRef(floatFeatures[i], floatFeaturesSize); } if (catFeaturesSize > 0) { catFeaturesVec[i] = TConstArrayRef(catFeatures[i], catFeaturesSize); } for (size_t textFeatureIdx = 0; textFeatureIdx < textFeaturesSize; ++textFeatureIdx) { textFeaturesVec[i][textFeatureIdx] = textFeatures[i][textFeatureIdx]; } } FULL_MODEL_PTR(modelHandle)->CalcWithHashedCatAndText(floatFeaturesVec, catFeaturesVec, textFeaturesVec, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool PredictSpecificClassFlat( ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, int classId, double* result, size_t resultSize) { try { const size_t dim = FULL_MODEL_PTR(modelHandle)->GetDimensionsCount(); TVector rawResult(docCount * dim); if (!CalcModelPredictionFlat(modelHandle, docCount, floatFeatures, floatFeaturesSize, rawResult.data(), rawResult.size())) { return false; } GetSpecificClass(classId, rawResult, dim, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool PredictSpecificClass( ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, const char*** catFeatures, size_t catFeaturesSize, int classId, double* result, size_t resultSize) { try { const size_t dim = FULL_MODEL_PTR(modelHandle)->GetDimensionsCount(); TVector rawResult(docCount * dim); if (!CalcModelPrediction( modelHandle, docCount, floatFeatures, floatFeaturesSize, catFeatures, catFeaturesSize, rawResult.data(), rawResult.size())) { return false; } GetSpecificClass(classId, rawResult, dim, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool PredictSpecificClassText( ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, const char*** catFeatures, size_t catFeaturesSize, const char*** textFeatures, size_t textFeaturesSize, int classId, double* result, size_t resultSize) { try { const size_t dim = FULL_MODEL_PTR(modelHandle)->GetDimensionsCount(); TVector rawResult(docCount * dim); if (!CalcModelPredictionText( modelHandle, docCount, floatFeatures, floatFeaturesSize, catFeatures, catFeaturesSize, textFeatures, textFeaturesSize, rawResult.data(), rawResult.size())) { return false; } GetSpecificClass(classId, rawResult, dim, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool PredictSpecificClassSingle( ModelCalcerHandle* modelHandle, const float* floatFeatures, size_t floatFeaturesSize, const char** catFeatures, size_t catFeaturesSize, int classId, double* result, size_t resultSize) { try { const size_t dim = FULL_MODEL_PTR(modelHandle)->GetDimensionsCount(); TVector rawResult(dim); if (!CalcModelPredictionSingle( modelHandle, floatFeatures, floatFeaturesSize, catFeatures, catFeaturesSize, rawResult.data(), rawResult.size())) { return false; } GetSpecificClass(classId, rawResult, dim, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool PredictSpecificClassWithHashedCatFeatures( ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, const int** catFeatures, size_t catFeaturesSize, int classId, double* result, size_t resultSize) { try { const size_t dim = FULL_MODEL_PTR(modelHandle)->GetDimensionsCount(); TVector rawResult(docCount * dim); if (!CalcModelPredictionWithHashedCatFeatures( modelHandle, docCount, floatFeatures, floatFeaturesSize, catFeatures, catFeaturesSize, rawResult.data(), rawResult.size())) { return false; } GetSpecificClass(classId, rawResult, dim, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API bool PredictSpecificClassWithHashedCatFeaturesAndTextFeatures( ModelCalcerHandle* modelHandle, size_t docCount, const float** floatFeatures, size_t floatFeaturesSize, const int** catFeatures, size_t catFeaturesSize, const char*** textFeatures, size_t textFeaturesSize, int classId, double* result, size_t resultSize) { try { const size_t dim = FULL_MODEL_PTR(modelHandle)->GetDimensionsCount(); TVector rawResult(docCount * dim); if (!CalcModelPredictionWithHashedCatFeaturesAndTextFeatures( modelHandle, docCount, floatFeatures, floatFeaturesSize, catFeatures, catFeaturesSize, textFeatures, textFeaturesSize, rawResult.data(), rawResult.size())) { return false; } GetSpecificClass(classId, rawResult, dim, TArrayRef(result, resultSize)); } catch (...) { Singleton()->Message = CurrentExceptionMessage(); return false; } return true; } CATBOOST_API int GetStringCatFeatureHash(const char* data, size_t size) { return CalcCatFeatureHash(TStringBuf(data, size)); } CATBOOST_API int GetIntegerCatFeatureHash(long long val) { TStringBuilder valStr; valStr << val; return CalcCatFeatureHash(valStr); } CATBOOST_API size_t GetFloatFeaturesCount(ModelCalcerHandle* modelHandle) { return FULL_MODEL_PTR(modelHandle)->GetNumFloatFeatures(); } CATBOOST_API size_t GetCatFeaturesCount(ModelCalcerHandle* modelHandle) { return FULL_MODEL_PTR(modelHandle)->GetNumCatFeatures(); } CATBOOST_API size_t GetTreeCount(ModelCalcerHandle* modelHandle) { return FULL_MODEL_PTR(modelHandle)->GetTreeCount(); } CATBOOST_API size_t GetDimensionsCount(ModelCalcerHandle* modelHandle) { return FULL_MODEL_PTR(modelHandle)->GetDimensionsCount(); } CATBOOST_API size_t GetPredictionDimensionsCount(ModelCalcerHandle* modelHandle) { return EVALUATOR_PTR(modelHandle)->GetPredictionDimensions(); } CATBOOST_API bool CheckModelMetadataHasKey(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize) { return FULL_MODEL_PTR(modelHandle)->ModelInfo.contains(TStringBuf(keyPtr, keySize)); } CATBOOST_API size_t GetModelInfoValueSize(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize) { TStringBuf key(keyPtr, keySize); if (!FULL_MODEL_PTR(modelHandle)->ModelInfo.contains(key)) { return 0; } return FULL_MODEL_PTR(modelHandle)->ModelInfo.at(key).size(); } CATBOOST_API const char* GetModelInfoValue(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize) { TStringBuf key(keyPtr, keySize); if (!FULL_MODEL_PTR(modelHandle)->ModelInfo.contains(key)) { return nullptr; } return FULL_MODEL_PTR(modelHandle)->ModelInfo.at(key).c_str(); } }