Skip to content

Refactor quantization scheme #3042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
May 6, 2025
Merged

Refactor quantization scheme #3042

merged 28 commits into from
May 6, 2025

Conversation

maxtremblay
Copy link
Contributor

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Changes

Refactor QuantizationScheme to a struct and add two new parameters. The quantization accumulator precision and the output mode.

Testing

No new tests, but the existing ones are still succeeding.

@maxtremblay maxtremblay requested a review from laggui April 17, 2025 17:55
Copy link

codecov bot commented Apr 17, 2025

Codecov Report

Attention: Patch coverage is 45.30612% with 402 lines in your changes missing coverage. Please review.

Project coverage is 81.40%. Comparing base (eb57d7a) to head (bf2b20a).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-fusion/src/ops/qtensor.rs 37.55% 143 Missing ⚠️
crates/burn-tensor/src/tensor/ops/qtensor.rs 7.82% 106 Missing ⚠️
crates/burn-tch/src/ops/qtensor.rs 0.00% 27 Missing ⚠️
crates/burn-tensor/src/tensor/api/numeric.rs 46.15% 21 Missing ⚠️
crates/burn-cubecl/src/kernel/matmul/base.rs 0.00% 20 Missing ⚠️
...ates/burn-tensor/src/tensor/quantization/scheme.rs 45.94% 20 Missing ⚠️
crates/burn-ndarray/src/ops/qtensor.rs 33.33% 14 Missing ⚠️
crates/burn-cubecl/src/ops/qtensor.rs 60.00% 10 Missing ⚠️
...tes/burn-cubecl/src/kernel/quantization/qtensor.rs 0.00% 7 Missing ⚠️
crates/burn-tch/src/tensor.rs 0.00% 7 Missing ⚠️
... and 12 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3042      +/-   ##
==========================================
+ Coverage   81.36%   81.40%   +0.03%     
==========================================
  Files         821      821              
  Lines      117791   118058     +267     
==========================================
+ Hits        95844    96103     +259     
- Misses      21947    21955       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much more flexible given the different configurations for quantization! So I agree with the change from an enum to struct for the QuantizationScheme.

I have some comments regarding naming, otherwise LGTM.

@laggui laggui marked this pull request as draft April 18, 2025 11:56
@laggui
Copy link
Member

laggui commented Apr 22, 2025

Can't self-assign to an opened PR, but just stating in the open that I will take over the WIP and complete the refactor.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change to a new quantization scheme struct impacts a lot of files at a superficial level, but I've left some comments below to highlight the important changes.

Left this as a draft even if it is in a good state to review. Mostly because we should avoid merging before the multi-tensor handle PR so we are the ones dealing with the conflicts 😅

