Skip to content

Commit 6d0db87

Browse files
authored
[Perf] Interpolate optimizations (#3077)
* Refactor interpolate to use NHWC and fix OOB issue * Fix slice negative range * Fix review comment * Add regression test to `is_contiguous`
1 parent 4a28812 commit 6d0db87

File tree

23 files changed

+349
-234
lines changed

23 files changed

+349
-234
lines changed

crates/burn-cubecl/src/kernel/conv/conv2d/implicit_gemm/launch.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use cubecl::linalg::{
1818

1919
use crate::{
2020
CubeElement, CubeRuntime, FloatElement,
21-
ops::{numeric::empty_device_strided, permute},
21+
ops::{numeric::empty_device_strided, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
2222
tensor::CubeTensor,
2323
};
2424

@@ -95,8 +95,8 @@ where
9595
width,
9696
);
9797

98-
let input = permute(input, &[0, 2, 3, 1]);
99-
let weight = permute(weight, &[0, 2, 3, 1]);
98+
let input = permute_nchw_to_nhwc(input);
99+
let weight = permute_nchw_to_nhwc(weight);
100100

101101
let out_shape = Shape::new([batch_size, out_h, out_w, out_channels]);
102102
let out =
@@ -117,5 +117,5 @@ where
117117
},
118118
)?;
119119

120-
Ok(permute(out, &[0, 3, 1, 2]))
120+
Ok(permute_nhwc_to_nchw(out))
121121
}

crates/burn-cubecl/src/kernel/conv/conv2d/layout_swap.rs

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ use cubecl::{CubeCount, CubeDim, prelude::*};
33

44
use crate::{
55
CubeElement, CubeRuntime,
6-
kernel::into_contiguous,
7-
ops::{max_vectorization, numeric::empty_device_strided},
6+
ops::{max_line_size, numeric::empty_device_strided},
87
tensor::CubeTensor,
98
};
109

@@ -55,7 +54,7 @@ pub fn nchw_to_nhwc<R: CubeRuntime, E: CubeElement>(input: CubeTensor<R>) -> Cub
5554
};
5655
let cube_count = CubeCount::Static(cube_count_x, cube_count_y, cube_count_z);
5756

58-
let in_vec = max_vectorization(&input);
57+
let in_vec = max_line_size(&input);
5958
let out_vec = R::supported_line_sizes()
6059
.iter()
6160
.copied()
@@ -201,26 +200,3 @@ pub fn swizzle(offset: u32, #[comptime] bank_count: i32) -> u32 {
201200

202201
offset ^ ((offset & yyy_mask) >> mask_shift)
203202
}
204-
205-
/// Transpose an NCHW tensor to NHWC.
206-
pub fn permute_nchw_to_nhwc<R: CubeRuntime, E: CubeElement>(input: CubeTensor<R>) -> CubeTensor<R> {
207-
// Disabled for now, need to fix hard to track down bug
208-
209-
// let use_plane = if cfg!(target_family = "wasm") {
210-
// // Any plane op enables subgroups on wasm so we need to make sure it is entirely supported.
211-
// // Otherwise, `nchw_to_nhwc` will cause compilation issues on wasm.
212-
// input
213-
// .client
214-
// .properties()
215-
// .feature_enabled(cubecl::Feature::Plane)
216-
// } else {
217-
// true // plane broadcast was/is always used on other platforms
218-
// };
219-
220-
// if input.is_contiguous() && use_plane {
221-
// nchw_to_nhwc::<R, E>(input)
222-
// } else {
223-
// crate::ops::permute(input, &[0, 2, 3, 1])
224-
// }
225-
into_contiguous(crate::ops::permute(input, &[0, 2, 3, 1]))
226-
}

crates/burn-cubecl/src/kernel/conv/mod.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,4 @@ pub(crate) use conv3d::*;
1010
pub(crate) use deform_conv_transpose2d::*;
1111
pub(crate) use deform_conv2d::*;
1212

