ort2

Crates.ioort2
lib.rsort2
version0.1.2
created_at2025-02-08 09:06:54.006966+00
updated_at2025-02-11 03:50:18.225015+00
descriptiononnxruntime wrapper c/c++ api
homepage
repository
max_upload_size
id1547875
size66,751
(yexiangyu)

documentation

README

ort2

  • onnxruntime wrapper for rust 2

  • tested

OS Linux Windows MacOS
CPU Y Y Y(aarch64)
CUDA Y Y N/A

Pre-requirements

  • download onnxruntime from onnxruntime github Repo, unzip it
  • setup enviroment variable
    • ORT_INC_PATH=/opt/homebrew/opt/onnxruntime/include
    • ORT_LIB_PATH=/opt/homebrew/opt/onnxruntime/lib

Getting Started

use ort2::prelude::*;

// load model stuff
let model = include_bytes!("models/mnist-8.onnx");

// create session
let session = Session::builder()
    .build(model.as_ref())
    .expect("failed to create session");

// dump input
let input = vec![0.0f32;28 * 28];

// create value from input
let value = Value::tensor()
    .with_shape([1, 1, 28, 28])
    .with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
    .borrow(&input)
    .expect("failed to build value");

// get output
let output = session.run([&value])
    .expect("failed to run")
    .into_iter()
    .next()
    .expect("failed to get outputs");

// output of session as ndarray Array
let output = output
    .view::<f32>()
    .expect("failed to view output");

assert_eq!(output.shape()[1], 10);

Run with IoBinding

use ort2::prelude::*;

// load model stuff
let model = include_bytes!("models/mnist-8.onnx");

// create session
let session = Session::builder()
    .build(model.as_ref())
    .expect("failed to create session");

// dump input
let input = vec![0.0f32;28 * 28];

// create value from input
let value = Value::tensor()
    .with_shape([1, 1, 28, 28])
    .with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
    .borrow(&input)
    .expect("failed to build value");

// create iobinding
let mut iobinding = session.iobinding()
    .expect("failed to create iobinding");

// bind input
iobinding.bind_input(
        &session.get_inputs()
        .expect("failed to get input")[0].name,
        &value
    )
    .expect("failed to bind input");

let mem_info = MemoryInfo::default();

// bind_output

iobinding.bind_output_to_device(
        &session.get_outputs()
        .expect("failed to get outputs")[0].name,
        &mem_info
    )
    .expect("failed to bind outputs");

// run
session.run_with_iobinding(&mut iobinding)
    .expect("failed to run");

let alloc = DefaultAllocator::default();

// get output
let output = iobinding
    .get_bound_outputs(&alloc)
    .expect("failed to get output from iobinding")
    .into_iter()
    .next()
    .expect("failed to get output");

let output = output
    .view::<f32>()
    .expect("failed to view output");

assert_eq!(output.shape()[1], 10);
Commit count: 0

cargo fmt