Crates.io | tch-tensor-like |
lib.rs | tch-tensor-like |
version | 0.6.0 |
source | src |
created_at | 2020-11-16 11:37:05.112524 |
updated_at | 2022-04-03 07:19:13.620301 |
description | Derive convenient methods for struct or enum of tch tensors |
homepage | https://github.com/jerry73204/tch-tensor-like |
repository | https://github.com/jerry73204/tch-tensor-like.git |
max_upload_size | |
id | 312910 |
size | 22,007 |
If you are a user of tch-rs, perhaps you ever worked with a complex model input type like this.
struct ModelInput {
pub images: Vec<Tensor>,
pub kind: Tensor,
pub label: Option<Tensor>,
}
Before you feed a batch input of this type into a model, you have to move it to the appropriate device.
It could be tedious to call tensor.to_device()
for each member of the type.
The TensorLike
derive macro comes to your rescue.
use tch_tensor_like::TensorLike;
#[derive(TensorLike)]
struct ModelInput {
pub images: Vec<Tensor>,
pub kind: Tensor,
pub label: Option<Tensor>,
}
By deriving the macro, you have to_device()
, to_kind()
and shallow_clone()
out of box.
let input: ModelInput = fetch_data();
let input = input.to_device(Device::cuda_if_available())
.to_kind(Kind::Float)
.shallow_clone();
For non-tensor members, you can mark the attributes to clone the value instead.
#[derive(TensorLike)]
struct ModelInput {
// primitives are copied by default
pub number: i32,
// copy the field
#[tensor_like(copy)]
pub text: &'static str,
// clone the field
#[tensor_like(clone)]
pub desc: String,
}
The crate is not published to crates.io yet. Add the repo link to include this crate in your project.
[dependencies]
tch-tensor-like = { git = "https://github.com/jerry73204/tch-tensor-like.git", features = ["derive"] }
MIT License. See LICENSE file.