13-
pub use conv2d::{
14-
Conv2dStrategy, ConvTranspose2dStrategy, conv_transpose2d, conv2d, nchw_to_nhwc,
15-
permute_nchw_to_nhwc,
16-
};
13+
pub use conv2d::{Conv2dStrategy, ConvTranspose2dStrategy, conv_transpose2d, conv2d, nchw_to_nhwc};
Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::{
2-
CubeRuntime, FloatElement, kernel::into_contiguous, ops::numeric::empty_device,
2+
CubeRuntime, FloatElement,
3+
ops::{numeric::empty_device_strided, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
34
tensor::CubeTensor,
45
};
56
use burn_tensor::{
@@ -20,18 +21,22 @@ pub fn interpolate<R: CubeRuntime, E: FloatElement>(
2021
output_size: [usize; 2],
2122
options: InterpolateOptions,
2223
) -> CubeTensor<R> {
23-
let input = into_contiguous(input);
2424
let [batch_size, channels, _, _] = input.shape.dims();
2525
let [out_height, out_width] = output_size;
2626

27-
let shape_out = Shape::new([batch_size, channels, out_height, out_width]);
28-
let output = empty_device::<R, E>(input.client.clone(), input.device.clone(), shape_out);
27+
let input = permute_nchw_to_nhwc(input);
2928

30-
match options.mode {
29+
let shape_out = Shape::new([batch_size, out_height, out_width, channels]);
30+
let output =
31+
empty_device_strided::<R, E>(input.client.clone(), input.device.clone(), shape_out);
32+
33+
let output = match options.mode {
3134
InterpolateMode::Nearest => interpolate_nearest_launch::<R, E>(input, output),
3235
InterpolateMode::Bilinear => interpolate_bilinear_launch::<R, E>(input, output),
3336
InterpolateMode::Bicubic => interpolate_bicubic_launch::<R, E>(input, output),
34-
}
37+
};
38+
39+
permute_nhwc_to_nchw(output)
3540
}
3641

3742
/// Backward interpolate operation
@@ -43,25 +48,22 @@ pub fn interpolate_backward<R: CubeRuntime, E: FloatElement>(
4348
_output_size: [usize; 2],
4449
options: InterpolateOptions,
4550
) -> CubeTensor<R> {
46-
let out_grad = into_contiguous(out_grad);
51+
let input = permute_nchw_to_nhwc(input);
52+
let out_grad = permute_nchw_to_nhwc(out_grad);
53+
4754
let output_shape = input.shape.clone();
48-
let num_elems = input.shape.num_elements();
49-
let buffer = input.client.empty(num_elems * core::mem::size_of::<E>());
50-
let output = CubeTensor::new_contiguous(
51-
input.client.clone(),
52-
input.device.clone(),
53-
output_shape,
54-
buffer,
55-
input.dtype,
56-
);
55+
let output =
56+
empty_device_strided::<R, E>(input.client.clone(), input.device.clone(), output_shape);
5757

58-
match options.mode {
58+
let output = match options.mode {
5959
InterpolateMode::Nearest => interpolate_nearest_backward_launch::<R, E>(out_grad, output),
6060
InterpolateMode::Bilinear => {
6161
panic!("bilinear interpolation backward is not supported by JIT backend")
6262
}
6363
InterpolateMode::Bicubic => {
6464
panic!("bicubic interpolation backward is not supported by JIT backend")
6565
}
66-
}
66+
};
67+
68+
permute_nhwc_to_nchw(output)
6769
}

crates/burn-cubecl/src/kernel/interpolate/bicubic.rs

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,61 @@
1-
use cubecl::{calculate_cube_count_elemwise, prelude::*};
1+
use cubecl::{calculate_cube_count_elemwise, linalg::tensor::StridedLayout, prelude::*};
2+
use cubecl_std::FastDivmod;
23

3-
use crate::{CubeRuntime, FloatElement, tensor::CubeTensor};
4+
use crate::{
5+
CubeRuntime, FloatElement,
6+
kernel::utils::{shape_divmod, strided_layout},
7+
ops::max_line_size,
8+
tensor::CubeTensor,
9+
};
410

511
#[cube(launch)]
6-
fn interpolate_bicubic_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F>) {
12+
fn interpolate_bicubic_kernel<F: Float>(
13+
input: &Tensor<Line<F>>,
14+
output: &mut Tensor<Line<F>>,
15+
shape_out: Sequence<FastDivmod>,
16+
out_layout: StridedLayout,
17+
) {
718
if ABSOLUTE_POS >= output.len() {
819
terminate!();
920
}
1021

11-
let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0);
12-
let channel = ABSOLUTE_POS / output.stride(1) % output.shape(1);
13-
let y = ABSOLUTE_POS / output.stride(2) % output.shape(2);
14-
let x = ABSOLUTE_POS / output.stride(3) % output.shape(3);
22+
let line_size = input.line_size();
23+
let out_idx = out_layout.index(output, ABSOLUTE_POS);
1524

16-
let input_height = input.shape(2) - 1;
17-
let output_height = F::cast_from(Max::max(output.shape(2) - 1, 1));
18-
let numerator = F::cast_from(y * input_height);
25+
let (rem, c) = shape_out.index(3).div_mod(ABSOLUTE_POS * line_size);
26+
let (rem, x) = shape_out.index(2).div_mod(rem);
27+
let (b, y) = shape_out.index(1).div_mod(rem);
1928

20-
let frac = numerator / output_height;
29+
let input_height = input.shape(1) - 1;
30+
let output_height = f32::cast_from(Max::max(output.shape(1) - 1, 1));
31+
let numerator = f32::cast_from(y * input_height);
32+
33+
let frac = f32::cast_from(numerator / output_height);
2134
let y_in_f = Floor::floor(frac);
2235
let y_in = u32::cast_from(y_in_f);
23-
let yw = frac - y_in_f;
36+
let yw = Line::empty(line_size).fill(F::cast_from(frac - y_in_f));
2437

2538
let y0 = select(y_in != 0, y_in - 1, 0);
2639
let y1 = y_in;
2740
let y2 = Min::min(y_in + 1, input_height);
2841
let y3 = Min::min(y_in + 2, input_height);
2942

30-
let input_width = input.shape(3) - 1;
31-
let output_width = F::cast_from(Max::max(output.shape(3) - 1, 1));
32-
let numerator = F::cast_from(x * input_width);
43+
let input_width = input.shape(2) - 1;
44+
let output_width = f32::cast_from(Max::max(output.shape(2) - 1, 1));
45+
let numerator = f32::cast_from(x * input_width);
3346
let frac = numerator / output_width;
3447
let x_in_f = Floor::floor(frac);
3548
let x_in = u32::cast_from(x_in_f);
36-
let xw = frac - x_in_f;
49+
let xw = Line::empty(line_size).fill(F::cast_from(frac - x_in_f));
3750

3851
let x0 = select(x_in != 0, x_in - 1, 0);
3952
let x1 = x_in;
4053
let x2 = Min::min(x_in + 1, input_width);
4154
let x3 = Min::min(x_in + 2, input_width);
4255

43-
let index_base = batch * input.stride(0) + channel * input.stride(1);
44-
let in_stride_y = input.stride(2);
45-
let in_stride_x = input.stride(3);
56+
let index_base = b * input.stride(0) + c * input.stride(3);
57+
let in_stride_y = input.stride(1);
58+
let in_stride_x = input.stride(2);
4659

4760
let y0_stride = y0 * in_stride_y;
4861
let y1_stride = y1 * in_stride_y;
@@ -89,51 +102,69 @@ fn interpolate_bicubic_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F
89102
yw,
90103
);
91104

92-
output[ABSOLUTE_POS] = val;
105+
output[out_idx] = val;
93106
}
94107

