Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.into_scalar()` | `tensor.item()` (for single-element tensors) |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_chunk(tensor, chunks, dim)
}

fn bool_permute<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: [usize; D],
) -> BoolTensor<Self, D> {
B::bool_permute(tensor, axes)
}

fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
B::bool_argwhere(tensor)
}
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,11 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_arange(range: std::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self, 1> {
B::int_arange(range, device)
}

fn int_permute<const D: usize>(
tensor: IntTensor<Self, D>,
axes: [usize; D],
) -> IntTensor<Self, D> {
B::int_permute(tensor, axes)
}
}
56 changes: 56 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,62 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}

fn float_permute<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: [usize; D],
) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct PermuteDim;

#[derive(new, Debug)]
struct RetroPermuteDims<B: Backend, const D: usize> {
input_id: NodeID,
axes: [usize; D],
_backend: PhantomData<B>,
}

impl<B: Backend, const D: usize> RetroForward for RetroPermuteDims<B, D> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let input = states.get_state::<B::FloatTensorPrimitive<D>>(&self.input_id);
let out = B::float_permute(input, self.axes);
states.save(out_node, out)
}
}

impl<B: Backend, const D: usize> Backward<B, D, 1> for PermuteDim {
type State = [usize; D];

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let axes = ops.state;

let mut inverse: [usize; D] = [0; D];
axes.iter()
.enumerate()
.for_each(|(i, &axis)| inverse[axis] = i);

unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
B::float_permute(grad, inverse)
});
}
}

match PermuteDim
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroPermuteDims::<B, D>::new(tensor.node.id.clone(), axes))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(axes, B::float_permute(tensor.primitive, axes)),
OpsKind::UnTracked(prep) => prep.finish(B::float_permute(tensor.primitive, axes)),
}
}

fn float_reshape<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ mod mul;
mod multithread;
mod neg;
mod nonzero;
mod permute;
mod pow;
mod recip;
mod relu;
Expand Down Expand Up @@ -109,6 +110,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_sigmoid!();
burn_autodiff::testgen_ad_transpose!();
burn_autodiff::testgen_ad_permute!();
burn_autodiff::testgen_ad_nonzero!();
};
}
28 changes: 28 additions & 0 deletions crates/burn-autodiff/src/tests/permute.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#[burn_tensor_testgen::testgen(ad_permute)]
mod tests {
use super::*;
use burn_tensor::Data;

#[test]
fn should_diff_permute() {
let data_1: Data<f32, 3> = Data::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
let data_2: Data<f32, 3> = Data::from([[[1.0, 7.0], [3.2, 2.0], [3.0, 3.0]]]); // 1x3x2

let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();

let tensor_3 = tensor_2.clone().permute([0, 2, 1]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

assert_eq!(grad_1.to_data(), Data::from([[[7.2, 12.0], [7.2, 12.0]]])); // 1x2x2
assert_eq!(
grad_2.to_data(),
Data::from([[[3.0, 10.0], [3.0, 10.0], [3.0, 10.0]]]) // 1x3x2
);
}
}
1 change: 1 addition & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ mod tests {
burn_tensor::testgen_maxmin!();
burn_tensor::testgen_mul!();
burn_tensor::testgen_neg!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_argwhere_nonzero!();

// TODO: https://github.com/tracel-ai/burn/issues/1237
Expand Down
7 changes: 7 additions & 0 deletions crates/burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ pub fn swap_dims<E: CandleElement, const D: usize>(
CandleTensor::new(tensor.tensor.transpose(dim1, dim2).unwrap())
}

pub fn permute<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
axes: [usize; D],
) -> CandleTensor<E, D> {
CandleTensor::new(tensor.tensor.permute(axes).unwrap())
}

pub fn reshape<E: CandleElement, const D1: usize, const D2: usize>(
tensor: CandleTensor<E, D1>,
shape: Shape<D2>,
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-candle/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::{
Candle, CandleTensor,
};

use super::base::permute;

impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
super::base::empty(shape, device)
Expand Down Expand Up @@ -126,4 +128,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
) -> Vec<BoolTensor<Self, D>> {
super::base::chunk(tensor, chunks, dim)
}

fn bool_permute<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: [usize; D],
) -> BoolTensor<Self, D> {
permute(tensor, axes)
}
}
9 changes: 9 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::{
Candle, CandleTensor,
};

use super::base::permute;

impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
super::base::empty(shape, device)
Expand Down Expand Up @@ -410,4 +412,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
),
}
}

fn int_permute<const D: usize>(
tensor: IntTensor<Self, D>,
axes: [usize; D],
) -> IntTensor<Self, D> {
permute(tensor, axes)
}
}
9 changes: 9 additions & 0 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use crate::{
Candle, CandleTensor,
};

use super::base::permute;

impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data<const D: usize>(
data: Data<F, D>,
Expand Down Expand Up @@ -515,4 +517,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
.unwrap(),
)
}

fn float_permute<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: [usize; D],
) -> FloatTensor<Self, D> {
permute(tensor, axes)
}
}
46 changes: 43 additions & 3 deletions crates/burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use crate::{
ops::binary::binary_ops_shape,
stream::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, Operation, OperationDescription, ReshapeDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
CatOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId,
SwapDimsDescription, UnaryOperationDescription,
},
Fusion, FusionBackend,
};
Expand Down Expand Up @@ -426,4 +426,44 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {

out
}

fn bool_permute<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: [usize; D],
) -> BoolTensor<Self, D> {
#[derive(new)]
struct PermuteDimsOps<const D: usize> {
desc: PermuteOperationDescription,
}

impl<const D: usize, B: FusionBackend> Operation<B> for PermuteDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let axes: [usize; D] = self.desc.axes.try_into().unwrap();
let output = B::bool_permute(input, axes);
handles.register_bool_tensor(&self.desc.out.id, output);
}
}

let stream = tensor.stream;

// Change the shape of the tensor to match the new axes
let shape = axes.into_iter().map(|x| tensor.shape[x]).collect();

let out = tensor.client.tensor_uninitialized(shape);

let desc = PermuteOperationDescription {
input: tensor.into_description(),
axes: axes.to_vec(),
out: out.to_description_out(),
};

out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Permute(desc.clone())),
PermuteDimsOps::<D>::new(desc),
);

out
}
}
42 changes: 41 additions & 1 deletion crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, FloatOperationDescription, GatherOperationDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
Operation, OperationDescription, RandomOperationDescription,
Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
Expand Down Expand Up @@ -1806,4 +1806,44 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {

out
}

fn float_permute<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: [usize; D],
) -> FloatTensor<Self, D> {
#[derive(new)]
struct PermuteDimsOps<const D: usize> {
desc: PermuteOperationDescription,
}

impl<const D: usize, B: FusionBackend> Operation<B> for PermuteDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let axes: [usize; D] = self.desc.axes.try_into().unwrap();
let output = B::float_permute(input, axes);
handles.register_float_tensor(&self.desc.out.id, output);
}
}

let stream = tensor.stream;

// Change the shape of the tensor to match the new axes
let shape = axes.into_iter().map(|x| tensor.shape[x]).collect();

let out = tensor.client.tensor_uninitialized(shape);

let desc = PermuteOperationDescription {
input: tensor.into_description(),
axes: axes.to_vec(),
out: out.to_description_out(),
};

out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Permute(desc.clone())),
PermuteDimsOps::<D>::new(desc),
);

out
}
}
46 changes: 43 additions & 3 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use crate::{
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
MaskWhereOperationDescription, NumericOperationDescription, Operation,
OperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
SelectAssignOperationDescription, SelectOperationDescription,
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
},
Expand Down Expand Up @@ -1471,4 +1471,44 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {

out
}

fn int_permute<const D: usize>(
tensor: IntTensor<Self, D>,
axes: [usize; D],
) -> IntTensor<Self, D> {
#[derive(new)]
struct PermuteDimsOps<const D: usize> {
desc: PermuteOperationDescription,
}

impl<const D: usize, B: FusionBackend> Operation<B> for PermuteDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let axes: [usize; D] = self.desc.axes.try_into().unwrap();
let output = B::int_permute(input, axes);
handles.register_int_tensor(&self.desc.out.id, output);
}
}

let stream = tensor.stream;

// Change the shape of the tensor to match the new axes
let shape = axes.into_iter().map(|x| tensor.shape[x]).collect();

let out = tensor.client.tensor_uninitialized(shape);

let desc = PermuteOperationDescription {
input: tensor.into_description(),
axes: axes.to_vec(),
out: out.to_description_out(),
};

out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Permute(desc.clone())),
PermuteDimsOps::<D>::new(desc),
);

out
}
}
Loading