Skip to content

Refactor quantization scheme #3042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
cc95328
refactor quantization scheme
maxtremblay Apr 17, 2025
fdf3845
impl acc precision and output mode
maxtremblay Apr 17, 2025
476072a
clean test
maxtremblay Apr 17, 2025
eb2e82f
cargo fmt
maxtremblay Apr 17, 2025
bff1824
remove unused import
maxtremblay Apr 17, 2025
34e7227
wip
maxtremblay Apr 17, 2025
b4113d5
Merge branch 'main' into refactor-quantization-scheme
laggui Apr 28, 2025
829c2bc
Cargo fmt
laggui Apr 28, 2025
b89d19c
Make it work
laggui Apr 28, 2025
beb896f
Merge branch 'main' into refactor-quantization-scheme
laggui Apr 29, 2025
5849d8e
Narrow, chunk and split are all high-level slice-based methods
laggui Apr 30, 2025
a586dcd
Add argwhere empty test
laggui Apr 30, 2025
9c76292
Cleanup qtensor ops
laggui Apr 30, 2025
77e48ed
Better docstrings
laggui Apr 30, 2025
237f921
Remove unused
laggui Apr 30, 2025
0a22259
Add propagate test example
laggui Apr 30, 2025
3494fdd
Merge branch 'main' into refactor-quantization-scheme
laggui May 2, 2025
d265f29
Add return type semantics description
laggui May 2, 2025
e73baa7
Fusion ops passthrough
laggui May 2, 2025
d295d82
Merge branch 'main' into refactor-quantization-scheme
laggui May 2, 2025
1f61966
Cleanup
laggui May 2, 2025
c9de3bd
Handle lhs float rhs quant for practical use cases
laggui May 2, 2025
f5645b4
Fix clippy
laggui May 2, 2025
239cb81
Use matches
laggui May 2, 2025
3e338c0
Remove comment
laggui May 2, 2025
e52a486
Cleaner
laggui May 2, 2025
4dc5597
Merge branch 'main' into refactor-quantization-scheme
laggui May 6, 2025
bf2b20a
Fix merged conflicts
laggui May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use burn_tensor::{
Device, Shape, TensorData,
backend::Backend,
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
quantization::{QuantScheme, QuantizationParametersPrimitive},
};

