Skip to content

Commit 9c9bf60

Browse files
maxtremblaylaggui
authored andcommitted
Refactor quantization scheme (tracel-ai#3042)
* refactor quantization scheme * impl acc precision and output mode * clean test * cargo fmt * remove unused import * wip * Cargo fmt * Make it work * Narrow, chunk and split are all high-level slice-based methods * Add argwhere empty test * Cleanup qtensor ops * Better docstrings * Remove unused * Add propagate test example * Add return type semantics description * Fusion ops passthrough * Cleanup * Handle lhs float rhs quant for practical use cases * Fix clippy * Use matches * Remove comment * Cleaner * Fix merged conflicts --------- Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
1 parent 38ba741 commit 9c9bf60

File tree

40 files changed

+1127
-805
lines changed

40 files changed

+1127
-805
lines changed

crates/burn-autodiff/src/ops/qtensor.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use burn_tensor::{
44
Device, Shape, TensorData,
55
backend::Backend,
66
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
7-
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
7+
quantization::{QuantScheme, QuantizationParametersPrimitive},
88
};
99

1010
use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
@@ -16,15 +16,15 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {
1616

1717
fn quantize(
1818
_tensor: FloatTensor<Self>,
19-
_scheme: &QuantizationScheme,
19+
_scheme: &QuantScheme,
2020
_qparams: QuantizationParametersPrimitive<Self>,
2121
) -> QuantizedTensor<Self> {
2222
todo!() // required for QAT
2323
}
2424

2525
fn quantize_dynamic(
2626
_tensor: FloatTensor<Self>,
27-
_scheme: &QuantizationScheme,
27+
_scheme: &QuantScheme,
2828
) -> QuantizedTensor<Self> {
2929
todo!()
3030
}

crates/burn-candle/src/ops/qtensor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use burn_tensor::{
44
DType, Device, Shape, TensorData,
55
backend::Backend,
66
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
7-
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
7+
quantization::{QuantScheme, QuantizationParametersPrimitive},
88
};
99

1010
use crate::{
@@ -19,7 +19,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,
1919

2020
fn quantize(
2121
_tensor: FloatTensor<Self>,
22-
_scheme: &QuantizationScheme,
22+
_scheme: &QuantScheme,
2323
_qparams: QuantizationParametersPrimitive<Self>,
2424
) -> QuantizedTensor<Self> {
2525
unimplemented!()

crates/burn-candle/src/tensor.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use burn_tensor::{
22
DType, Element, Shape, TensorData, TensorMetadata,
3-
quantization::{QTensorPrimitive, QuantizationScheme},
3+
quantization::{QTensorPrimitive, QuantScheme},
44
};
55

66
use crate::{CandleDevice, element::CandleElement};
@@ -63,11 +63,11 @@ pub struct CandleQTensor {
6363
// NOTE: candle does not implement `WithDType` for i8
6464
pub qtensor: CandleTensor,
6565
/// The quantization scheme.
66-
pub scheme: QuantizationScheme,
66+
pub scheme: QuantScheme,
6767
}
6868

6969
impl QTensorPrimitive for CandleQTensor {
70-
fn scheme(&self) -> &QuantizationScheme {
70+
fn scheme(&self) -> &QuantScheme {
7171
&self.scheme
7272
}
7373
}

crates/burn-core/src/module/quantize.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use burn_tensor::{
22
Tensor,
33
backend::Backend,
4-
quantization::{Calibration, QuantizationScheme},
4+
quantization::{Calibration, QuantScheme},
55
};
66

77
use crate::module::{ModuleMapper, ParamId};
@@ -11,7 +11,7 @@ pub struct Quantizer {
1111
/// The calibration method used in quantization.
1212
pub calibration: Calibration,
1313
/// The quantization scheme.
14-
pub scheme: QuantizationScheme,
14+
pub scheme: QuantScheme,
1515
}
1616

1717
impl<B: Backend> ModuleMapper<B> for Quantizer {

crates/burn-cubecl-fusion/src/shared/builder.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use burn_ir::{
88
BaseOperationIr, BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarOpIr,
99
TensorIr, UnaryOpIr,
1010
};
11-
use burn_tensor::Element;
11+
use burn_tensor::{DType, Element};
1212
use cubecl::ir::Elem;
1313

1414
/// The base optimization builder that can be used to fuse all elemwise operations.
@@ -212,6 +212,10 @@ impl FuseOptimizationBuilder {
212212
return false;
213213
}
214214

215+
if self.input_is_quantized(&desc.input) {
216+
return false;
217+
}
218+
215219
if self.builder.register(|build| {
216220
build.input_swap_dims(
217221
&desc.input,
@@ -243,6 +247,10 @@ impl FuseOptimizationBuilder {
243247
return false;
244248
}
245249

250+
if self.input_is_quantized(&desc.input) {
251+
return false;
252+
}
253+
246254
if self.builder.register(|build| {
247255
build.input_reshaped(&desc.input, &desc.out)?;
248256
Some(())
@@ -447,6 +455,10 @@ impl FuseOptimizationBuilder {
447455
return false;
448456
}
449457

458+
if self.input_is_quantized(&desc.tensor) {
459+
return false;
460+
}
461+
450462
self.builder.register(|build| {
451463
let input = build.input_indexed(&desc.tensor)?;
452464
let indices = build.input(&desc.indices)?;
@@ -467,6 +479,10 @@ impl FuseOptimizationBuilder {
467479
return false;
468480
}
469481

482+
if self.input_is_quantized(&desc.tensor) {
483+
return false;
484+
}
485+
470486
self.builder.register(|build| {
471487
let input = build.input_indexed(&desc.tensor)?;
472488
let indices = build.input_indexed(&desc.indices)?;
@@ -494,6 +510,10 @@ impl FuseOptimizationBuilder {
494510
return false;
495511
}
496512

513+
if self.input_is_quantized(&desc.lhs) {
514+
return false;
515+
}
516+
497517
self.builder.register(|build| {
498518
let lhs = build.input(&desc.lhs)?;
499519
let rhs = build.input(&desc.rhs)?;
@@ -513,6 +533,10 @@ impl FuseOptimizationBuilder {
513533
return false;
514534
}
515535

536+
if self.input_is_quantized(&desc.input) {
537+
return false;
538+
}
539+
516540
self.builder.register(|build| {
517541
let input = build.input(&desc.input)?;
518542
let out = build.output(&desc.out)?;
@@ -529,6 +553,10 @@ impl FuseOptimizationBuilder {
529553
return false;
530554
}
531555

556+
if self.input_is_quantized(&desc.lhs) {
557+
return false;
558+
}
559+
532560
self.builder.register(|build| {
533561
let elem = desc.lhs.dtype;
534562
let lhs = build.input(&desc.lhs)?;
@@ -541,6 +569,10 @@ impl FuseOptimizationBuilder {
541569
})
542570
}
543571

572+
fn input_is_quantized(&self, input: &TensorIr) -> bool {
573+
matches!(input.dtype, DType::QFloat(_scheme))
574+
}
575+
544576
fn output_is_compatible(&mut self, out: &TensorIr) -> bool {
545577
if self.current_output_shape.is_empty() {
546578
self.current_output_shape.clone_from(&out.shape);

crates/burn-cubecl-fusion/src/shared/ir.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ impl From<DType> for FusePrecision {
415415
DType::U16 => Self::U16,
416416
DType::U8 => Self::U8,
417417
DType::Bool => Self::Bool,
418-
_ => panic!("Unsupported"),
418+
_ => panic!("Unsupported precision for fusion: {value:?}"),
419419
}
420420
}
421421
}

crates/burn-cubecl/src/kernel/matmul/base.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use super::init_matmul_output;
22
use crate::{CubeRuntime, FloatElement, tensor::CubeTensor};
3-
use burn_tensor::DType;
3+
use burn_tensor::{
4+
DType,
5+
quantization::{QTensorPrimitive, QuantAccPrecision},
6+
};
47
use cubecl::linalg::matmul::{components::Quantized, kernels::MatmulLaunchError};
58

69
#[cfg(feature = "autotune")]
@@ -65,16 +68,34 @@ pub fn q_matmul<R: CubeRuntime>(
6568

6669
let client = &lhs.client;
6770

71+
let scheme = *lhs.scheme();
72+
6873
lhs.dtype = DType::I8;
6974
rhs.dtype = DType::I8;
7075

71-
cubecl::linalg::matmul::launch_ref::<R, (i8, half::f16, half::f16, half::f16, Quantized)>(
72-
&Default::default(),
73-
client,
74-
&lhs.as_handle_ref(),
75-
&rhs.as_handle_ref(),
76-
&out.as_handle_ref(),
77-
)?;
76+
match scheme.acc_precision {
77+
QuantAccPrecision::Full => {
78+
cubecl::linalg::matmul::launch_ref::<R, (i8, half::f16, f32, half::f16, Quantized)>(
79+
&Default::default(),
80+
client,
81+
&lhs.as_handle_ref(),
82+
&rhs.as_handle_ref(),
83+
&out.as_handle_ref(),
84+
)?;
85+
}
86+
QuantAccPrecision::Half => {
87+
cubecl::linalg::matmul::launch_ref::<
88+
R,
89+
(i8, half::f16, half::f16, half::f16, Quantized),
90+
>(
91+
&Default::default(),
92+
client,
93+
&lhs.as_handle_ref(),
94+
&rhs.as_handle_ref(),
95+
&out.as_handle_ref(),
96+
)?;
97+
}
98+
}
7899

79100
Ok(out)
80101
}

crates/burn-cubecl/src/kernel/quantization/dequantize.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::tensor::CubeTensor;
22
use crate::{CubeElement, CubeRuntime};
33
use burn_tensor::DType;
4-
use burn_tensor::quantization::{QuantizationMode, QuantizationScheme, QuantizationType};
4+
use burn_tensor::quantization::{QuantInputType, QuantLevel, QuantMode, QuantScheme};
55
use cubecl::calculate_cube_count_elemwise;
66
use cubecl::prelude::*;
77

@@ -39,7 +39,7 @@ fn unpack_i8s(value: u32) -> Line<i32> {
3939
fn dequantize_per_tensor_symmetric_int8_kernel(
4040
input: &QTensor,
4141
output: &mut Tensor<Line<f32>>,
42-
#[comptime] scheme: QuantizationScheme,
42+
#[comptime] scheme: QuantScheme,
4343
) {
4444
// Last position contains the qparam
4545
if ABSOLUTE_POS >= input.len() - 1 {
@@ -93,7 +93,12 @@ where
9393

9494
if let DType::QFloat(scheme) = tensor.dtype {
9595
match scheme {
96-
QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
96+
QuantScheme {
97+
level: QuantLevel::Tensor,
98+
mode: QuantMode::Symmetric,
99+
q_type: QuantInputType::QInt8,
100+
..
101+
} => {
97102
unsafe {
98103
dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
99104
&client,
Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#![allow(missing_docs)] // cube derive macros
22

3-
use burn_tensor::quantization::{QuantizationMode, QuantizationScheme};
3+
use burn_tensor::quantization::{QuantInputType, QuantLevel, QuantMode, QuantScheme};
44
use cubecl::prelude::*;
55

66
/// Quantization parameters.
77
#[derive(CubeLaunch, CubeType)]
88
pub struct QParams {
99
#[cube(comptime)]
10-
scheme: QuantizationScheme,
10+
scheme: QuantScheme,
1111
}
1212

1313
/// Quantized tensor representation.
@@ -16,7 +16,7 @@ pub type QTensor = Array<Line<u32>>;
1616
#[cube]
1717
impl QParams {
1818
/// Create a new quantization parameters instance.
19-
pub fn new(scheme: QuantizationScheme) -> Self {
19+
pub fn new(#[comptime] scheme: QuantScheme) -> Self {
2020
QParams { scheme }
2121
}
2222

@@ -25,9 +25,12 @@ impl QParams {
2525
let len = tensor.len();
2626
match comptime!(self.scheme) {
2727
// Symmetric quantization only contains the scaling factor as the last element
28-
QuantizationScheme::PerTensor(QuantizationMode::Symmetric, _) => {
29-
(f32::reinterpret(tensor[len - 1][tensor.line_size() - 1]), 0)
30-
}
28+
QuantScheme {
29+
level: QuantLevel::Tensor,
30+
mode: QuantMode::Symmetric,
31+
q_type: QuantInputType::QInt8,
32+
..
33+
} => (f32::reinterpret(tensor[len - 1][tensor.line_size() - 1]), 0),
3134
}
3235
}
3336
}

0 commit comments

Comments
 (0)