use hotg_rune_core::{Shape, TensorList, TensorListMut};
use alloc::{
vec::Vec,
string::{String, ToString},
};
use core::marker::PhantomData;
use crate::intrinsics::StringRef;
#[derive(Debug, Clone, PartialEq)]
pub struct Model {
id: u32,
input_shapes: Vec>,
output_shapes: Vec>,
_types: PhantomData Output>,
}
impl Model {
pub fn load(
mimetype: &str,
model_data: &[u8],
input_shapes: &[Shape<'static>],
output_shapes: &[Shape<'static>],
) -> Self {
let id = unsafe {
let input_shape_descriptors: Vec =
input_shapes.iter().map(|s| s.to_string()).collect();
let input_shape_descriptors: Vec<_> = input_shape_descriptors
.iter()
.map(|s| StringRef::from(s.as_str()))
.collect();
let output_shape_descriptors: Vec =
output_shapes.iter().map(|s| s.to_string()).collect();
let output_shape_descriptors: Vec<_> = output_shape_descriptors
.iter()
.map(|s| StringRef::from(s.as_str()))
.collect();
crate::intrinsics::rune_model_load(
mimetype.as_ptr(),
mimetype.len() as u32,
model_data.as_ptr(),
model_data.len() as u32,
input_shape_descriptors.as_ptr(),
input_shape_descriptors.len() as u32,
output_shape_descriptors.as_ptr(),
output_shape_descriptors.len() as u32,
)
};
Model {
id,
input_shapes: input_shapes.into(),
output_shapes: output_shapes.into(),
_types: PhantomData,
}
}
}
impl Model
where
for<'a> &'a Input: TensorList<'a>,
Output: TensorListMut,
{
pub fn transform(&mut self, inputs: Input) -> Output {
assert_eq!(
(&inputs).shape_list().as_ref(),
&self.input_shapes,
"The input had the wrong shape",
);
let mut outputs =