use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy};
Expand All @@ -16,15 +16,15 @@ impl<B: Backend, C: CheckpointStrategy> QTensorOps<Self> for Autodiff<B, C> {

fn quantize(
_tensor: FloatTensor<Self>,
_scheme: &QuantizationScheme,
_scheme: &QuantScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
todo!() // required for QAT
}

fn quantize_dynamic(
_tensor: FloatTensor<Self>,
_scheme: &QuantizationScheme,
_scheme: &QuantScheme,
) -> QuantizedTensor<Self> {
todo!()
}
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-candle/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use burn_tensor::{
DType, Device, Shape, TensorData,
backend::Backend,
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
quantization::{QuantScheme, QuantizationParametersPrimitive},
};

use crate::{
Expand All @@ -19,7 +19,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> QTensorOps<Self> for Candle<F,

fn quantize(
_tensor: FloatTensor<Self>,
_scheme: &QuantizationScheme,
_scheme: &QuantScheme,
_qparams: QuantizationParametersPrimitive<Self>,
) -> QuantizedTensor<Self> {
unimplemented!()
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-candle/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn_tensor::{
DType, Element, Shape, TensorData, TensorMetadata,
quantization::{QTensorPrimitive, QuantizationScheme},
quantization::{QTensorPrimitive, QuantScheme},
};

use crate::{CandleDevice, element::CandleElement};
Expand Down Expand Up @@ -63,11 +63,11 @@ pub struct CandleQTensor {
// NOTE: candle does not implement `WithDType` for i8
pub qtensor: CandleTensor,
/// The quantization scheme.
pub scheme: QuantizationScheme,
pub scheme: QuantScheme,
}

impl QTensorPrimitive for CandleQTensor {
fn scheme(&self) -> &QuantizationScheme {
fn scheme(&self) -> &QuantScheme {
&self.scheme
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-core/src/module/quantize.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use burn_tensor::{
Tensor,
backend::Backend,
quantization::{Calibration, QuantizationScheme},
quantization::{Calibration, QuantScheme},
};

use crate::module::{ModuleMapper, ParamId};
Expand All @@ -11,7 +11,7 @@ pub struct Quantizer {
/// The calibration method used in quantization.
pub calibration: Calibration,
/// The quantization scheme.
pub scheme: QuantizationScheme,
pub scheme: QuantScheme,
}

impl<B: Backend> ModuleMapper<B> for Quantizer {
Expand Down
34 changes: 33 additions & 1 deletion crates/burn-cubecl-fusion/src/shared/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use burn_ir::{
BaseOperationIr, BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarOpIr,
TensorIr, UnaryOpIr,
};
use burn_tensor::Element;
use burn_tensor::{DType, Element};
use cubecl::ir::Elem;

/// The base optimization builder that can be used to fuse all elemwise operations.
Expand Down Expand Up @@ -212,6 +212,10 @@ impl FuseOptimizationBuilder {
return false;
}

if self.input_is_quantized(&desc.input) {
return false;
}

if self.builder.register(|build| {
build.input_swap_dims(
&desc.input,
Expand Down Expand Up @@ -243,6 +247,10 @@ impl FuseOptimizationBuilder {
return false;
}

if self.input_is_quantized(&desc.input) {
return false;
}

if self.builder.register(|build| {
build.input_reshaped(&desc.input, &desc.out)?;
Some(())
Expand Down Expand Up @@ -447,6 +455,10 @@ impl FuseOptimizationBuilder {
return false;
}

if self.input_is_quantized(&desc.tensor) {
return false;
}

self.builder.register(|build| {
let input = build.input_indexed(&desc.tensor)?;
let indices = build.input(&desc.indices)?;
Expand All @@ -467,6 +479,10 @@ impl FuseOptimizationBuilder {
return false;
}

if self.input_is_quantized(&desc.tensor) {
return false;
}

self.builder.register(|build| {
let input = build.input_indexed(&desc.tensor)?;
let indices = build.input_indexed(&desc.indices)?;
Expand Down Expand Up @@ -494,6 +510,10 @@ impl FuseOptimizationBuilder {
return false;
}

if self.input_is_quantized(&desc.lhs) {
return false;
}

self.builder.register(|build| {
let lhs = build.input(&desc.lhs)?;
let rhs = build.input(&desc.rhs)?;
Expand All @@ -513,6 +533,10 @@ impl FuseOptimizationBuilder {
return false;
}

if self.input_is_quantized(&desc.input) {
return false;
}

self.builder.register(|build| {
let input = build.input(&desc.input)?;
let out = build.output(&desc.out)?;
Expand All @@ -529,6 +553,10 @@ impl FuseOptimizationBuilder {
return false;
}

if self.input_is_quantized(&desc.lhs) {
return false;
}

self.builder.register(|build| {
let elem = desc.lhs.dtype;
let lhs = build.input(&desc.lhs)?;
Expand All @@ -541,6 +569,10 @@ impl FuseOptimizationBuilder {
})
}

fn input_is_quantized(&self, input: &TensorIr) -> bool {
matches!(input.dtype, DType::QFloat(_scheme))
}

fn output_is_compatible(&mut self, out: &TensorIr) -> bool {
if self.current_output_shape.is_empty() {
self.current_output_shape.clone_from(&out.shape);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cubecl-fusion/src/shared/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ impl From<DType> for FusePrecision {
DType::U16 => Self::U16,
DType::U8 => Self::U8,
DType::Bool => Self::Bool,
_ => panic!("Unsupported"),
_ => panic!("Unsupported precision for fusion: {value:?}"),
}
}
}
Expand Down
37 changes: 29 additions & 8 deletions crates/burn-cubecl/src/kernel/matmul/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use super::init_matmul_output;
use crate::{CubeRuntime, FloatElement, tensor::CubeTensor};
use burn_tensor::DType;
use burn_tensor::{
DType,
quantization::{QTensorPrimitive, QuantAccPrecision},
};
use cubecl::linalg::matmul::{components::Quantized, kernels::MatmulLaunchError};

#[cfg(feature = "autotune")]
Expand Down Expand Up @@ -65,16 +68,34 @@ pub fn q_matmul<R: CubeRuntime>(

let client = &lhs.client;

let scheme = *lhs.scheme();

lhs.dtype = DType::I8;
rhs.dtype = DType::I8;

cubecl::linalg::matmul::launch_ref::<R, (i8, half::f16, half::f16, half::f16, Quantized)>(
&Default::default(),
client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
)?;
match scheme.acc_precision {
QuantAccPrecision::Full => {
cubecl::linalg::matmul::launch_ref::<R, (i8, half::f16, f32, half::f16, Quantized)>(
&Default::default(),
client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
)?;
}
QuantAccPrecision::Half => {
cubecl::linalg::matmul::launch_ref::<
R,
(i8, half::f16, half::f16, half::f16, Quantized),
>(
&Default::default(),
client,
&lhs.as_handle_ref(),
&rhs.as_handle_ref(),
&out.as_handle_ref(),
)?;
}
}

Ok(out)
}
11 changes: 8 additions & 3 deletions crates/burn-cubecl/src/kernel/quantization/dequantize.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::tensor::CubeTensor;
use crate::{CubeElement, CubeRuntime};
use burn_tensor::DType;
use burn_tensor::quantization::{QuantizationMode, QuantizationScheme, QuantizationType};
use burn_tensor::quantization::{QuantInputType, QuantLevel, QuantMode, QuantScheme};
use cubecl::calculate_cube_count_elemwise;
use cubecl::prelude::*;

Expand Down Expand Up @@ -39,7 +39,7 @@ fn unpack_i8s(value: u32) -> Line<i32> {
fn dequantize_per_tensor_symmetric_int8_kernel(
input: &QTensor,
output: &mut Tensor<Line<f32>>,
#[comptime] scheme: QuantizationScheme,
#[comptime] scheme: QuantScheme,
) {
// Last position contains the qparam
if ABSOLUTE_POS >= input.len() - 1 {
Expand Down Expand Up @@ -93,7 +93,12 @@ where

if let DType::QFloat(scheme) = tensor.dtype {
match scheme {
QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
QuantScheme {
level: QuantLevel::Tensor,
mode: QuantMode::Symmetric,
q_type: QuantInputType::QInt8,
..
} => {
unsafe {
dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
&client,
Expand Down
15 changes: 9 additions & 6 deletions crates/burn-cubecl/src/kernel/quantization/qtensor.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#![allow(missing_docs)] // cube derive macros

use burn_tensor::quantization::{QuantizationMode, QuantizationScheme};
use burn_tensor::quantization::{QuantInputType, QuantLevel, QuantMode, QuantScheme};
use cubecl::prelude::*;

/// Quantization parameters.
#[derive(CubeLaunch, CubeType)]
pub struct QParams {
#[cube(comptime)]
scheme: QuantizationScheme,
scheme: QuantScheme,
}

/// Quantized tensor representation.
Expand All @@ -16,7 +16,7 @@ pub type QTensor = Array<Line<u32>>;
#[cube]
impl QParams {
/// Create a new quantization parameters instance.
pub fn new(scheme: QuantizationScheme) -> Self {
pub fn new(#[comptime] scheme: QuantScheme) -> Self {
QParams { scheme }
}

Expand All @@ -25,9 +25,12 @@ impl QParams {
let len = tensor.len();
match comptime!(self.scheme) {
// Symmetric quantization only contains the scaling factor as the last element
QuantizationScheme::PerTensor(QuantizationMode::Symmetric, _) => {
(f32::reinterpret(tensor[len - 1][tensor.line_size() - 1]), 0)
}
QuantScheme {
level: QuantLevel::Tensor,
mode: QuantMode::Symmetric,
q_type: QuantInputType::QInt8,
..
} => (f32::reinterpret(tensor[len - 1][tensor.line_size() - 1]), 0),
}
}
}
Loading
Loading