// Copyright 2020 The IREE Authors // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "runtime/bindings/tflite/model.h" #include #include #include "iree/modules/hal/module.h" #include "iree/vm/bytecode/module.h" static iree_status_t _TfLiteModelCalculateFunctionIOCounts( const iree_vm_function_signature_t* signature, int32_t* out_input_count, int32_t* out_output_count) { iree_string_view_t arguments, results; IREE_RETURN_IF_ERROR(iree_vm_function_call_get_cconv_fragments( signature, &arguments, &results)); // NOTE: today we only pass 1:1 buffer views with what tflite does. // That means that both these should be one `r` per buffer view and our counts // are just the number of chars in the cconv. *out_input_count = (int32_t)arguments.size; *out_output_count = (int32_t)results.size; return iree_ok_status(); } static iree_status_t _TfLiteModelInitializeModule(const void* flatbuffer_data, size_t flatbuffer_size, iree_allocator_t allocator, TfLiteModel* model) { IREE_TRACE_ZONE_BEGIN(z0); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, allocator, &model->instance)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_module_register_all_types(model->instance)); iree_const_byte_span_t flatbuffer_span = iree_make_const_byte_span(flatbuffer_data, flatbuffer_size); iree_allocator_t flatbuffer_allocator = iree_allocator_null(); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_vm_bytecode_module_create(model->instance, flatbuffer_span, flatbuffer_allocator, allocator, &model->module), "error creating bytecode module"); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_vm_module_lookup_function_by_name( model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT, iree_make_cstring_view("_tflite_main"), &model->exports._main), "unable to find '_tflite_main' export in module, module must be compiled " "with tflite bindings support"); // Get the input and output counts of the function; this is useful for being // able to preallocate storage when creating interpreters. iree_vm_function_signature_t main_signature = iree_vm_function_signature(&model->exports._main); IREE_RETURN_IF_ERROR(_TfLiteModelCalculateFunctionIOCounts( &main_signature, &model->input_count, &model->output_count)); // NOTE: the input shape query is not required as it's possible (though // silly) for a model to have no inputs. In testing this can happen a lot // but in the wild it's rare ... says someone who previously filed bugs // against tflite because they didn't support models with no inputs when I // was being silly and needed them ;) IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name( model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT, iree_make_cstring_view("_tflite_main_query_input_shape"), &model->exports._query_input_shape)); // NOTE: the input shape resizing function is only required if the model has // dynamic shapes. IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name( model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT, iree_make_cstring_view("_tflite_main_resize_input_shape"), &model->exports._resize_input_shape)); IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_vm_module_lookup_function_by_name( model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT, iree_make_cstring_view("_tflite_main_query_output_shape"), &model->exports._query_output_shape), "unable to find '_tflite_main_query_output_shape' export in module"); // It's OK for this to fail; the model may not have variables. IREE_IGNORE_ERROR(iree_vm_module_lookup_function_by_name( model->module, IREE_VM_FUNCTION_LINKAGE_EXPORT, iree_make_cstring_view("_tflite_main_reset_variables"), &model->exports._reset_variables)); IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreate(const void* model_data, size_t model_size) { iree_allocator_t allocator = iree_allocator_system(); IREE_TRACE_ZONE_BEGIN(z0); TfLiteModel* model = NULL; iree_status_t status = iree_allocator_malloc(allocator, sizeof(*model), (void**)&model); if (!iree_status_is_ok(iree_status_consume_code(status))) { IREE_TRACE_MESSAGE(ERROR, "failed model allocation"); IREE_TRACE_ZONE_END(z0); return NULL; } memset(model, 0, sizeof(*model)); iree_atomic_ref_count_init(&model->ref_count); model->allocator = allocator; status = _TfLiteModelInitializeModule(model_data, model_size, allocator, model); if (!iree_status_is_ok(status)) { iree_status_fprint(stderr, status); iree_status_free(status); TfLiteModelDelete(model); IREE_TRACE_ZONE_END(z0); return NULL; } IREE_TRACE_ZONE_END(z0); return model; } TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreateFromFile( const char* model_path) { iree_allocator_t allocator = iree_allocator_system(); IREE_TRACE_ZONE_BEGIN(z0); // TODO(#3909): use file mapping C API. FILE* file = fopen(model_path, "r"); if (!file) { IREE_TRACE_MESSAGE(ERROR, "failed to open model file"); IREE_TRACE_MESSAGE_DYNAMIC(ERROR, model_path, strlen(model_path)); IREE_TRACE_ZONE_END(z0); return NULL; } fseek(file, 0, SEEK_END); size_t file_size = ftell(file); fseek(file, 0, SEEK_SET); TfLiteModel* model = NULL; iree_status_t status = iree_allocator_malloc( allocator, sizeof(TfLiteModel) + file_size, (void**)&model); if (!iree_status_is_ok(iree_status_consume_code(status))) { IREE_TRACE_MESSAGE(ERROR, "failed model+data allocation"); IREE_TRACE_ZONE_END(z0); return NULL; } memset(model, 0, sizeof(*model)); iree_atomic_ref_count_init(&model->ref_count); model->allocator = allocator; model->owned_model_data = (uint8_t*)model + file_size; int ret = fread(model->owned_model_data, 1, file_size, file); fclose(file); if (ret != file_size) { TfLiteModelDelete(model); IREE_TRACE_MESSAGE(ERROR, "failed model+data read"); IREE_TRACE_ZONE_END(z0); return NULL; } status = _TfLiteModelInitializeModule(model->owned_model_data, file_size, allocator, model); if (!iree_status_is_ok(iree_status_consume_code(status))) { TfLiteModelDelete(model); IREE_TRACE_ZONE_END(z0); return NULL; } IREE_TRACE_ZONE_END(z0); return model; } void _TfLiteModelRetain(TfLiteModel* model) { if (model) { iree_atomic_ref_count_inc(&model->ref_count); } } void _TfLiteModelRelease(TfLiteModel* model) { if (model && iree_atomic_ref_count_dec(&model->ref_count) == 1) { IREE_TRACE_ZONE_BEGIN(z0); iree_vm_module_release(model->module); iree_vm_instance_release(model->instance); iree_allocator_free(model->allocator, model); IREE_TRACE_ZONE_END(z0); } } TFL_CAPI_EXPORT extern void TfLiteModelDelete(TfLiteModel* model) { IREE_TRACE_ZONE_BEGIN(z0); _TfLiteModelRelease(model); IREE_TRACE_ZONE_END(z0); }