Comment on lines 207 to 227
match (self.primitive, other.primitive) {
(TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => {
Self::new(TensorPrimitive::QFloat(B::q_matmul(lhs, rhs)))
Self::new(B::q_matmul(lhs, rhs))
}
(TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => Self::new(
TensorPrimitive::Float(B::float_matmul(B::dequantize(lhs), rhs)),
),
(TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => {
// NOTE: in a typical workflow with linear layers (e.g., transformers), the rhs
// represents the weights.
//
// Since `q_matmul(lhs_f16, rhs_quant)` isn't currently supported, in practice it makes
// more sense to re-quantize the input back. Better usability.
//
// This might change in the future (dequantize on read in fusion?).
Self::new(B::q_matmul(B::quantize_dynamic(lhs, rhs.scheme()), rhs))
}
(TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => {
Self::new(TensorPrimitive::Float(B::float_matmul(lhs, rhs)))
}
(lhs, rhs) => Self::new(TensorPrimitive::Float(B::float_matmul(
lhs.tensor(),
rhs.tensor(),
))),
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See special note for matmul with lhs float and rhs quantized

Comment on lines +92 to 115
/// Operations on quantized tensors.
///
/// # Return Type Semantics
///
/// The return type of each operation indicates how quantization is handled:
///
/// ## [`QuantizedTensor<B>`]
/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized
/// representation. Implementations should avoid dequantizing when possible to maintain performance.
/// For example, shape or layout changes such as expand or transpose preserve quantization.
///
/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is
/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering
/// the quantization parameters to match the new layout.*
///
///
/// ## [`TensorPrimitive<B>`]
/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation
/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`])
/// returned in floating-point form ([`TensorPrimitive::Float`]).
///
/// This distinction allows for fine-grained control over mixed-precision flows while still operating
/// through a unified API.
pub trait QTensorOps<B: Backend> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important specification related to quantization scheme propagation.

This is like a contract.

@@ -140,6 +208,110 @@ pub trait QTensorOps<B: Backend> {
false
}

/// Broadcasts the `tensor` to the given `shape`.
fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, this should always return a quantized tensor

@@ -411,8 +539,8 @@ pub trait QTensorOps<B: Backend> {
/// # Returns
///
/// The result of multiplying the two tensors together using matrix multiplication.
fn q_matmul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> {
dequant_op_quant!(
fn q_matmul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But compute operations like matmul, which affect the qparams, can return a TensorPrimitive::Float or TensorPrimitive::QFloat based on the propagation.

Comment on lines -632 to -635
fn q_mask_where(
tensor: QuantizedTensor<B>,
mask: BoolTensor<B>,
value: QuantizedTensor<B>,
) -> QuantizedTensor<B> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed some ambiguous ops for now.

Comment on lines +9 to +22
/// Describes a quantization scheme/configuration.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct QuantScheme {
/// Granularity level of quantization (e.g., per-tensor).
pub level: QuantLevel,
/// Quantization mode (e.g., symmetric).
pub mode: QuantMode,
/// Data type used for storing quantized values (e.g., QInt8).
pub q_type: QuantInputType,
/// Precision used for accumulating intermediate values (e.g., during matmul).
pub acc_precision: QuantAccPrecision,
/// Whether to propagate quantization to outputs or return unquantized results.
pub propagation: QuantPropagation,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New scheme struct

Comment on lines +98 to 106
/// Specify if the output of an operation is quantized using the scheme of the input
/// or returned unquantized.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum QuantPropagation {
/// The output is quantized using the scheme of the input.
Propagate,
/// The output is not quantized.
Inhibit,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still unsure about the naming, especially the name of the variants.

Comment on lines +94 to +114
#[test]
fn test_matmul_lhs_float_rhs_quantized() {
// Simulates a typical workflow with linear layers (e.g., transformers), where the rhs
// represents the weights. The lhs might be a float if a previous operation did not propagate
// the quantization. We still want to perform an efficient matmul with quantized weights.
//
// Since `q_matmul(lhs_f16, rhs_quant)` isn't currently supported, in practice it makes
// more sense to re-quantize the input back at this time. Better usability.
//
// This might be handled differently in the future (dequantize on read in fusion?).
let tensor_1 = TestTensor::<2>::from([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]]);
let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]]);
let tensor_3 = tensor_1.matmul(tensor_2);

let expected = TensorData::from([[16.7, 27.05, 50.8], [14., 25., 43.4], [10., 17., 30.7]]);
let output = tensor_3.into_data();
output.assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-2, 1e-1));

// Default quantization scheme does not propagate quantization with matmul
assert!(output.dtype.is_float());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detailed example and explanation for the matmul special case

Comment on lines +32 to +84
#[test]
fn quant_scheme_should_propagate() {
let device = Default::default();
let scheme = QuantScheme {
propagation: QuantPropagation::Propagate,
..Default::default()
};

let tensor_1 = TestTensor::<2>::from_floats([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]], &device)
.quantize_dynamic(&scheme);
let tensor_2 = TestTensor::<2>::from_floats([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]], &device)
.quantize_dynamic(&scheme);

let tensor_3 = tensor_1.matmul(tensor_2);
assert_eq!(tensor_3.to_data().dtype, DType::QFloat(scheme));

let tensor_4 = tensor_3.add_scalar(1.);
assert_eq!(tensor_4.to_data().dtype, DType::QFloat(scheme));
}

#[test]
fn quant_scheme_should_not_propagate() {
let device = Default::default();
let scheme = QuantScheme {
propagation: QuantPropagation::Inhibit,
acc_precision: QuantAccPrecision::Full, // f32
..Default::default()
};

let tensor_1 = TestTensor::<2>::from_floats([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]], &device)
.quantize_dynamic(&scheme);
let tensor_2 = TestTensor::<2>::from_floats([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]], &device)
.quantize_dynamic(&scheme);

// Some ops like reshape, swap_dims, permute, expand, select, slice, etc. do not affect
// the propagation. It mostly applies to compute kernels.
let tensor_1 = tensor_1
.permute([1, 0])
.swap_dims(0, 1)
.reshape([1, 6])
.reshape([3, 2]);
assert_eq!(tensor_1.to_data().dtype, DType::QFloat(scheme));

// When propagation is not desired, compute kernels like matmul should return tensor
// in floating point precision
let tensor_3 = tensor_1.matmul(tensor_2);
let dtype = tensor_3.to_data().dtype;
assert!(dtype.is_float());

// Subsequent ops will therefore be performed on floats
let tensor_4 = tensor_3.add(TestTensor::<2>::ones([3, 3], &device).cast(dtype));
assert!(tensor_4.to_data().dtype.is_float());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detailed tests that demonstrate the propagation expectations

@laggui laggui requested a review from nathanielsimard May 2, 2025 19:46
@laggui laggui marked this pull request as ready for review May 6, 2025 12:56
@laggui laggui merged commit f77b405 into main May 6, 2025
10 of 11 checks passed
@laggui laggui deleted the refactor-quantization-scheme branch May 6, 2025 20:10
Helios113 pushed a commit to Helios113/burn that referenced this pull request Jul 1, 2025
* refactor quantization scheme

* impl acc precision and output mode

* clean test

* cargo fmt

* remove unused import

* wip

* Cargo fmt

* Make it work

* Narrow, chunk and split are all high-level slice-based methods

* Add argwhere empty test

* Cleanup qtensor ops

* Better docstrings

* Remove unused

* Add propagate test example

* Add return type semantics description

* Fusion ops passthrough

* Cleanup

* Handle lhs float rhs quant for practical use cases

* Fix clippy

* Use matches

* Remove comment

* Cleaner

* Fix merged conflicts

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants