Skip to content

Commit f7639bd

Browse files
authored
Repeat operation (#2090)
* renaming repeat to repeat_dim * implementing repeat function * renaming repeat files to repeat_dim * renaming part 2 * renaming part 3 * renaming part 4 * renaming part 5 * adding test file * adding unit test * adding rust book documentation * adding function args doc * fixing tests * changing repeat api to match pytorch equivalent * fixing clippy error
1 parent bb13729 commit f7639bd

File tree

40 files changed

+478
-174
lines changed

40 files changed

+478
-174
lines changed

burn-book/src/building-blocks/tensor.md

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -131,40 +131,41 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
131131

132132
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
133133

134-
| Burn | PyTorch Equivalent |
135-
| ------------------------------------- | ------------------------------------ |
136-
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
137-
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
138-
| `Tensor::from_primitive(primitive)` | N/A |
139-
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
140-
| `tensor.all()` | `tensor.all()` |
141-
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
142-
| `tensor.any()` | `tensor.any()` |
143-
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
144-
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
145-
| `tensor.device()` | `tensor.device` |
146-
| `tensor.dims()` | `tensor.size()` |
147-
| `tensor.equal(other)` | `x == y` |
148-
| `tensor.expand(shape)` | `tensor.expand(shape)` |
149-
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
150-
| `tensor.flip(axes)` | `tensor.flip(axes)` |
151-
| `tensor.into_data()` | N/A |
152-
| `tensor.into_primitive()` | N/A |
153-
| `tensor.into_scalar()` | `tensor.item()` |
154-
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
155-
| `tensor.not_equal(other)` | `x != y` |
156-
| `tensor.permute(axes)` | `tensor.permute(axes)` |
157-
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
158-
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
159-
| `tensor.reshape(shape)` | `tensor.view(shape)` |
160-
| `tensor.shape()` | `tensor.shape` |
161-
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
162-
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
163-
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
164-
| `tensor.to_data()` | N/A |
165-
| `tensor.to_device(device)` | `tensor.to(device)` |
166-
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
167-
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
134+
| Burn | PyTorch Equivalent |
135+
| ------------------------------------- | ------------------------------------------------------------------------ |
136+
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
137+
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
138+
| `Tensor::from_primitive(primitive)` | N/A |
139+
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
140+
| `tensor.all()` | `tensor.all()` |
141+
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
142+
| `tensor.any()` | `tensor.any()` |
143+
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
144+
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
145+
| `tensor.device()` | `tensor.device` |
146+
| `tensor.dims()` | `tensor.size()` |
147+
| `tensor.equal(other)` | `x == y` |
148+
| `tensor.expand(shape)` | `tensor.expand(shape)` |
149+
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
150+
| `tensor.flip(axes)` | `tensor.flip(axes)` |
151+
| `tensor.into_data()` | N/A |
152+
| `tensor.into_primitive()` | N/A |
153+
| `tensor.into_scalar()` | `tensor.item()` |
154+
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
155+
| `tensor.not_equal(other)` | `x != y` |
156+
| `tensor.permute(axes)` | `tensor.permute(axes)` |
157+
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
158+
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])`|
159+
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
160+
| `tensor.reshape(shape)` | `tensor.view(shape)` |
161+
| `tensor.shape()` | `tensor.shape` |
162+
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
163+
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
164+
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
165+
| `tensor.to_data()` | N/A |
166+
| `tensor.to_device(device)` | `tensor.to(device)` |
167+
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
168+
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
168169

169170
### Numeric Operations
170171

crates/burn-autodiff/src/ops/bool_tensor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
132132
B::bool_expand(tensor, shape)
133133
}
134134

135-
fn bool_repeat<const D: usize>(
135+
fn bool_repeat_dim<const D: usize>(
136136
tensor: BoolTensor<B, D>,
137137
dim: usize,
138138
times: usize,
139139
) -> BoolTensor<B, D> {
140-
B::bool_repeat(tensor, dim, times)
140+
B::bool_repeat_dim(tensor, dim, times)
141141
}
142142
}

crates/burn-autodiff/src/ops/int_tensor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
162162
B::int_mean_dim(tensor, dim)
163163
}
164164

165-
fn int_repeat<const D: usize>(
165+
fn int_repeat_dim<const D: usize>(
166166
tensor: IntTensor<B, D>,
167167
dim: usize,
168168
times: usize,
169169
) -> IntTensor<B, D> {
170-
B::int_repeat(tensor, dim, times)
170+
B::int_repeat_dim(tensor, dim, times)
171171
}
172172

173173
fn int_greater<const D: usize>(lhs: IntTensor<B, D>, rhs: IntTensor<B, D>) -> BoolTensor<B, D> {

crates/burn-autodiff/src/ops/tensor.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,7 +2418,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
24182418
B::float_argsort(tensor.primitive, dim, descending)
24192419
}
24202420

2421-
fn float_repeat<const D: usize>(
2421+
fn float_repeat_dim<const D: usize>(
24222422
tensor: FloatTensor<Self, D>,
24232423
dim: usize,
24242424
times: usize,
@@ -2437,7 +2437,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
24372437
impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> {
24382438
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
24392439
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
2440-
let out = B::float_repeat(tensor, self.dim, self.times);
2440+
let out = B::float_repeat_dim(tensor, self.dim, self.times);
24412441
states.save(out_node, out)
24422442
}
24432443
}
@@ -2467,9 +2467,11 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
24672467
.stateful()
24682468
{
24692469
OpsKind::Tracked(prep) => {
2470-
prep.finish(dim, B::float_repeat(tensor.primitive, dim, times))
2470+
prep.finish(dim, B::float_repeat_dim(tensor.primitive, dim, times))
2471+
}
2472+
OpsKind::UnTracked(prep) => {
2473+
prep.finish(B::float_repeat_dim(tensor.primitive, dim, times))
24712474
}
2472-
OpsKind::UnTracked(prep) => prep.finish(B::float_repeat(tensor.primitive, dim, times)),
24732475
}
24742476
}
24752477

crates/burn-autodiff/src/tests/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ mod permute;
4747
mod pow;
4848
mod recip;
4949
mod relu;
50-
mod repeat;
50+
mod repeat_dim;
5151
mod reshape;
5252
mod select;
5353
mod sigmoid;
@@ -133,6 +133,6 @@ macro_rules! testgen_all {
133133
burn_autodiff::testgen_ad_sign!();
134134
burn_autodiff::testgen_ad_expand!();
135135
burn_autodiff::testgen_ad_sort!();
136-
burn_autodiff::testgen_ad_repeat!();
136+
burn_autodiff::testgen_ad_repeat_dim!();
137137
};
138138
}

crates/burn-autodiff/src/tests/repeat.rs renamed to crates/burn-autodiff/src/tests/repeat_dim.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#[burn_tensor_testgen::testgen(ad_repeat)]
1+
#[burn_tensor_testgen::testgen(ad_repeat_dim)]
22
mod tests {
33
use super::*;
44
use burn_tensor::{activation, TensorData};
@@ -12,7 +12,7 @@ mod tests {
1212
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
1313
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
1414

15-
let tensor_3 = tensor_2.clone().repeat(1, 3);
15+
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
1616

1717
let tensor_3 = tensor_1.matmul(tensor_3);
1818
let grads = tensor_3.backward();

crates/burn-candle/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ mod tests {
9494
// burn_tensor::testgen_powf!();
9595

9696
burn_tensor::testgen_random!();
97-
burn_tensor::testgen_repeat!();
97+
burn_tensor::testgen_repeat_dim!();
9898
burn_tensor::testgen_reshape!();
9999
burn_tensor::testgen_select!();
100100
burn_tensor::testgen_sin!();

crates/burn-core/src/nn/attention/mask.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
1919
mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
2020
}
2121

22-
mask = mask.repeat(0, batch_size);
22+
mask = mask.repeat_dim(0, batch_size);
2323

2424
mask.equal_elem(1_i64.elem::<i64>())
2525
}

crates/burn-core/src/nn/loss/cross_entropy.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
152152
* weights
153153
.clone()
154154
.reshape([1, nr_classes])
155-
.repeat(0, batch_size);
155+
.repeat_dim(0, batch_size);
156156
let weights = weights.clone().gather(0, targets);
157157
let tensor = Self::apply_mask_2d(tensor, mask);
158158
tensor.sum().neg() / weights.sum()
@@ -224,7 +224,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
224224
fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
225225
if let Some(mask) = mask {
226226
let [batch_size, nr_classes] = tensor.dims();
227-
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0);
227+
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
228228
}
229229

230230
tensor
@@ -312,7 +312,7 @@ mod tests {
312312
* targets_logits
313313
* Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
314314
.unsqueeze()
315-
.repeat(0, 4);
315+
.repeat_dim(0, 4);
316316
let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
317317
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
318318
}

crates/burn-core/src/nn/rope_encoding.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ impl RotaryEncodingConfig {
5858
.float()
5959
.unsqueeze()
6060
.transpose()
61-
.repeat(1, self.d_model / 2)
61+
.repeat_dim(1, self.d_model / 2)
6262
* theta_i.unsqueeze();
6363

6464
// Convert frequency values to complex numbers (polar form)
@@ -71,7 +71,7 @@ impl RotaryEncodingConfig {
7171
.reshape([self.max_sequence_length, 2, self.d_model / 2])
7272
.transpose()
7373
.unsqueeze_dim::<4>(2)
74-
.repeat(2, 2)
74+
.repeat_dim(2, 2)
7575
.reshape([self.max_sequence_length, self.d_model, 2]);
7676

7777
RotaryEncoding {

0 commit comments

Comments
 (0)