//! Demonstrates how to select or gather sub tensors (index) from tensors use dfdx::{ shapes::Rank3, tensor::{AsArray, AutoDevice, Tensor, TensorFrom}, tensor_ops::{GatherTo, SelectTo}, }; fn main() { let dev = AutoDevice::default(); let a: Tensor, f32, _> = dev.tensor([ [[0.00, 0.01, 0.02], [0.10, 0.11, 0.12]], [[1.00, 1.01, 1.02], [1.10, 1.11, 1.12]], [[2.00, 2.01, 2.02], [2.10, 2.11, 2.12]], [[3.00, 3.01, 3.02], [3.10, 3.11, 3.12]], ]); // the easiest thing to do is to `select` a single value from a given axis. // to do that, you need indices with a shape up to the axis you are select from. // for example, given shape (M, N, O), here are the index shapes for each axis: // - Axis 0: index shape () // - Axis 1: index shape (M, ) // - Axis 2: index shape (M, N) // here we select from axis 0 so we just need 1 value. let b = a.clone().select(dev.tensor(0)); assert_eq!(b.array(), a.array()[0]); // to `select` from axis 1, we use a tensor with shape (4,) let d = a.clone().select(dev.tensor([0, 1, 0, 1])); assert_eq!( d.array(), [ [0.00, 0.01, 0.02], [1.10, 1.11, 1.12], [2.00, 2.01, 2.02], [3.10, 3.11, 3.12] ] ); // We can also `gather` multiple elements from each axis. This lets you grab // the same elements multiple times! This requires an index with shape similar to select, // but with an extra dimension at the end that says how many elements to gather from the axis: // - Axis 0: index shape (Z, ) // - Axis 1: index shape (M, Z) // - Axis 2: index shape (M, N, Z) // here, we `gather` from axis 0 because we have a 1d tensor. the new size will be (6, 2, 3)! let c = a.clone().gather(dev.tensor([0, 0, 1, 1, 2, 2])); dbg!(c.array()); // and similarly, we can `gather` from axis 1 with a 2d tensor. the new size will be (4, 6, 3)! let e = a.gather(dev.tensor([[0; 6], [1; 6], [1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 1]])); dbg!(e.array()); }