95108
#[cube]
96-
fn cubic_interp_1d<F: Float>(x0: F, x1: F, x2: F, x3: F, t: F) -> F {
97-
let a = F::new(-0.75);
98-
99-
let coeffs0 = cubic_convolution_2::<F>(t + F::new(1.0), a);
109+
fn cubic_interp_1d<F: Float>(
110+
x0: Line<F>,
111+
x1: Line<F>,
112+
x2: Line<F>,
113+
x3: Line<F>,
114+
t: Line<F>,
115+
) -> Line<F> {
116+
let a = lined(&x0, -0.75);
117+
118+
let coeffs0 = cubic_convolution_2::<F>(t + lined(&x0, 1.0), a);
100119
let coeffs1 = cubic_convolution_1::<F>(t, a);
101-
let coeffs2 = cubic_convolution_1::<F>(F::new(1.0) - t, a);
102-
let coeffs3 = cubic_convolution_2::<F>(F::new(2.0) - t, a);
120+
let coeffs2 = cubic_convolution_1::<F>(lined(&x0, 1.0) - t, a);
121+
let coeffs3 = cubic_convolution_2::<F>(lined(&x0, 2.0) - t, a);
103122

104123
x0 * coeffs0 + x1 * coeffs1 + x2 * coeffs2 + x3 * coeffs3
105124
}
106125

107126
#[cube]
108-
fn cubic_convolution_1<F: Float>(x: F, a: F) -> F {
109-
let conv = (a + F::new(2.0)) * x;
110-
let tmp = a + F::new(3.0);
111-
(conv - tmp) * x * x + F::new(1.0)
127+
fn cubic_convolution_1<F: Float>(x: Line<F>, a: Line<F>) -> Line<F> {
128+
let conv = (a + lined(&x, 2.0)) * x;
129+
let tmp = a + lined(&x, 3.0);
130+
(conv - tmp) * x * x + lined(&x, 1.0)
112131
}
113132

114133
#[cube]
115-
fn cubic_convolution_2<F: Float>(x: F, a: F) -> F {
134+
fn cubic_convolution_2<F: Float>(x: Line<F>, a: Line<F>) -> Line<F> {
116135
let conv = a * x;
117-
let conv = (conv - F::new(5.0) * a) * x;
118-
let tmp = F::new(8.0) * a;
136+
let conv = (conv - lined(&x, 5.0) * a) * x;
137+
let tmp = lined(&x, 8.0) * a;
119138
let conv = (conv + tmp) * x;
120139

121-
conv - F::new(4.0) * a
140+
conv - lined(&x, 4.0) * a
141+
}
142+
143+
#[cube]
144+
fn lined<F: Float>(x: &Line<F>, #[comptime] v: f32) -> Line<F> {
145+
Line::empty(x.size()).fill(F::new(v))
122146
}
123147

124148
pub(crate) fn interpolate_bicubic_launch<R: CubeRuntime, E: FloatElement>(
125149
input: CubeTensor<R>,
126150
output: CubeTensor<R>,
127151
) -> CubeTensor<R> {
152+
let line_size = max_line_size(&input);
153+
let out_shape = shape_divmod(&output);
154+
let out_layout = strided_layout(&output);
155+
128156
let cube_dim = CubeDim::default();
129-
let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);
157+
let cube_count =
158+
calculate_cube_count_elemwise(output.shape.num_elements() / line_size as usize, cube_dim);
130159

131160
interpolate_bicubic_kernel::launch::<E, R>(
132161
&input.client,
133162
cube_count,
134163
cube_dim,
135-
input.as_tensor_arg::<E>(1),
136-
output.as_tensor_arg::<E>(1),
164+
input.as_tensor_arg::<E>(line_size),
165+
output.as_tensor_arg::<E>(line_size),
166+
out_shape,
167+
out_layout,
137168
);
138169

139170
output

0 commit comments

Comments
 (0)