tch-tensor-like

Crates.iotch-tensor-like
lib.rstch-tensor-like
version0.6.0
sourcesrc
created_at2020-11-16 11:37:05.112524
updated_at2022-04-03 07:19:13.620301
descriptionDerive convenient methods for struct or enum of tch tensors
homepagehttps://github.com/jerry73204/tch-tensor-like
repositoryhttps://github.com/jerry73204/tch-tensor-like.git
max_upload_size
id312910
size22,007
(jerry73204)

documentation

https://docs.rs/tch-tensor-like/

README

tch-tensor-like: Derive Tensor-like Types for tch-rs

About this crate

If you are a user of tch-rs, perhaps you ever worked with a complex model input type like this.

struct ModelInput {
    pub images: Vec<Tensor>,
    pub kind: Tensor,
    pub label: Option<Tensor>,
}

Before you feed a batch input of this type into a model, you have to move it to the appropriate device. It could be tedious to call tensor.to_device() for each member of the type. The TensorLike derive macro comes to your rescue.

use tch_tensor_like::TensorLike;

#[derive(TensorLike)]
struct ModelInput {
    pub images: Vec<Tensor>,
    pub kind: Tensor,
    pub label: Option<Tensor>,
}

By deriving the macro, you have to_device(), to_kind() and shallow_clone() out of box.

let input: ModelInput = fetch_data();
let input = input.to_device(Device::cuda_if_available())
                 .to_kind(Kind::Float)
                 .shallow_clone();

For non-tensor members, you can mark the attributes to clone the value instead.

#[derive(TensorLike)]
struct ModelInput {
    // primitives are copied by default
    pub number: i32,

    // copy the field
    #[tensor_like(copy)]
    pub text: &'static str,

    // clone the field
    #[tensor_like(clone)]
    pub desc: String,
}

Usage

The crate is not published to crates.io yet. Add the repo link to include this crate in your project.

[dependencies]
tch-tensor-like = { git = "https://github.com/jerry73204/tch-tensor-like.git", features = ["derive"] }

License

MIT License. See LICENSE file.

Commit count: 24

cargo fmt