candle-ext

Crates.iocandle-ext
lib.rscandle-ext
version0.1.7
sourcesrc
created_at2023-11-05 08:47:31.820991
updated_at2023-12-12 14:39:20.851523
descriptionAn extension library to Candle that provides PyTorch functions not currently available in Candle
homepage
repositoryhttps://github.com/mokeyish/candle-ext
max_upload_size
id1025824
size76,868
YISH (mokeyish)

documentation

README

Candle Extensions

Test

An extension library to Candle that provides PyTorch functions not currently available in Candle

 use candle_ext::{
     candle::{ D, DType, Device, Result, Tensor},
     TensorExt, F,
 };

 fn main() -> Result<()> {
     let device = Device::Cpu;
     let q = Tensor::randn(0., 1., (3, 3, 2, 4), &device)?;
     let k = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let v = Tensor::randn(0., 1., (1, 3, 3, 4), &device)?;
     let m = Tensor::ones((q.dim(D::Minus2)?, k.dim(D::Minus2)?), DType::U8, &device)?.tril(0)?;

     let o = F::scaled_dot_product_attention(&q, &k, &v, Some(&m), None, None, None)?;

     Ok(())
 }

Currently provides (see also tests):

  • F::scaled_dot_product_attention

  • F::chunk2..5 / Tensor::chunk2..5

  • F::cumsum / Tensor::cumsum

  • F::equal / Tensor::equal

  • F::eye / Tensor::eye

  • F::full / Tensor::full

  • F::full_like / Tensor::full_like

  • F::triu / Tensor::triu

  • F::tril / Tensor::tril

  • F::masked_fill / Tensor::masked_fill

  • F::logical_not / Tensor::logical_not

  • F::logical_or / Tensor::logical_or

  • F::outer / Tensor::outer

  • F::unbind / Tensor::unbind / F::unbind2..5 / Tensor::unbind2..5

License

Licensed under either of

at your option.

Contribution

Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any additional terms or conditions.

Commit count: 23

cargo fmt