|
1 |
| -use cubecl::{calculate_cube_count_elemwise, prelude::*}; |
| 1 | +use cubecl::{calculate_cube_count_elemwise, linalg::tensor::StridedLayout, prelude::*}; |
| 2 | +use cubecl_std::FastDivmod; |
2 | 3 |
|
3 |
| -use crate::{CubeRuntime, FloatElement, tensor::CubeTensor}; |
| 4 | +use crate::{ |
| 5 | + CubeRuntime, FloatElement, |
| 6 | + kernel::utils::{shape_divmod, strided_layout}, |
| 7 | + ops::max_line_size, |
| 8 | + tensor::CubeTensor, |
| 9 | +}; |
4 | 10 |
|
5 | 11 | #[cube(launch)]
|
6 |
| -fn interpolate_bicubic_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F>) { |
| 12 | +fn interpolate_bicubic_kernel<F: Float>( |
| 13 | + input: &Tensor<Line<F>>, |
| 14 | + output: &mut Tensor<Line<F>>, |
| 15 | + shape_out: Sequence<FastDivmod>, |
| 16 | + out_layout: StridedLayout, |
| 17 | +) { |
7 | 18 | if ABSOLUTE_POS >= output.len() {
|
8 | 19 | terminate!();
|
9 | 20 | }
|
10 | 21 |
|
11 |
| - let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); |
12 |
| - let channel = ABSOLUTE_POS / output.stride(1) % output.shape(1); |
13 |
| - let y = ABSOLUTE_POS / output.stride(2) % output.shape(2); |
14 |
| - let x = ABSOLUTE_POS / output.stride(3) % output.shape(3); |
| 22 | + let line_size = input.line_size(); |
| 23 | + let out_idx = out_layout.index(output, ABSOLUTE_POS); |
15 | 24 |
|
16 |
| - let input_height = input.shape(2) - 1; |
17 |
| - let output_height = F::cast_from(Max::max(output.shape(2) - 1, 1)); |
18 |
| - let numerator = F::cast_from(y * input_height); |
| 25 | + let (rem, c) = shape_out.index(3).div_mod(ABSOLUTE_POS * line_size); |
| 26 | + let (rem, x) = shape_out.index(2).div_mod(rem); |
| 27 | + let (b, y) = shape_out.index(1).div_mod(rem); |
19 | 28 |
|
20 |
| - let frac = numerator / output_height; |
| 29 | + let input_height = input.shape(1) - 1; |
| 30 | + let output_height = f32::cast_from(Max::max(output.shape(1) - 1, 1)); |
| 31 | + let numerator = f32::cast_from(y * input_height); |
| 32 | + |
| 33 | + let frac = f32::cast_from(numerator / output_height); |
21 | 34 | let y_in_f = Floor::floor(frac);
|
22 | 35 | let y_in = u32::cast_from(y_in_f);
|
23 |
| - let yw = frac - y_in_f; |
| 36 | + let yw = Line::empty(line_size).fill(F::cast_from(frac - y_in_f)); |
24 | 37 |
|
25 | 38 | let y0 = select(y_in != 0, y_in - 1, 0);
|
26 | 39 | let y1 = y_in;
|
27 | 40 | let y2 = Min::min(y_in + 1, input_height);
|
28 | 41 | let y3 = Min::min(y_in + 2, input_height);
|
29 | 42 |
|
30 |
| - let input_width = input.shape(3) - 1; |
31 |
| - let output_width = F::cast_from(Max::max(output.shape(3) - 1, 1)); |
32 |
| - let numerator = F::cast_from(x * input_width); |
| 43 | + let input_width = input.shape(2) - 1; |
| 44 | + let output_width = f32::cast_from(Max::max(output.shape(2) - 1, 1)); |
| 45 | + let numerator = f32::cast_from(x * input_width); |
33 | 46 | let frac = numerator / output_width;
|
34 | 47 | let x_in_f = Floor::floor(frac);
|
35 | 48 | let x_in = u32::cast_from(x_in_f);
|
36 |
| - let xw = frac - x_in_f; |
| 49 | + let xw = Line::empty(line_size).fill(F::cast_from(frac - x_in_f)); |
37 | 50 |
|
38 | 51 | let x0 = select(x_in != 0, x_in - 1, 0);
|
39 | 52 | let x1 = x_in;
|
40 | 53 | let x2 = Min::min(x_in + 1, input_width);
|
41 | 54 | let x3 = Min::min(x_in + 2, input_width);
|
42 | 55 |
|
43 |
| - let index_base = batch * input.stride(0) + channel * input.stride(1); |
44 |
| - let in_stride_y = input.stride(2); |
45 |
| - let in_stride_x = input.stride(3); |
| 56 | + let index_base = b * input.stride(0) + c * input.stride(3); |
| 57 | + let in_stride_y = input.stride(1); |
| 58 | + let in_stride_x = input.stride(2); |
46 | 59 |
|
47 | 60 | let y0_stride = y0 * in_stride_y;
|
48 | 61 | let y1_stride = y1 * in_stride_y;
|
@@ -89,51 +102,69 @@ fn interpolate_bicubic_kernel<F: Float>(input: &Tensor<F>, output: &mut Tensor<F
|
89 | 102 | yw,
|
90 | 103 | );
|
91 | 104 |
|
92 |
| - output[ABSOLUTE_POS] = val; |
| 105 | + output[out_idx] = val; |
93 | 106 | }
|
94 | 107 |
|
95 | 108 | #[cube]
|
96 |
| -fn cubic_interp_1d<F: Float>(x0: F, x1: F, x2: F, x3: F, t: F) -> F { |
97 |
| - let a = F::new(-0.75); |
98 |
| - |
99 |
| - let coeffs0 = cubic_convolution_2::<F>(t + F::new(1.0), a); |
| 109 | +fn cubic_interp_1d<F: Float>( |
| 110 | + x0: Line<F>, |
| 111 | + x1: Line<F>, |
| 112 | + x2: Line<F>, |
| 113 | + x3: Line<F>, |
| 114 | + t: Line<F>, |
| 115 | +) -> Line<F> { |
| 116 | + let a = lined(&x0, -0.75); |
| 117 | + |
| 118 | + let coeffs0 = cubic_convolution_2::<F>(t + lined(&x0, 1.0), a); |
100 | 119 | let coeffs1 = cubic_convolution_1::<F>(t, a);
|
101 |
| - let coeffs2 = cubic_convolution_1::<F>(F::new(1.0) - t, a); |
102 |
| - let coeffs3 = cubic_convolution_2::<F>(F::new(2.0) - t, a); |
| 120 | + let coeffs2 = cubic_convolution_1::<F>(lined(&x0, 1.0) - t, a); |
| 121 | + let coeffs3 = cubic_convolution_2::<F>(lined(&x0, 2.0) - t, a); |
103 | 122 |
|
104 | 123 | x0 * coeffs0 + x1 * coeffs1 + x2 * coeffs2 + x3 * coeffs3
|
105 | 124 | }
|
106 | 125 |
|
107 | 126 | #[cube]
|
108 |
| -fn cubic_convolution_1<F: Float>(x: F, a: F) -> F { |
109 |
| - let conv = (a + F::new(2.0)) * x; |
110 |
| - let tmp = a + F::new(3.0); |
111 |
| - (conv - tmp) * x * x + F::new(1.0) |
| 127 | +fn cubic_convolution_1<F: Float>(x: Line<F>, a: Line<F>) -> Line<F> { |
| 128 | + let conv = (a + lined(&x, 2.0)) * x; |
| 129 | + let tmp = a + lined(&x, 3.0); |
| 130 | + (conv - tmp) * x * x + lined(&x, 1.0) |
112 | 131 | }
|
113 | 132 |
|
114 | 133 | #[cube]
|
115 |
| -fn cubic_convolution_2<F: Float>(x: F, a: F) -> F { |
| 134 | +fn cubic_convolution_2<F: Float>(x: Line<F>, a: Line<F>) -> Line<F> { |
116 | 135 | let conv = a * x;
|
117 |
| - let conv = (conv - F::new(5.0) * a) * x; |
118 |
| - let tmp = F::new(8.0) * a; |
| 136 | + let conv = (conv - lined(&x, 5.0) * a) * x; |
| 137 | + let tmp = lined(&x, 8.0) * a; |
119 | 138 | let conv = (conv + tmp) * x;
|
120 | 139 |
|
121 |
| - conv - F::new(4.0) * a |
| 140 | + conv - lined(&x, 4.0) * a |
| 141 | +} |
| 142 | + |
| 143 | +#[cube] |
| 144 | +fn lined<F: Float>(x: &Line<F>, #[comptime] v: f32) -> Line<F> { |
| 145 | + Line::empty(x.size()).fill(F::new(v)) |
122 | 146 | }
|
123 | 147 |
|
124 | 148 | pub(crate) fn interpolate_bicubic_launch<R: CubeRuntime, E: FloatElement>(
|
125 | 149 | input: CubeTensor<R>,
|
126 | 150 | output: CubeTensor<R>,
|
127 | 151 | ) -> CubeTensor<R> {
|
| 152 | + let line_size = max_line_size(&input); |
| 153 | + let out_shape = shape_divmod(&output); |
| 154 | + let out_layout = strided_layout(&output); |
| 155 | + |
128 | 156 | let cube_dim = CubeDim::default();
|
129 |
| - let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); |
| 157 | + let cube_count = |
| 158 | + calculate_cube_count_elemwise(output.shape.num_elements() / line_size as usize, cube_dim); |
130 | 159 |
|
131 | 160 | interpolate_bicubic_kernel::launch::<E, R>(
|
132 | 161 | &input.client,
|
133 | 162 | cube_count,
|
134 | 163 | cube_dim,
|
135 |
| - input.as_tensor_arg::<E>(1), |
136 |
| - output.as_tensor_arg::<E>(1), |
| 164 | + input.as_tensor_arg::<E>(line_size), |
| 165 | + output.as_tensor_arg::<E>(line_size), |
| 166 | + out_shape, |
| 167 | + out_layout, |
137 | 168 | );
|
138 | 169 |
|
139 | 170 | output
|
|
0 commit comments