-
Notifications
You must be signed in to change notification settings - Fork 645
Refactor quantization scheme #3042
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3042 +/- ##
==========================================
+ Coverage 81.36% 81.40% +0.03%
==========================================
Files 821 821
Lines 117791 118058 +267
==========================================
+ Hits 95844 96103 +259
- Misses 21947 21955 +8 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much more flexible given the different configurations for quantization! So I agree with the change from an enum to struct for the QuantizationScheme
.
I have some comments regarding naming, otherwise LGTM.
Can't self-assign to an opened PR, but just stating in the open that I will take over the WIP and complete the refactor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change to a new quantization scheme struct impacts a lot of files at a superficial level, but I've left some comments below to highlight the important changes.
Left this as a draft even if it is in a good state to review. Mostly because we should avoid merging before the multi-tensor handle PR so we are the ones dealing with the conflicts 😅
match (self.primitive, other.primitive) { | ||
(TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { | ||
Self::new(TensorPrimitive::QFloat(B::q_matmul(lhs, rhs))) | ||
Self::new(B::q_matmul(lhs, rhs)) | ||
} | ||
(TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => Self::new( | ||
TensorPrimitive::Float(B::float_matmul(B::dequantize(lhs), rhs)), | ||
), | ||
(TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => { | ||
// NOTE: in a typical workflow with linear layers (e.g., transformers), the rhs | ||
// represents the weights. | ||
// | ||
// Since `q_matmul(lhs_f16, rhs_quant)` isn't currently supported, in practice it makes | ||
// more sense to re-quantize the input back. Better usability. | ||
// | ||
// This might change in the future (dequantize on read in fusion?). | ||
Self::new(B::q_matmul(B::quantize_dynamic(lhs, rhs.scheme()), rhs)) | ||
} | ||
(TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { | ||
Self::new(TensorPrimitive::Float(B::float_matmul(lhs, rhs))) | ||
} | ||
(lhs, rhs) => Self::new(TensorPrimitive::Float(B::float_matmul( | ||
lhs.tensor(), | ||
rhs.tensor(), | ||
))), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See special note for matmul with lhs float and rhs quantized
/// Operations on quantized tensors. | ||
/// | ||
/// # Return Type Semantics | ||
/// | ||
/// The return type of each operation indicates how quantization is handled: | ||
/// | ||
/// ## [`QuantizedTensor<B>`] | ||
/// If the method returns a `QuantizedTensor<B>`, the operation is expected to preserve the quantized | ||
/// representation. Implementations should avoid dequantizing when possible to maintain performance. | ||
/// For example, shape or layout changes such as expand or transpose preserve quantization. | ||
/// | ||
/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is | ||
/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering | ||
/// the quantization parameters to match the new layout.* | ||
/// | ||
/// | ||
/// ## [`TensorPrimitive<B>`] | ||
/// If the method returns a `TensorPrimitive<B>` enum, the return type should align with propagation | ||
/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`]) | ||
/// returned in floating-point form ([`TensorPrimitive::Float`]). | ||
/// | ||
/// This distinction allows for fine-grained control over mixed-precision flows while still operating | ||
/// through a unified API. | ||
pub trait QTensorOps<B: Backend> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Important specification related to quantization scheme propagation.
This is like a contract.
@@ -140,6 +208,110 @@ pub trait QTensorOps<B: Backend> { | |||
false | |||
} | |||
|
|||
/// Broadcasts the `tensor` to the given `shape`. | |||
fn q_expand(tensor: QuantizedTensor<B>, shape: Shape) -> QuantizedTensor<B>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, this should always return a quantized tensor
@@ -411,8 +539,8 @@ pub trait QTensorOps<B: Backend> { | |||
/// # Returns | |||
/// | |||
/// The result of multiplying the two tensors together using matrix multiplication. | |||
fn q_matmul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> QuantizedTensor<B> { | |||
dequant_op_quant!( | |||
fn q_matmul(lhs: QuantizedTensor<B>, rhs: QuantizedTensor<B>) -> TensorPrimitive<B> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But compute operations like matmul, which affect the qparams, can return a TensorPrimitive::Float
or TensorPrimitive::QFloat
based on the propagation.
fn q_mask_where( | ||
tensor: QuantizedTensor<B>, | ||
mask: BoolTensor<B>, | ||
value: QuantizedTensor<B>, | ||
) -> QuantizedTensor<B> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed some ambiguous ops for now.
/// Describes a quantization scheme/configuration. | ||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] | ||
pub struct QuantScheme { | ||
/// Granularity level of quantization (e.g., per-tensor). | ||
pub level: QuantLevel, | ||
/// Quantization mode (e.g., symmetric). | ||
pub mode: QuantMode, | ||
/// Data type used for storing quantized values (e.g., QInt8). | ||
pub q_type: QuantInputType, | ||
/// Precision used for accumulating intermediate values (e.g., during matmul). | ||
pub acc_precision: QuantAccPrecision, | ||
/// Whether to propagate quantization to outputs or return unquantized results. | ||
pub propagation: QuantPropagation, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New scheme struct
/// Specify if the output of an operation is quantized using the scheme of the input | ||
/// or returned unquantized. | ||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] | ||
pub enum QuantPropagation { | ||
/// The output is quantized using the scheme of the input. | ||
Propagate, | ||
/// The output is not quantized. | ||
Inhibit, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still unsure about the naming, especially the name of the variants.
#[test] | ||
fn test_matmul_lhs_float_rhs_quantized() { | ||
// Simulates a typical workflow with linear layers (e.g., transformers), where the rhs | ||
// represents the weights. The lhs might be a float if a previous operation did not propagate | ||
// the quantization. We still want to perform an efficient matmul with quantized weights. | ||
// | ||
// Since `q_matmul(lhs_f16, rhs_quant)` isn't currently supported, in practice it makes | ||
// more sense to re-quantize the input back at this time. Better usability. | ||
// | ||
// This might be handled differently in the future (dequantize on read in fusion?). | ||
let tensor_1 = TestTensor::<2>::from([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]]); | ||
let tensor_2 = QTensor::<TestBackend, 2>::int8([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]]); | ||
let tensor_3 = tensor_1.matmul(tensor_2); | ||
|
||
let expected = TensorData::from([[16.7, 27.05, 50.8], [14., 25., 43.4], [10., 17., 30.7]]); | ||
let output = tensor_3.into_data(); | ||
output.assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-2, 1e-1)); | ||
|
||
// Default quantization scheme does not propagate quantization with matmul | ||
assert!(output.dtype.is_float()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Detailed example and explanation for the matmul special case
#[test] | ||
fn quant_scheme_should_propagate() { | ||
let device = Default::default(); | ||
let scheme = QuantScheme { | ||
propagation: QuantPropagation::Propagate, | ||
..Default::default() | ||
}; | ||
|
||
let tensor_1 = TestTensor::<2>::from_floats([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]], &device) | ||
.quantize_dynamic(&scheme); | ||
let tensor_2 = TestTensor::<2>::from_floats([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]], &device) | ||
.quantize_dynamic(&scheme); | ||
|
||
let tensor_3 = tensor_1.matmul(tensor_2); | ||
assert_eq!(tensor_3.to_data().dtype, DType::QFloat(scheme)); | ||
|
||
let tensor_4 = tensor_3.add_scalar(1.); | ||
assert_eq!(tensor_4.to_data().dtype, DType::QFloat(scheme)); | ||
} | ||
|
||
#[test] | ||
fn quant_scheme_should_not_propagate() { | ||
let device = Default::default(); | ||
let scheme = QuantScheme { | ||
propagation: QuantPropagation::Inhibit, | ||
acc_precision: QuantAccPrecision::Full, // f32 | ||
..Default::default() | ||
}; | ||
|
||
let tensor_1 = TestTensor::<2>::from_floats([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]], &device) | ||
.quantize_dynamic(&scheme); | ||
let tensor_2 = TestTensor::<2>::from_floats([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]], &device) | ||
.quantize_dynamic(&scheme); | ||
|
||
// Some ops like reshape, swap_dims, permute, expand, select, slice, etc. do not affect | ||
// the propagation. It mostly applies to compute kernels. | ||
let tensor_1 = tensor_1 | ||
.permute([1, 0]) | ||
.swap_dims(0, 1) | ||
.reshape([1, 6]) | ||
.reshape([3, 2]); | ||
assert_eq!(tensor_1.to_data().dtype, DType::QFloat(scheme)); | ||
|
||
// When propagation is not desired, compute kernels like matmul should return tensor | ||
// in floating point precision | ||
let tensor_3 = tensor_1.matmul(tensor_2); | ||
let dtype = tensor_3.to_data().dtype; | ||
assert!(dtype.is_float()); | ||
|
||
// Subsequent ops will therefore be performed on floats | ||
let tensor_4 = tensor_3.add(TestTensor::<2>::ones([3, 3], &device).cast(dtype)); | ||
assert!(tensor_4.to_data().dtype.is_float()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Detailed tests that demonstrate the propagation expectations
* refactor quantization scheme * impl acc precision and output mode * clean test * cargo fmt * remove unused import * wip * Cargo fmt * Make it work * Narrow, chunk and split are all high-level slice-based methods * Add argwhere empty test * Cleanup qtensor ops * Better docstrings * Remove unused * Add propagate test example * Add return type semantics description * Fusion ops passthrough * Cleanup * Handle lhs float rhs quant for practical use cases * Fix clippy * Use matches * Remove comment * Cleaner * Fix merged conflicts --------- Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
Pull Request Template
Checklist
run-checks all
script has been executed.Changes
Refactor
QuantizationScheme
to a struct and add two new parameters. The quantization accumulator precision and the output mode.Testing
No new tests, but the existing ones are still succeeding.