Crates.io | candle-ext |
lib.rs | candle-ext |
version | 0.1.7 |
source | src |
created_at | 2023-11-05 08:47:31.820991 |
updated_at | 2023-12-12 14:39:20.851523 |
description | An extension library to Candle that provides PyTorch functions not currently available in Candle |
homepage | |
repository | https://github.com/mokeyish/candle-ext |
max_upload_size | |
id | 1025824 |
size | 76,868 |
An extension library to Candle that provides PyTorch functions not currently available in Candle
use candle_ext::{
candle::{ D, DType, Device, Result, Tensor},
TensorExt, F,
};
fn main() -> Result<()> {
let device = Device::Cpu;
let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
let k = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
let v = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;
let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, None)?;
Ok(())
}
Currently provides (see also tests):
F::scaled_dot_product_attention
F::chunk2..5 / Tensor::chunk2..5
F::cumsum / Tensor::cumsum
F::equal / Tensor::equal
F::eye / Tensor::eye
F::full / Tensor::full
F::full_like / Tensor::full_like
F::triu / Tensor::triu
F::tril / Tensor::tril
F::masked_fill / Tensor::masked_fill
F::logical_not / Tensor::logical_not
F::logical_or / Tensor::logical_or
F::outer / Tensor::outer
F::unbind / Tensor::unbind / F::unbind2..5 / Tensor::unbind2..5
Licensed under either of
at your option.
Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.