use crate::{ffi, Error, Tensor, TensorDescriptor, TensorMut}; use bitflags::bitflags; use std::{ convert::TryInto, ffi::CString, fmt::{self, Debug, Formatter}, mem::MaybeUninit, ptr::NonNull, }; /// A backend which can run inference on a model. pub struct InferenceContext { ctx: NonNull, } impl InferenceContext { /// Create a new [`InferenceContext`]. /// /// # Safety /// /// This takes ownership of the `ctx` pointer and will deallocate it on /// drop. pub(crate) unsafe fn new(ctx: NonNull) -> Self { InferenceContext { ctx } } pub fn infer( &mut self, inputs: &[Tensor<'_>], outputs: &mut [TensorMut<'_>], ) -> Result<(), InferError> { // Safety: We are effectively casting a &T to a *mut T here. This is // okay, but only as long as the infer() function doesn't mutate the // input tensors in any way (casting from *mut T to &mut T would still // be UB, though). unsafe { let inputs: Vec<_> = inputs.iter().map(|t| t.as_coral_tensor()).collect(); let mut outputs: Vec<_> = outputs.iter_mut().map(|t| t.as_coral_tensor()).collect(); let ret = ffi::infer( self.ctx.as_ptr(), inputs.as_ptr() as *mut _, inputs.len() as ffi::size_t, outputs.as_mut_ptr(), outputs.len() as ffi::size_t, ); check_inference_error(ret) } } pub fn create_context( mimetype: &str, model: &[u8], acceleration_backend: AccelerationBackend, ) -> Result { let mimetype = CString::new(mimetype)?; let mut inference_context = MaybeUninit::uninit(); // Safety: We've ensured our inputs are sane by construction (i.e. Rust // doesn't let you create a null slice and all enums are exhaustive) // and our `inputs` and `outputs` tensor vector can't outlive the // `inputs` and `outputs` function arguments. unsafe { let ret = ffi::create_inference_context( mimetype.as_ptr(), model.as_ptr().cast(), model.len() as ffi::size_t, (acceleration_backend.bits() as i32).try_into().unwrap(), inference_context.as_mut_ptr(), ); check_load_result(ret)?; let inference_context = inference_context.assume_init(); Ok(InferenceContext::new( NonNull::new(inference_context).expect("Should be initialized"), )) } } pub fn opcount(&self) -> u64 { unsafe { ffi::inference_opcount(self.ctx.as_ptr()).into() } } pub fn inputs(&self) -> impl Iterator> + '_ { unsafe { let mut inputs = MaybeUninit::uninit(); let len = ffi::inference_inputs(self.ctx.as_ptr(), inputs.as_mut_ptr()); descriptors(inputs.assume_init(), len.into()) } } pub fn outputs(&self) -> impl Iterator> + '_ { unsafe { let mut outputs = MaybeUninit::uninit(); let len = ffi::inference_outputs(self.ctx.as_ptr(), outputs.as_mut_ptr()); descriptors(outputs.assume_init(), len.into()) } } } /// Iterate over the [`TensorDescriptor`]s for a set of tensors. /// /// # Safety /// /// The caller must ensure the returned iterator (`'a`) doesn't outlive the data /// pointed to by `tensors`. unsafe fn descriptors<'a>( tensors: *const ffi::RuneCoralTensor, len: u64, ) -> impl Iterator> { // Safety: Assumes the tensors are valid. The caller guarantees the 'a // lifetime doesn't outlive the original tensors. let tensors = if len > 0 { std::slice::from_raw_parts(tensors, len as usize) } else { // Note: The tensors pointer may be null when len == 0, so let's swap it // out with an empty slice &[] }; tensors.iter().map(TensorDescriptor::from_rune_coral_tensor) } impl Debug for InferenceContext { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("InferenceContext").finish_non_exhaustive() } } impl Drop for InferenceContext { fn drop(&mut self) { unsafe { ffi::destroy_inference_context(self.ctx.as_ptr()); } } } #[derive(Debug, Copy, Clone, PartialEq, thiserror::Error)] pub enum LoadError { #[error("Incorrect mimetype")] IncorrectMimeType, #[error("Internal error")] InternalError, #[error("Unknown error {}", return_code)] Other { return_code: ffi::RuneCoralLoadResult, }, } fn check_load_result(return_code: ffi::RuneCoralLoadResult) -> Result<(), LoadError> { match return_code { ffi::RuneCoralLoadResult__Ok => Ok(()), ffi::RuneCoralLoadResult__IncorrectMimeType => Err(LoadError::IncorrectMimeType), ffi::RuneCoralLoadResult__InternalError => Err(LoadError::InternalError), _ => Err(LoadError::Other { return_code }), } } // Safety: There shouldn't be any thread-specific state, so it's okay to move // the inference context to another thread. // // The inference context is very much **not** thread-safe though, so we can't // implement Sync. unsafe impl Send for InferenceContext {} fn check_inference_error(return_code: ffi::RuneCoralInferenceResult) -> Result<(), InferError> { match return_code { ffi::RuneCoralInferenceResult__Ok => Ok(()), ffi::RuneCoralInferenceResult__Error => Err(InferError::InterpreterError), ffi::RuneCoralInferenceResult__DelegateError => Err(InferError::DelegateError), ffi::RuneCoralInferenceResult__ApplicationError => Err(InferError::ApplicationError), _ => Err(InferError::Other { return_code }), } } #[derive(Debug, Copy, Clone, PartialEq, thiserror::Error)] pub enum InferError { /// Generally referring to an error in the runtime (i.e. interpreter). #[error("The TensorFlow Lite interpreter encountered an error")] InterpreterError, /// Generally referring to an error from a TfLiteDelegate itself. #[error("A delegate returned an error")] DelegateError, // Generally referring to an error in applying a delegate due to // incompatibility between runtime and delegate, e.g., this error is returned // when trying to apply a TfLite delegate onto a model graph that's already // immutable. #[error("Invalid model graph or incompatibility between runtime and delegates")] ApplicationError, #[error("Unknown inference error {}", return_code)] Other { return_code: ffi::RuneCoralInferenceResult, }, } bitflags! { pub struct AccelerationBackend: u32 { const NONE = ffi::RuneCoralAccelerationBackend__None as u32; const EDGETPU = ffi::RuneCoralAccelerationBackend__Edgetpu as u32; const GPU = ffi::RuneCoralAccelerationBackend__Gpu as u32; } } impl AccelerationBackend { /// Get all [`AccelerationBackend`]s that are available on this device. pub fn currently_available() -> Self { unsafe { AccelerationBackend::from_bits(ffi::availableAccelerationBackends() as u32).unwrap() } } } #[cfg(test)] mod tests { use std::sync::Mutex; use super::*; #[test] fn inference_context_is_only_send() { static_assertions::assert_impl_all!(InferenceContext: Send); static_assertions::assert_not_impl_any!(InferenceContext: Sync); // but we can wrap it in a mutex! static_assertions::assert_impl_all!(Mutex: Send, Sync); } }