diff --git a/crates/burn-autodiff/src/ops/qtensor.rs b/crates/burn-autodiff/src/ops/qtensor.rs index c4ec73dd98..ed307f442d 100644 --- a/crates/burn-autodiff/src/ops/qtensor.rs +++ b/crates/burn-autodiff/src/ops/qtensor.rs @@ -4,7 +4,7 @@ use burn_tensor::{ Device, Shape, TensorData, backend::Backend, ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme}, + quantization::{QuantScheme, QuantizationParametersPrimitive}, }; use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy}; @@ -16,7 +16,7 @@ impl QTensorOps for Autodiff { fn quantize( _tensor: FloatTensor, - _scheme: &QuantizationScheme, + _scheme: &QuantScheme, _qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { todo!() // required for QAT @@ -24,7 +24,7 @@ impl QTensorOps for Autodiff { fn quantize_dynamic( _tensor: FloatTensor, - _scheme: &QuantizationScheme, + _scheme: &QuantScheme, ) -> QuantizedTensor { todo!() } diff --git a/crates/burn-candle/src/ops/qtensor.rs b/crates/burn-candle/src/ops/qtensor.rs index a2aaaaf873..f6da3762c2 100644 --- a/crates/burn-candle/src/ops/qtensor.rs +++ b/crates/burn-candle/src/ops/qtensor.rs @@ -4,7 +4,7 @@ use burn_tensor::{ DType, Device, Shape, TensorData, backend::Backend, ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme}, + quantization::{QuantScheme, QuantizationParametersPrimitive}, }; use crate::{ @@ -19,7 +19,7 @@ impl QTensorOps for Candle, - _scheme: &QuantizationScheme, + _scheme: &QuantScheme, _qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { unimplemented!() diff --git a/crates/burn-candle/src/tensor.rs b/crates/burn-candle/src/tensor.rs index 24c7db0086..aabc88f2a7 100644 --- a/crates/burn-candle/src/tensor.rs +++ b/crates/burn-candle/src/tensor.rs @@ -1,6 +1,6 @@ use burn_tensor::{ DType, Element, Shape, TensorData, TensorMetadata, - quantization::{QTensorPrimitive, QuantizationScheme}, + quantization::{QTensorPrimitive, QuantScheme}, }; use crate::{CandleDevice, element::CandleElement}; @@ -63,11 +63,11 @@ pub struct CandleQTensor { // NOTE: candle does not implement `WithDType` for i8 pub qtensor: CandleTensor, /// The quantization scheme. - pub scheme: QuantizationScheme, + pub scheme: QuantScheme, } impl QTensorPrimitive for CandleQTensor { - fn scheme(&self) -> &QuantizationScheme { + fn scheme(&self) -> &QuantScheme { &self.scheme } } diff --git a/crates/burn-core/src/module/quantize.rs b/crates/burn-core/src/module/quantize.rs index 2e9f6b0bc6..bd063acc17 100644 --- a/crates/burn-core/src/module/quantize.rs +++ b/crates/burn-core/src/module/quantize.rs @@ -1,7 +1,7 @@ use burn_tensor::{ Tensor, backend::Backend, - quantization::{Calibration, QuantizationScheme}, + quantization::{Calibration, QuantScheme}, }; use crate::module::{ModuleMapper, ParamId}; @@ -11,7 +11,7 @@ pub struct Quantizer { /// The calibration method used in quantization. pub calibration: Calibration, /// The quantization scheme. - pub scheme: QuantizationScheme, + pub scheme: QuantScheme, } impl ModuleMapper for Quantizer { diff --git a/crates/burn-cubecl-fusion/src/shared/builder.rs b/crates/burn-cubecl-fusion/src/shared/builder.rs index 7d01d917a3..983e0662af 100644 --- a/crates/burn-cubecl-fusion/src/shared/builder.rs +++ b/crates/burn-cubecl-fusion/src/shared/builder.rs @@ -8,7 +8,7 @@ use burn_ir::{ BaseOperationIr, BinaryOpIr, FloatOperationIr, NumericOperationIr, OperationIr, ScalarOpIr, TensorIr, UnaryOpIr, }; -use burn_tensor::Element; +use burn_tensor::{DType, Element}; use cubecl::ir::Elem; /// The base optimization builder that can be used to fuse all elemwise operations. @@ -212,6 +212,10 @@ impl FuseOptimizationBuilder { return false; } + if self.input_is_quantized(&desc.input) { + return false; + } + if self.builder.register(|build| { build.input_swap_dims( &desc.input, @@ -243,6 +247,10 @@ impl FuseOptimizationBuilder { return false; } + if self.input_is_quantized(&desc.input) { + return false; + } + if self.builder.register(|build| { build.input_reshaped(&desc.input, &desc.out)?; Some(()) @@ -447,6 +455,10 @@ impl FuseOptimizationBuilder { return false; } + if self.input_is_quantized(&desc.tensor) { + return false; + } + self.builder.register(|build| { let input = build.input_indexed(&desc.tensor)?; let indices = build.input(&desc.indices)?; @@ -467,6 +479,10 @@ impl FuseOptimizationBuilder { return false; } + if self.input_is_quantized(&desc.tensor) { + return false; + } + self.builder.register(|build| { let input = build.input_indexed(&desc.tensor)?; let indices = build.input_indexed(&desc.indices)?; @@ -494,6 +510,10 @@ impl FuseOptimizationBuilder { return false; } + if self.input_is_quantized(&desc.lhs) { + return false; + } + self.builder.register(|build| { let lhs = build.input(&desc.lhs)?; let rhs = build.input(&desc.rhs)?; @@ -513,6 +533,10 @@ impl FuseOptimizationBuilder { return false; } + if self.input_is_quantized(&desc.input) { + return false; + } + self.builder.register(|build| { let input = build.input(&desc.input)?; let out = build.output(&desc.out)?; @@ -529,6 +553,10 @@ impl FuseOptimizationBuilder { return false; } + if self.input_is_quantized(&desc.lhs) { + return false; + } + self.builder.register(|build| { let elem = desc.lhs.dtype; let lhs = build.input(&desc.lhs)?; @@ -541,6 +569,10 @@ impl FuseOptimizationBuilder { }) } + fn input_is_quantized(&self, input: &TensorIr) -> bool { + matches!(input.dtype, DType::QFloat(_scheme)) + } + fn output_is_compatible(&mut self, out: &TensorIr) -> bool { if self.current_output_shape.is_empty() { self.current_output_shape.clone_from(&out.shape); diff --git a/crates/burn-cubecl-fusion/src/shared/ir.rs b/crates/burn-cubecl-fusion/src/shared/ir.rs index 4dbfe34e1f..e9b9e8a9f0 100644 --- a/crates/burn-cubecl-fusion/src/shared/ir.rs +++ b/crates/burn-cubecl-fusion/src/shared/ir.rs @@ -415,7 +415,7 @@ impl From for FusePrecision { DType::U16 => Self::U16, DType::U8 => Self::U8, DType::Bool => Self::Bool, - _ => panic!("Unsupported"), + _ => panic!("Unsupported precision for fusion: {value:?}"), } } } diff --git a/crates/burn-cubecl/src/kernel/matmul/base.rs b/crates/burn-cubecl/src/kernel/matmul/base.rs index cfcb970d9d..240e05eba1 100644 --- a/crates/burn-cubecl/src/kernel/matmul/base.rs +++ b/crates/burn-cubecl/src/kernel/matmul/base.rs @@ -1,6 +1,9 @@ use super::init_matmul_output; use crate::{CubeRuntime, FloatElement, tensor::CubeTensor}; -use burn_tensor::DType; +use burn_tensor::{ + DType, + quantization::{QTensorPrimitive, QuantAccPrecision}, +}; use cubecl::linalg::matmul::{components::Quantized, kernels::MatmulLaunchError}; #[cfg(feature = "autotune")] @@ -65,16 +68,34 @@ pub fn q_matmul( let client = &lhs.client; + let scheme = *lhs.scheme(); + lhs.dtype = DType::I8; rhs.dtype = DType::I8; - cubecl::linalg::matmul::launch_ref::( - &Default::default(), - client, - &lhs.as_handle_ref(), - &rhs.as_handle_ref(), - &out.as_handle_ref(), - )?; + match scheme.acc_precision { + QuantAccPrecision::Full => { + cubecl::linalg::matmul::launch_ref::( + &Default::default(), + client, + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), + )?; + } + QuantAccPrecision::Half => { + cubecl::linalg::matmul::launch_ref::< + R, + (i8, half::f16, half::f16, half::f16, Quantized), + >( + &Default::default(), + client, + &lhs.as_handle_ref(), + &rhs.as_handle_ref(), + &out.as_handle_ref(), + )?; + } + } Ok(out) } diff --git a/crates/burn-cubecl/src/kernel/quantization/dequantize.rs b/crates/burn-cubecl/src/kernel/quantization/dequantize.rs index a4e2e09e76..7b7fe5d602 100644 --- a/crates/burn-cubecl/src/kernel/quantization/dequantize.rs +++ b/crates/burn-cubecl/src/kernel/quantization/dequantize.rs @@ -1,7 +1,7 @@ use crate::tensor::CubeTensor; use crate::{CubeElement, CubeRuntime}; use burn_tensor::DType; -use burn_tensor::quantization::{QuantizationMode, QuantizationScheme, QuantizationType}; +use burn_tensor::quantization::{QuantInputType, QuantLevel, QuantMode, QuantScheme}; use cubecl::calculate_cube_count_elemwise; use cubecl::prelude::*; @@ -39,7 +39,7 @@ fn unpack_i8s(value: u32) -> Line { fn dequantize_per_tensor_symmetric_int8_kernel( input: &QTensor, output: &mut Tensor>, - #[comptime] scheme: QuantizationScheme, + #[comptime] scheme: QuantScheme, ) { // Last position contains the qparam if ABSOLUTE_POS >= input.len() - 1 { @@ -93,7 +93,12 @@ where if let DType::QFloat(scheme) = tensor.dtype { match scheme { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { unsafe { dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( &client, diff --git a/crates/burn-cubecl/src/kernel/quantization/qtensor.rs b/crates/burn-cubecl/src/kernel/quantization/qtensor.rs index fa80295f7c..e040bc5910 100644 --- a/crates/burn-cubecl/src/kernel/quantization/qtensor.rs +++ b/crates/burn-cubecl/src/kernel/quantization/qtensor.rs @@ -1,13 +1,13 @@ #![allow(missing_docs)] // cube derive macros -use burn_tensor::quantization::{QuantizationMode, QuantizationScheme}; +use burn_tensor::quantization::{QuantInputType, QuantLevel, QuantMode, QuantScheme}; use cubecl::prelude::*; /// Quantization parameters. #[derive(CubeLaunch, CubeType)] pub struct QParams { #[cube(comptime)] - scheme: QuantizationScheme, + scheme: QuantScheme, } /// Quantized tensor representation. @@ -16,7 +16,7 @@ pub type QTensor = Array>; #[cube] impl QParams { /// Create a new quantization parameters instance. - pub fn new(scheme: QuantizationScheme) -> Self { + pub fn new(#[comptime] scheme: QuantScheme) -> Self { QParams { scheme } } @@ -25,9 +25,12 @@ impl QParams { let len = tensor.len(); match comptime!(self.scheme) { // Symmetric quantization only contains the scaling factor as the last element - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, _) => { - (f32::reinterpret(tensor[len - 1][tensor.line_size() - 1]), 0) - } + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => (f32::reinterpret(tensor[len - 1][tensor.line_size() - 1]), 0), } } } diff --git a/crates/burn-cubecl/src/kernel/quantization/quantize.rs b/crates/burn-cubecl/src/kernel/quantization/quantize.rs index d8c5b8faae..c1420cb43b 100644 --- a/crates/burn-cubecl/src/kernel/quantization/quantize.rs +++ b/crates/burn-cubecl/src/kernel/quantization/quantize.rs @@ -1,7 +1,7 @@ use crate::tensor::CubeTensor; use crate::{CubeElement, CubeRuntime, IntElement}; use burn_tensor::Shape; -use burn_tensor::quantization::{QuantizationMode, QuantizationScheme, QuantizationType}; +use burn_tensor::quantization::{QuantInputType, QuantLevel, QuantMode, QuantScheme}; use cubecl::calculate_cube_count_elemwise; use cubecl::prelude::*; @@ -90,16 +90,19 @@ fn create_quantized_output( num_input_elems: usize, device: R::Device, shape: Shape, - scheme: QuantizationScheme, + scheme: QuantScheme, ) -> CubeTensor { // Output tensor contains 4x less elements (four int8 values packed in a single u32) let output_elems_size = usize::div_ceil(num_input_elems, 4) * core::mem::size_of::(); // Scale and offset (optional) qparams are also packed in the tensor data let qparams_size = match &scheme { - QuantizationScheme::PerTensor(mode, ..) => match mode { - QuantizationMode::Symmetric => core::mem::size_of::(), - }, + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => core::mem::size_of::(), }; let handle = client.empty(output_elems_size + qparams_size); @@ -115,7 +118,7 @@ fn create_quantized_output( /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. pub fn quantize( tensor: CubeTensor, - scheme: &QuantizationScheme, + scheme: &QuantScheme, scale: CubeTensor, ) -> CubeTensor where @@ -142,32 +145,28 @@ where ); match scheme { - QuantizationScheme::PerTensor(mode, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { let ndims = tensor.shape.num_dims(); let dummy_array = vec![1; ndims]; - match mode { - QuantizationMode::Symmetric => { - unsafe { - quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(line_size), - // Ignore shape and stride - TensorArg::from_raw_parts::( - &scale.handle, - &dummy_array, - &dummy_array, - 1, - ), - ScalarArg::new(-i8::MAX as f32), - ScalarArg::new(i8::MAX as f32), - output.as_array_arg::(1), - ) - }; - } - } + unsafe { + quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + // Ignore shape and stride + TensorArg::from_raw_parts::(&scale.handle, &dummy_array, &dummy_array, 1), + ScalarArg::new(-i8::MAX as f32), + ScalarArg::new(i8::MAX as f32), + output.as_array_arg::(1), + ) + }; } } diff --git a/crates/burn-cubecl/src/ops/qtensor.rs b/crates/burn-cubecl/src/ops/qtensor.rs index e471acbb8b..3d00bee79f 100644 --- a/crates/burn-cubecl/src/ops/qtensor.rs +++ b/crates/burn-cubecl/src/ops/qtensor.rs @@ -1,11 +1,11 @@ use std::ops::Range; use burn_tensor::{ - DType, Device, Shape, TensorData, + DType, Device, Shape, TensorData, TensorPrimitive, ops::{FloatTensor, FloatTensorOps, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ - QTensorPrimitive, QuantizationMode, QuantizationParametersPrimitive, QuantizationScheme, - QuantizationType, + QTensorPrimitive, QuantInputType, QuantLevel, QuantMode, QuantPropagation, QuantScheme, + QuantizationParametersPrimitive, }, }; use cubecl::{ @@ -27,7 +27,7 @@ use super::{permute, swap_dims}; fn new_qtensor>( data: &[u8], shape: S, - scheme: QuantizationScheme, + scheme: QuantScheme, device: &R::Device, ) -> CubeTensor { let client = R::client(device); @@ -52,7 +52,12 @@ where fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { match data.dtype { DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { // TensorData quantized representation is the same, with multiple quantized values // packed into u32 and quantization parameters appended to the bytes new_qtensor(data.as_bytes(), data.shape.clone(), scheme, device) @@ -69,7 +74,7 @@ where fn quantize( tensor: FloatTensor, - scheme: &QuantizationScheme, + scheme: &QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { kernel::quantization::quantize::(tensor, scheme, qparams.scale) @@ -143,29 +148,47 @@ where unimplemented!() } - fn q_matmul(lhs: QuantizedTensor, rhs: QuantizedTensor) -> QuantizedTensor { + fn q_matmul(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { if features_enabled::(&lhs.client) && both_matches_symmetric_qint8(lhs.scheme(), rhs.scheme()) { let out = kernel::matmul::q_matmul(lhs.clone(), rhs.clone(), None, MatmulStrategy::default()); if let Ok(out) = out { - return out; + return match lhs.scheme().propagation { + QuantPropagation::Propagate => { + TensorPrimitive::QFloat(Self::quantize_dynamic(out, lhs.scheme())) + } + QuantPropagation::Inhibit => TensorPrimitive::Float(out), + }; } } // If the above quantized matmul fail, we fallback to the dequantize-then-matmul pattern. + let scheme = *lhs.scheme(); let t1_f = ::dequantize(lhs); let t2_f = ::dequantize(rhs); - Self::float_matmul(t1_f, t2_f) + let out = Self::float_matmul(t1_f, t2_f); + + match scheme.propagation { + QuantPropagation::Propagate => { + TensorPrimitive::QFloat(Self::quantize_dynamic(out, &scheme)) + } + QuantPropagation::Inhibit => TensorPrimitive::Float(out), + } } } -fn both_matches_symmetric_qint8(lhs: &QuantizationScheme, rhs: &QuantizationScheme) -> bool { +fn both_matches_symmetric_qint8(lhs: &QuantScheme, rhs: &QuantScheme) -> bool { [lhs, rhs].iter().all(|scheme| { matches!( scheme, - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8), + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } ) }) } diff --git a/crates/burn-cubecl/src/tensor/base.rs b/crates/burn-cubecl/src/tensor/base.rs index 923fefd3bf..715776ec04 100644 --- a/crates/burn-cubecl/src/tensor/base.rs +++ b/crates/burn-cubecl/src/tensor/base.rs @@ -151,7 +151,7 @@ impl TensorMetadata for CubeTensor { } impl QTensorPrimitive for CubeTensor { - fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme { + fn scheme(&self) -> &burn_tensor::quantization::QuantScheme { if let DType::QFloat(scheme) = &self.dtype { scheme } else { diff --git a/crates/burn-cubecl/src/tests/quantization.rs b/crates/burn-cubecl/src/tests/quantization.rs index 2dba7c275f..0818415a7c 100644 --- a/crates/burn-cubecl/src/tests/quantization.rs +++ b/crates/burn-cubecl/src/tests/quantization.rs @@ -1,17 +1,13 @@ #[burn_tensor_testgen::testgen(quantization)] mod tests { use super::*; - use burn_tensor::{ - Tensor, - quantization::{QuantizationMode, QuantizationScheme, QuantizationType}, - }; + use burn_tensor::{Tensor, quantization::QuantScheme}; use burn_tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; #[test] fn should_quantize_dequantize_symmetric_single() { - let scheme = - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); + let scheme = QuantScheme::default(); let input = Tensor::::from_floats([-1.8], &Default::default()); let input_ref = Tensor::::from_data(input.to_data(), &Default::default()); @@ -31,8 +27,7 @@ mod tests { #[test] fn should_quantize_dequantize_symmetric_multiple() { - let scheme = - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); + let scheme = QuantScheme::default(); let input = Tensor::::from_floats([-1.8, -1.0, 0.0, 0.5, 0.0], &Default::default()); let input_ref = diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 38339bfc47..62f163fc47 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -1,13 +1,14 @@ use std::{marker::PhantomData, ops::Range}; use burn_ir::{ - DequantizeOpIr, FloatOperationIr, HandleContainer, InitOperationIr, OperationIr, - QuantizationParametersIr, QuantizeOpIr, + BaseOperationIr, DequantizeOpIr, ExpandOpIr, FlipOpIr, FloatOperationIr, GatherOpIr, + HandleContainer, InitOperationIr, NumericOperationIr, OperationIr, PermuteOpIr, + QuantizationParametersIr, QuantizeOpIr, SelectOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr, }; use burn_tensor::{ DType, Device, Element, Shape, TensorData, TensorMetadata, - ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme}, + ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, + quantization::{QuantScheme, QuantizationParametersPrimitive}, }; use crate::{ @@ -42,7 +43,7 @@ impl QTensorOps for Fusion { fn quantize( tensor: FloatTensor, - scheme: &QuantizationScheme, + scheme: &QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { #[derive(new)] @@ -69,6 +70,7 @@ impl QTensorOps for Fusion { } let shape: Vec = tensor.shape.clone(); + let dtype = tensor.dtype; let out = tensor .client .tensor_uninitialized(shape, DType::QFloat(*scheme)); @@ -91,10 +93,7 @@ impl QTensorOps for Fusion { out.client.register( streams, - OperationIr::Float( - FloatElem::::dtype(), - FloatOperationIr::Quantize(desc.clone()), - ), + OperationIr::Float(dtype, FloatOperationIr::Quantize(desc.clone())), QuantizeOp::::new(desc), ); @@ -119,9 +118,8 @@ impl QTensorOps for Fusion { let stream = tensor.stream; let shape: Vec = tensor.shape.clone(); - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let dtype = B::FloatElem::dtype(); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = DequantizeOpIr { input: tensor.into_ir(), @@ -130,10 +128,7 @@ impl QTensorOps for Fusion { out.client.register( vec![stream], - OperationIr::Float( - FloatElem::::dtype(), - FloatOperationIr::Dequantize(desc.clone()), - ), + OperationIr::Float(dtype, FloatOperationIr::Dequantize(desc.clone())), DequantizeOp::::new(desc), ); @@ -159,8 +154,36 @@ impl QTensorOps for Fusion { client_original.change_client_quantized::(tensor.into_ir(), client_target, id) } - fn q_reshape(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { - unimplemented!() + fn q_reshape(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { + #[derive(new)] + struct ReshapeDimsOps { + desc: UnaryOpIr, + _b: PhantomData, + } + + impl Operation for ReshapeDimsOps { + fn execute(&self, handles: &mut HandleContainer) { + let input = handles.get_quantized_tensor::(&self.desc.input); + let output = B::q_reshape(input, Shape::from(&self.desc.out.shape)); + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let dtype = tensor.dtype; + let out = tensor.client.tensor_uninitialized(shape.dims, dtype); + + let desc = UnaryOpIr { + input: tensor.into_ir(), + out: out.to_ir_out(), + }; + out.client.register( + vec![stream], + OperationIr::BaseFloat(BaseOperationIr::Reshape(desc.clone())), + ReshapeDimsOps::::new(desc), + ); + + out } async fn q_into_data(tensor: QuantizedTensor) -> TensorData { @@ -168,42 +191,280 @@ impl QTensorOps for Fusion { } fn q_swap_dims( - _tensor: QuantizedTensor, - _dim1: usize, - _dim2: usize, + tensor: QuantizedTensor, + dim1: usize, + dim2: usize, ) -> QuantizedTensor { - unimplemented!() + #[derive(new)] + struct SwapDimsOps { + desc: SwapDimsOpIr, + _b: PhantomData, + } + + impl Operation for SwapDimsOps { + fn execute(&self, handles: &mut HandleContainer) { + let input = handles.get_quantized_tensor::(&self.desc.input); + let output = B::q_swap_dims(input, self.desc.dim1, self.desc.dim2); + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let dtype = tensor.dtype; + let mut shape = tensor.shape.clone(); + shape[dim1] = tensor.shape[dim2]; + shape[dim2] = tensor.shape[dim1]; + + let mut out = tensor.client.tensor_uninitialized(shape, dtype); + + let desc = SwapDimsOpIr { + input: tensor.into_ir(), + dim1, + dim2, + out: out.to_ir_out(), + }; + out.client.register( + vec![stream], + OperationIr::BaseFloat(BaseOperationIr::SwapDims(desc.clone())), + SwapDimsOps::::new(desc), + ); + out.stream = stream; + + out } - fn q_permute(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { - unimplemented!() + fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { + #[derive(new)] + struct PermuteDimsOps { + desc: PermuteOpIr, + _b: PhantomData, + } + + impl Operation for PermuteDimsOps { + fn execute(&self, handles: &mut HandleContainer) { + let input = handles.get_quantized_tensor::(&self.desc.input); + let output = B::q_permute(input, self.desc.axes.as_slice()); + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + + // Change the shape of the tensor to match the new axes + let shape = axes.iter().map(|x| tensor.shape[*x]).collect(); + + let out = tensor.client.tensor_uninitialized(shape, tensor.dtype); + + let desc = PermuteOpIr { + input: tensor.into_ir(), + axes: axes.to_vec(), + out: out.to_ir_out(), + }; + + out.client.register( + vec![stream], + OperationIr::BaseInt(BaseOperationIr::Permute(desc.clone())), + PermuteDimsOps::::new(desc), + ); + + out } - fn q_flip(_tensor: QuantizedTensor, _axes: &[usize]) -> QuantizedTensor { - unimplemented!() + fn q_flip(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor { + #[derive(new)] + struct FlipOps { + desc: FlipOpIr, + _b: PhantomData, + } + + impl Operation for FlipOps { + fn execute(&self, handles: &mut HandleContainer) { + let input = handles.get_quantized_tensor::(&self.desc.input); + let output = B::q_flip(input, &self.desc.axes); + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), tensor.dtype); + + let desc = FlipOpIr { + input: tensor.into_ir(), + axes: axes.to_vec(), + out: out.to_ir_out(), + }; + + out.client.register( + vec![stream], + OperationIr::BaseInt(BaseOperationIr::Flip(desc.clone())), + FlipOps::::new(desc), + ); + + out } fn q_gather( - _dim: usize, - _tensor: QuantizedTensor, - _indices: IntTensor, + dim: usize, + tensor: QuantizedTensor, + indices: IntTensor, ) -> QuantizedTensor { - unimplemented!() + #[derive(new)] + struct GatherOps { + desc: GatherOpIr, + _b: PhantomData, + } + + impl Operation for GatherOps { + fn execute(&self, handles: &mut HandleContainer) { + let tensor = handles.get_quantized_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + + let output = B::q_gather(self.desc.dim, tensor, indices); + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + + let stream_1 = tensor.stream; + let stream_2 = indices.stream; + let dtype = tensor.dtype; + let shape: Vec = indices.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape, dtype); + + let desc = GatherOpIr { + tensor: tensor.into_ir(), + dim, + indices: indices.into_ir(), + out: out.to_ir_out(), + }; + out.client.register( + vec![stream_1, stream_2], + OperationIr::NumericFloat(dtype, NumericOperationIr::Gather(desc.clone())), + GatherOps::::new(desc), + ); + + out } fn q_select( - _tensor: QuantizedTensor, - _dim: usize, - _indices: IntTensor, + tensor: QuantizedTensor, + dim: usize, + indices: IntTensor, ) -> QuantizedTensor { - unimplemented!() + #[derive(new)] + struct SelectOps { + desc: SelectOpIr, + _b: PhantomData, + } + + impl Operation for SelectOps { + fn execute(&self, handles: &mut HandleContainer) { + let tensor = handles.get_quantized_tensor::(&self.desc.tensor); + let indices = handles.get_int_tensor::(&self.desc.indices); + + let output = B::q_select(tensor, self.desc.dim, indices); + + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + + let stream_1 = tensor.stream; + let stream_2 = indices.stream; + let dtype = tensor.dtype; + let mut shape: Vec = tensor.shape.clone(); + shape[dim] = indices.shape[0]; + let out = tensor.client.tensor_uninitialized(shape, dtype); + let desc = SelectOpIr { + tensor: tensor.into_ir(), + dim, + indices: indices.into_ir(), + out: out.to_ir_out(), + }; + out.client.register( + vec![stream_1, stream_2], + OperationIr::NumericFloat(dtype, NumericOperationIr::Select(desc.clone())), + SelectOps::::new(desc), + ); + + out } - fn q_slice(_tensor: QuantizedTensor, _ranges: &[Range]) -> QuantizedTensor { - unimplemented!() + fn q_slice(tensor: QuantizedTensor, ranges: &[Range]) -> QuantizedTensor { + #[derive(new)] + struct SliceOps { + desc: SliceOpIr, + _b: PhantomData, + } + + impl Operation for SliceOps { + fn execute(&self, handles: &mut HandleContainer) { + let tensor = handles.get_quantized_tensor::(&self.desc.tensor); + + let output = B::q_slice(tensor, self.desc.ranges.as_slice()); + + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + let stream = tensor.stream; + let dtype = tensor.dtype; + let ndims = tensor.shape().num_dims(); + let mut shape: Vec = ranges.iter().map(|range| range.end - range.start).collect(); + + for i in shape.len()..ndims { + shape.push(tensor.shape[i]); + } + + let out = tensor.client.tensor_uninitialized(shape, dtype); + + let desc = SliceOpIr { + tensor: tensor.into_ir(), + ranges: ranges.into(), + out: out.to_ir_out(), + }; + out.client.register( + vec![stream], + OperationIr::BaseFloat(BaseOperationIr::Slice(desc.clone())), + SliceOps::::new(desc), + ); + + out } - fn q_expand(_tensor: QuantizedTensor, _shape: Shape) -> QuantizedTensor { - unimplemented!() + fn q_expand(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor { + #[derive(new)] + struct ExpandOps { + desc: ExpandOpIr, + _b: PhantomData, + } + + impl Operation for ExpandOps { + fn execute(&self, handles: &mut HandleContainer) { + let input = handles.get_quantized_tensor::(&self.desc.input); + let output = B::q_expand(input, self.desc.shape.clone().into()); + + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + + let stream = tensor.stream; + + let out = tensor + .client + .tensor_uninitialized(shape.dims.clone(), tensor.dtype); + + let desc = ExpandOpIr { + input: tensor.into_ir(), + shape: shape.dims, + out: out.to_ir_out(), + }; + + out.client.register( + vec![stream], + OperationIr::BaseFloat(BaseOperationIr::Expand(desc.clone())), + ExpandOps::::new(desc), + ); + + out } } diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index eb7c958492..17a025e569 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -2,7 +2,7 @@ use crate::{Client, FusionBackend, FusionRuntime, client::FusionClient, stream:: use burn_ir::{TensorId, TensorIr, TensorStatus}; use burn_tensor::{ DType, Shape, TensorData, TensorMetadata, - quantization::{QTensorPrimitive, QuantizationScheme}, + quantization::{QTensorPrimitive, QuantScheme}, }; use std::sync::Arc; @@ -179,7 +179,7 @@ impl Drop for FusionTensor { } impl QTensorPrimitive for FusionTensor { - fn scheme(&self) -> &QuantizationScheme { + fn scheme(&self) -> &QuantScheme { if let DType::QFloat(scheme) = &self.dtype { scheme } else { diff --git a/crates/burn-ir/src/operation.rs b/crates/burn-ir/src/operation.rs index 561a9b134c..4ef5fb3b64 100644 --- a/crates/burn-ir/src/operation.rs +++ b/crates/burn-ir/src/operation.rs @@ -11,7 +11,7 @@ use burn_tensor::{ ops::{ ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions, }, - quantization::QuantizationScheme, + quantization::QuantScheme, }; use crate::TensorIr; @@ -978,7 +978,7 @@ pub struct QuantizationParametersIr { pub struct QuantizeOpIr { pub tensor: TensorIr, pub qparams: QuantizationParametersIr, - pub scheme: QuantizationScheme, + pub scheme: QuantScheme, pub out: TensorIr, } diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index 728a67a227..6988ed8a57 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -3,7 +3,7 @@ use burn_tensor::ElementConversion; use burn_tensor::TensorData; use burn_tensor::TensorMetadata; #[cfg(feature = "simd")] -use burn_tensor::{DType, quantization::QuantizationType}; +use burn_tensor::{DType, quantization::QuantInputType}; use core::fmt::Debug; use core::{marker::PhantomData, ops::Range}; use ndarray::Array2; @@ -198,8 +198,8 @@ macro_rules! dispatch_binary_simd { paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* - DType::QFloat(strategy) => match strategy.q_type() { - QuantizationType::QInt8 => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), + DType::QFloat(strategy) => match strategy.q_type { + QuantInputType::QInt8 => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), }, _ => Err(($lhs, $rhs)), }; @@ -235,8 +235,8 @@ macro_rules! dispatch_binary_scalar_simd { paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)* - DType::QFloat(strategy) => match strategy.q_type() { - QuantizationType::QInt8 => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), + DType::QFloat(strategy) => match strategy.q_type { + QuantInputType::QInt8 => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs), }, _ => Err($lhs), }; @@ -260,8 +260,8 @@ macro_rules! dispatch_cmp_simd { paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_cmp_simd::<$elem, $ty, $op>($lhs, $rhs),)* - DType::QFloat(strategy) => match strategy.q_type() { - QuantizationType::QInt8 => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs), + DType::QFloat(strategy) => match strategy.q_type { + QuantInputType::QInt8 => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs), }, _ => Err(($lhs, $rhs)), }; @@ -284,8 +284,8 @@ macro_rules! dispatch_cmp_scalar_simd { paste! { let simd = match $elem::dtype() { $(DType::[<$ty:upper>] => try_cmp_scalar_simd::<$elem, $ty, $op>($lhs, $rhs),)* - DType::QFloat(strategy) => match strategy.q_type() { - QuantizationType::QInt8 => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs), + DType::QFloat(strategy) => match strategy.q_type { + QuantInputType::QInt8 => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs), }, _ => Err($lhs), }; diff --git a/crates/burn-ndarray/src/ops/qtensor.rs b/crates/burn-ndarray/src/ops/qtensor.rs index b834854b6b..e524806521 100644 --- a/crates/burn-ndarray/src/ops/qtensor.rs +++ b/crates/burn-ndarray/src/ops/qtensor.rs @@ -5,8 +5,9 @@ use burn_tensor::{ DType, Shape, TensorData, TensorMetadata, ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ - QParams, QuantizationMode, QuantizationParametersPrimitive, QuantizationScheme, - QuantizationStrategy, QuantizationType, QuantizedBytes, SymmetricQuantization, + QParams, QuantInputType, QuantLevel, QuantMode, QuantScheme, + QuantizationParametersPrimitive, QuantizationStrategy, QuantizedBytes, + SymmetricQuantization, }, }; @@ -46,21 +47,24 @@ impl QTensorOps { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { // We should probably check that `Q` matches i8.. but it's the only valid type now let (values, qparams) = q_bytes.into_vec_i8(); let data = TensorData::new(values, shape); - let qparams = match mode { - QuantizationMode::Symmetric => qparams - .scale - .into_iter() - .map(|scale| QParams { - scale, - offset: None, - }) - .collect(), - }; + let qparams = qparams + .scale + .into_iter() + .map(|scale| QParams { + scale, + offset: None, + }) + .collect(); NdArrayQTensor { qtensor: NdArrayTensor::::from_data(data), @@ -79,12 +83,17 @@ impl QTensorOps, - scheme: &QuantizationScheme, + scheme: &QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { // Implement with ndarray instead of QuantizationStrategy? let (strategy, qparams) = match scheme { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { let scale = into_data_f(qparams.scale).iter().next().unwrap(); ( QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( diff --git a/crates/burn-ndarray/src/ops/simd/maxpool.rs b/crates/burn-ndarray/src/ops/simd/maxpool.rs index 7367b83a16..7490ebeab5 100644 --- a/crates/burn-ndarray/src/ops/simd/maxpool.rs +++ b/crates/burn-ndarray/src/ops/simd/maxpool.rs @@ -3,7 +3,7 @@ use core::{marker::PhantomData, mem::transmute}; use crate::{sharing::UnsafeSharedRef, tensor::NdArrayTensor}; use burn_common::{iter_range_par, run_par}; -use burn_tensor::{DType, Element, TensorMetadata, quantization::QuantizationType}; +use burn_tensor::{DType, Element, TensorMetadata, quantization::QuantInputType}; use macerator::{Simd, VOrd}; use ndarray::{Array4, s}; use nhwc::max_pool2d_nhwc; @@ -33,8 +33,8 @@ macro_rules! launch_kernel { DType::U16 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::U8 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), DType::Bool if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), - DType::QFloat(scheme) => match scheme.q_type() { - QuantizationType::QInt8 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), + DType::QFloat(scheme) => match scheme.q_type { + QuantInputType::QInt8 if is_accelerated::() => Ok(cast($func::(cast($x), $($arg),*))), _ => Err($x) }, _ => Err($x), diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index 598495ab36..1282f20075 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -3,8 +3,8 @@ use core::mem; use burn_tensor::{ DType, Element, Shape, TensorData, TensorMetadata, quantization::{ - QParams, QTensorPrimitive, QuantizationMode, QuantizationScheme, QuantizationStrategy, - QuantizationType, SymmetricQuantization, + QParams, QTensorPrimitive, QuantInputType, QuantLevel, QuantMode, QuantScheme, + QuantizationStrategy, SymmetricQuantization, }, }; @@ -318,7 +318,7 @@ pub struct NdArrayQTensor { /// The quantized tensor. pub qtensor: NdArrayTensor, /// The quantization scheme. - pub scheme: QuantizationScheme, + pub scheme: QuantScheme, /// The quantization parameters. pub qparams: Vec>, } @@ -327,17 +327,20 @@ impl NdArrayQTensor { /// Returns the quantization strategy, including quantization parameters, for the given tensor. pub fn strategy(&self) -> QuantizationStrategy { match self.scheme { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { - QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( - self.qparams[0].scale, - )) - } + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( + self.qparams[0].scale, + )), } } } impl QTensorPrimitive for NdArrayQTensor { - fn scheme(&self) -> &QuantizationScheme { + fn scheme(&self) -> &QuantScheme { &self.scheme } } @@ -361,7 +364,7 @@ mod tests { use burn_tensor::{ Distribution, ops::{FloatTensorOps, QTensorOps}, - quantization::{QuantizationParametersPrimitive, QuantizationType}, + quantization::QuantizationParametersPrimitive, }; #[test] @@ -427,8 +430,7 @@ mod tests { let device = Default::default(); let tensor = B::float_from_data(TensorData::from([-1.8f32, -1.0, 0.0, 0.5]), &device); - let scheme = - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); + let scheme = QuantScheme::default(); let qparams = QuantizationParametersPrimitive { scale: B::float_from_data(TensorData::from([scale]), &device), offset: None, diff --git a/crates/burn-router/src/backend.rs b/crates/burn-router/src/backend.rs index 9675d9cb5b..473f608107 100644 --- a/crates/burn-router/src/backend.rs +++ b/crates/burn-router/src/backend.rs @@ -3,7 +3,7 @@ use core::marker::PhantomData; use burn_tensor::{ backend::Backend, - quantization::{QTensorPrimitive, QuantizationScheme}, + quantization::{QTensorPrimitive, QuantScheme}, }; use super::{RouterTensor, RunnerChannel, RunnerClient, get_client, set_seed}; @@ -33,7 +33,7 @@ impl Default for BackendRouter { // TODO: quantization tensor primitive (w/ qparams) impl QTensorPrimitive for RouterTensor { - fn scheme(&self) -> &QuantizationScheme { + fn scheme(&self) -> &QuantScheme { todo!() } } diff --git a/crates/burn-router/src/ops/op_qfloat.rs b/crates/burn-router/src/ops/op_qfloat.rs index 763a22c17d..5af668337b 100644 --- a/crates/burn-router/src/ops/op_qfloat.rs +++ b/crates/burn-router/src/ops/op_qfloat.rs @@ -3,7 +3,7 @@ use core::ops::Range; use burn_tensor::{ Device, Shape, TensorData, ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, - quantization::{QuantizationParametersPrimitive, QuantizationScheme}, + quantization::{QuantScheme, QuantizationParametersPrimitive}, }; use crate::{BackendRouter, RunnerChannel}; @@ -15,7 +15,7 @@ impl QTensorOps for BackendRouter { fn quantize( _tensor: FloatTensor, - _scheme: &QuantizationScheme, + _scheme: &QuantScheme, _qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { unimplemented!() @@ -23,7 +23,7 @@ impl QTensorOps for BackendRouter { fn quantize_dynamic( _tensor: FloatTensor, - _scheme: &QuantizationScheme, + _scheme: &QuantScheme, ) -> QuantizedTensor { unimplemented!() } diff --git a/crates/burn-tch/src/ops/qtensor.rs b/crates/burn-tch/src/ops/qtensor.rs index 074e00c857..fc0128495a 100644 --- a/crates/burn-tch/src/ops/qtensor.rs +++ b/crates/burn-tch/src/ops/qtensor.rs @@ -4,8 +4,8 @@ use burn_tensor::{ DType, Shape, TensorData, TensorMetadata, ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{ - QParams, QuantizationMode, QuantizationParametersPrimitive, QuantizationScheme, - QuantizationType, QuantizedBytes, + QParams, QuantInputType, QuantLevel, QuantMode, QuantScheme, + QuantizationParametersPrimitive, QuantizedBytes, }, }; @@ -15,7 +15,7 @@ use super::TchOps; fn quantize( tensor: tch::Tensor, - scheme: &QuantizationScheme, + scheme: &QuantScheme, qparams: &QParams, ) -> tch::Tensor { let mut tensor = tensor; @@ -25,9 +25,12 @@ fn quantize( } match scheme { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { - tensor.quantize_per_tensor(qparams.scale.elem(), 0, tch::Kind::QInt8) - } + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => tensor.quantize_per_tensor(qparams.scale.elem(), 0, tch::Kind::QInt8), } } @@ -41,8 +44,8 @@ impl QTensorOps for LibTorch { // So for now we have to load the dequantized values to quantize them back since the dequantization // methods take the values provided when quantizing. match data.dtype { - DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensor(_, _) => { + DType::QFloat(scheme) => match scheme.level { + QuantLevel::Tensor => { let num_elements = data.num_elements(); let q_bytes = QuantizedBytes { bytes: data.into_bytes(), @@ -73,7 +76,7 @@ impl QTensorOps for LibTorch { fn quantize( tensor: FloatTensor, - scheme: &QuantizationScheme, + scheme: &QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor { let mut tensor = tensor; @@ -83,13 +86,16 @@ impl QTensorOps for LibTorch { } let qtensor = match scheme { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { - tensor.tensor.quantize_per_tensor_tensor_qparams( - &qparams.scale.tensor, - &tch::Tensor::zeros_like(&qparams.scale.tensor), - tch::Kind::QInt8, - ) - } + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => tensor.tensor.quantize_per_tensor_tensor_qparams( + &qparams.scale.tensor, + &tch::Tensor::zeros_like(&qparams.scale.tensor), + tch::Kind::QInt8, + ), }; TchQTensor { @@ -98,12 +104,14 @@ impl QTensorOps for LibTorch { } } - fn quantize_dynamic( - tensor: FloatTensor, - scheme: &QuantizationScheme, - ) -> QuantizedTensor { + fn quantize_dynamic(tensor: FloatTensor, scheme: &QuantScheme) -> QuantizedTensor { let qtensor = match &scheme { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { log::warn!( "LibTorch backend does not support symmetric per-tensor scheme for dynamic quantization, reverting to the default per-tensor affine quantization" ); diff --git a/crates/burn-tch/src/tensor.rs b/crates/burn-tch/src/tensor.rs index c49cf384a0..f0e03debd0 100644 --- a/crates/burn-tch/src/tensor.rs +++ b/crates/burn-tch/src/tensor.rs @@ -2,8 +2,8 @@ use crate::{LibTorchDevice, TchElement}; use burn_tensor::{ DType, Shape, TensorData, TensorMetadata, quantization::{ - QTensorPrimitive, QuantizationMode, QuantizationScheme, QuantizationStrategy, - QuantizationType, SymmetricQuantization, + QTensorPrimitive, QuantInputType, QuantLevel, QuantMode, QuantScheme, QuantizationStrategy, + SymmetricQuantization, }, }; use libc::c_void; @@ -324,14 +324,19 @@ pub struct TchQTensor { /// The quantized tensor. pub qtensor: TchTensor, /// The quantization scheme. - pub scheme: QuantizationScheme, + pub scheme: QuantScheme, } impl TchQTensor { /// Returns the quantization strategy, including quantization parameters, for the given tensor. pub fn strategy(&self) -> QuantizationStrategy { match &self.scheme { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { let scale = self.qtensor.tensor.q_scale(); QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init( scale as f32, @@ -352,7 +357,7 @@ impl TensorMetadata for TchQTensor { } impl QTensorPrimitive for TchQTensor { - fn scheme(&self) -> &QuantizationScheme { + fn scheme(&self) -> &QuantScheme { &self.scheme } } @@ -363,7 +368,7 @@ mod tests { use super::*; use burn_tensor::ops::QTensorOps; - use burn_tensor::quantization::{QuantizationMode, QuantizationParametersPrimitive}; + use burn_tensor::quantization::QuantizationParametersPrimitive; use burn_tensor::{Distribution, Tensor, TensorPrimitive}; use rand::SeedableRng; use rand::prelude::StdRng; @@ -428,8 +433,7 @@ mod tests { fn should_support_qtensor_strategy() { let tensor = TchTensor::from_data::(TensorData::from([-1.8, -1.0, 0.0, 0.5]), tch::Device::Cpu); - let scheme = - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); + let scheme = QuantScheme::default(); let qparams = QuantizationParametersPrimitive::> { scale: TchTensor::from_data::(TensorData::from([0.009_019_608]), tch::Device::Cpu), offset: Some(TchTensor::from_data::( diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index a44a00dc56..6916d4ce32 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -1,6 +1,6 @@ use crate::Tensor; use crate::check::TensorCheck; -use crate::quantization::{QuantizationParameters, QuantizationScheme}; +use crate::quantization::{QTensorPrimitive, QuantScheme, QuantizationParameters}; use crate::tensor::backend::Backend; use crate::tensor::stats; use crate::tensor::{Distribution, TensorData}; @@ -206,12 +206,24 @@ where check!(TensorCheck::matmul(&self, &other)); match (self.primitive, other.primitive) { (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { - Self::new(TensorPrimitive::QFloat(B::q_matmul(lhs, rhs))) + Self::new(B::q_matmul(lhs, rhs)) + } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::Float(rhs)) => Self::new( + TensorPrimitive::Float(B::float_matmul(B::dequantize(lhs), rhs)), + ), + (TensorPrimitive::Float(lhs), TensorPrimitive::QFloat(rhs)) => { + // NOTE: in a typical workflow with linear layers (e.g., transformers), the rhs + // represents the weights. + // + // Since `q_matmul(lhs_f16, rhs_quant)` isn't currently supported, in practice it makes + // more sense to re-quantize the input back. Better usability. + // + // This might change in the future (dequantize on read in fusion?). + Self::new(B::q_matmul(B::quantize_dynamic(lhs, rhs.scheme()), rhs)) + } + (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { + Self::new(TensorPrimitive::Float(B::float_matmul(lhs, rhs))) } - (lhs, rhs) => Self::new(TensorPrimitive::Float(B::float_matmul( - lhs.tensor(), - rhs.tensor(), - ))), } } @@ -326,7 +338,7 @@ where /// The quantized tensor. pub fn quantize( self, - scheme: &QuantizationScheme, + scheme: &QuantScheme, qparams: QuantizationParameters, ) -> Tensor { Tensor::new(TensorPrimitive::QFloat(B::quantize( @@ -348,7 +360,7 @@ where /// /// # Notes /// This uses [min-max calibration](crate::quantization::Calibration::MinMax). - pub fn quantize_dynamic(self, scheme: &QuantizationScheme) -> Tensor { + pub fn quantize_dynamic(self, scheme: &QuantScheme) -> Tensor { Tensor::new(TensorPrimitive::QFloat(B::quantize_dynamic( self.primitive.tensor(), scheme, diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 36ec5984b1..0074d31939 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -3839,9 +3839,7 @@ impl Numeric for Float { (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::float_add(lhs, rhs)) } - (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { - TensorPrimitive::QFloat(B::q_add(lhs, rhs)) - } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_add(lhs, rhs), _ => panic!("Primitive type mismatch for lhs and rhs"), } } @@ -3850,9 +3848,7 @@ impl Numeric for Float { TensorPrimitive::Float(lhs) => { TensorPrimitive::Float(B::float_add_scalar(lhs, rhs.elem())) } - TensorPrimitive::QFloat(lhs) => { - TensorPrimitive::QFloat(B::q_add_scalar(lhs, rhs.elem())) - } + TensorPrimitive::QFloat(lhs) => B::q_add_scalar(lhs, rhs.elem()), } } fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> >::Primitive { @@ -3860,9 +3856,7 @@ impl Numeric for Float { (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::float_sub(lhs, rhs)) } - (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { - TensorPrimitive::QFloat(B::q_sub(lhs, rhs)) - } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_sub(lhs, rhs), _ => panic!("Primitive type mismatch for lhs and rhs"), } } @@ -3871,9 +3865,7 @@ impl Numeric for Float { TensorPrimitive::Float(lhs) => { TensorPrimitive::Float(B::float_sub_scalar(lhs, rhs.elem())) } - TensorPrimitive::QFloat(lhs) => { - TensorPrimitive::QFloat(B::q_sub_scalar(lhs, rhs.elem())) - } + TensorPrimitive::QFloat(lhs) => B::q_sub_scalar(lhs, rhs.elem()), } } fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> >::Primitive { @@ -3881,9 +3873,7 @@ impl Numeric for Float { (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::float_div(lhs, rhs)) } - (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { - TensorPrimitive::QFloat(B::q_div(lhs, rhs)) - } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_div(lhs, rhs), _ => panic!("Primitive type mismatch for lhs and rhs"), } } @@ -3892,43 +3882,24 @@ impl Numeric for Float { TensorPrimitive::Float(lhs) => { TensorPrimitive::Float(B::float_div_scalar(lhs, rhs.elem())) } - TensorPrimitive::QFloat(lhs) => { - TensorPrimitive::QFloat(B::q_div_scalar(lhs, rhs.elem())) - } + TensorPrimitive::QFloat(lhs) => B::q_div_scalar(lhs, rhs.elem()), } } fn remainder( lhs: Self::Primitive, rhs: Self::Primitive, ) -> >::Primitive { - match (lhs, rhs) { - (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { - TensorPrimitive::Float(B::float_remainder(lhs, rhs)) - } - (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { - TensorPrimitive::QFloat(B::q_remainder(lhs, rhs)) - } - _ => panic!("Primitive type mismatch for lhs and rhs"), - } + TensorPrimitive::Float(B::float_remainder(lhs.tensor(), rhs.tensor())) } fn remainder_scalar(lhs: Self::Primitive, rhs: E) -> Self::Primitive { - match lhs { - TensorPrimitive::Float(lhs) => { - TensorPrimitive::Float(B::float_remainder_scalar(lhs, rhs.elem())) - } - TensorPrimitive::QFloat(lhs) => { - TensorPrimitive::QFloat(B::q_remainder_scalar(lhs, rhs.elem())) - } - } + TensorPrimitive::Float(B::float_remainder_scalar(lhs.tensor(), rhs.elem())) } fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> >::Primitive { match (lhs, rhs) { (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::float_mul(lhs, rhs)) } - (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { - TensorPrimitive::QFloat(B::q_mul(lhs, rhs)) - } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_mul(lhs, rhs), _ => panic!("Primitive type mismatch for lhs and rhs"), } } @@ -3937,15 +3908,13 @@ impl Numeric for Float { TensorPrimitive::Float(lhs) => { TensorPrimitive::Float(B::float_mul_scalar(lhs, rhs.elem())) } - TensorPrimitive::QFloat(lhs) => { - TensorPrimitive::QFloat(B::q_mul_scalar(lhs, rhs.elem())) - } + TensorPrimitive::QFloat(lhs) => B::q_mul_scalar(lhs, rhs.elem()), } } fn neg(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_neg(tensor)), - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_neg(tensor)), + TensorPrimitive::QFloat(tensor) => B::q_neg(tensor), } } fn zeros(shape: Shape, device: &B::Device) -> Self::Primitive { @@ -3966,21 +3935,21 @@ impl Numeric for Float { fn sum(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum(tensor)), - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum(tensor)), + TensorPrimitive::QFloat(tensor) => B::q_sum(tensor), } } fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_sum_dim(tensor, dim)), - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_sum_dim(tensor, dim)), + TensorPrimitive::QFloat(tensor) => B::q_sum_dim(tensor, dim), } } fn prod(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_prod(tensor)), - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod(tensor)), + TensorPrimitive::QFloat(tensor) => B::q_prod(tensor), } } @@ -3989,14 +3958,14 @@ impl Numeric for Float { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_prod_dim(tensor, dim)) } - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_prod_dim(tensor, dim)), + TensorPrimitive::QFloat(tensor) => B::q_prod_dim(tensor, dim), } } fn mean(tensor: Self::Primitive) -> Self::Primitive { match tensor { TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_mean(tensor)), - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean(tensor)), + TensorPrimitive::QFloat(tensor) => B::q_mean(tensor), } } @@ -4005,7 +3974,7 @@ impl Numeric for Float { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_mean_dim(tensor, dim)) } - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_mean_dim(tensor, dim)), + TensorPrimitive::QFloat(tensor) => B::q_mean_dim(tensor, dim), } } @@ -4052,15 +4021,7 @@ impl Numeric for Float { mask: B::BoolTensorPrimitive, source: Self::Primitive, ) -> Self::Primitive { - match (tensor, source) { - (TensorPrimitive::Float(tensor), TensorPrimitive::Float(source)) => { - TensorPrimitive::Float(B::float_mask_where(tensor, mask, source)) - } - (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(source)) => { - TensorPrimitive::QFloat(B::q_mask_where(tensor, mask, source)) - } - _ => panic!("Primitive type mismatch for tensor and source"), - } + TensorPrimitive::Float(B::float_mask_where(tensor.tensor(), mask, source.tensor())) } fn mask_fill( @@ -4068,14 +4029,7 @@ impl Numeric for Float { mask: B::BoolTensorPrimitive, value: Self::Elem, ) -> Self::Primitive { - match tensor { - TensorPrimitive::Float(tensor) => { - TensorPrimitive::Float(B::float_mask_fill(tensor, mask, value)) - } - TensorPrimitive::QFloat(tensor) => { - TensorPrimitive::QFloat(B::q_mask_fill(tensor, mask, value)) - } - } + TensorPrimitive::Float(B::float_mask_fill(tensor.tensor(), mask, value)) } fn select(tensor: Self::Primitive, dim: usize, indices: Tensor) -> Self::Primitive { @@ -4095,20 +4049,13 @@ impl Numeric for Float { indices: Tensor, values: Self::Primitive, ) -> Self::Primitive { - match (tensor, values) { - (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => { - TensorPrimitive::Float(B::float_select_assign( - tensor, - dim, - indices.primitive, - values, - )) - } - (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => { - TensorPrimitive::QFloat(B::q_select_assign(tensor, dim, indices.primitive, values)) - } - _ => panic!("Primitive type mismatch for tensor and values"), - } + // Select assign is ambiguous for QFloat + TensorPrimitive::Float(B::float_select_assign( + tensor.tensor(), + dim, + indices.primitive, + values.tensor(), + )) } fn gather( @@ -4132,15 +4079,12 @@ impl Numeric for Float { indices: B::IntTensorPrimitive, values: Self::Primitive, ) -> Self::Primitive { - match (tensor, values) { - (TensorPrimitive::Float(tensor), TensorPrimitive::Float(values)) => { - TensorPrimitive::Float(B::float_scatter(dim, tensor, indices, values)) - } - (TensorPrimitive::QFloat(tensor), TensorPrimitive::QFloat(values)) => { - TensorPrimitive::QFloat(B::q_scatter(dim, tensor, indices, values)) - } - _ => panic!("Primitive type mismatch for tensor and values"), - } + TensorPrimitive::Float(B::float_scatter( + dim, + tensor.tensor(), + indices, + values.tensor(), + )) } fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor { @@ -4222,9 +4166,7 @@ impl Numeric for Float { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_clamp(tensor, min, max)) } - TensorPrimitive::QFloat(tensor) => { - TensorPrimitive::QFloat(B::q_clamp(tensor, min, max)) - } + TensorPrimitive::QFloat(tensor) => B::q_clamp(tensor, min, max), } } @@ -4233,7 +4175,7 @@ impl Numeric for Float { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_clamp_min(tensor, min)) } - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_min(tensor, min)), + TensorPrimitive::QFloat(tensor) => B::q_clamp_min(tensor, min), } } @@ -4242,7 +4184,7 @@ impl Numeric for Float { TensorPrimitive::Float(tensor) => { TensorPrimitive::Float(B::float_clamp_max(tensor, max)) } - TensorPrimitive::QFloat(tensor) => TensorPrimitive::QFloat(B::q_clamp_max(tensor, max)), + TensorPrimitive::QFloat(tensor) => B::q_clamp_max(tensor, max), } } @@ -4258,9 +4200,7 @@ impl Numeric for Float { (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::float_powf(lhs, rhs)) } - (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { - TensorPrimitive::QFloat(B::q_powf(lhs, rhs)) - } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_powf(lhs, rhs), _ => panic!("Primitive type mismatch for lhs and rhs"), } } @@ -4270,9 +4210,7 @@ impl Numeric for Float { TensorPrimitive::Float(lhs) => { TensorPrimitive::Float(B::float_powf_scalar(lhs, rhs.elem())) } - TensorPrimitive::QFloat(lhs) => { - TensorPrimitive::QFloat(B::q_powf_scalar(lhs, rhs.elem())) - } + TensorPrimitive::QFloat(lhs) => B::q_powf_scalar(lhs, rhs.elem()), } } @@ -4281,9 +4219,7 @@ impl Numeric for Float { (TensorPrimitive::Float(lhs), TensorPrimitive::Float(rhs)) => { TensorPrimitive::Float(B::float_powf(lhs, rhs)) } - (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => { - TensorPrimitive::QFloat(B::q_powf(lhs, rhs)) - } + (TensorPrimitive::QFloat(lhs), TensorPrimitive::QFloat(rhs)) => B::q_powf(lhs, rhs), _ => panic!("Primitive type mismatch for lhs and rhs"), } } @@ -4293,9 +4229,7 @@ impl Numeric for Float { TensorPrimitive::Float(lhs) => { TensorPrimitive::Float(B::float_powi_scalar(lhs, rhs.elem())) } - TensorPrimitive::QFloat(lhs) => { - TensorPrimitive::QFloat(B::q_powi_scalar(lhs, rhs.elem())) - } + TensorPrimitive::QFloat(lhs) => B::q_powi_scalar(lhs, rhs.elem()), } } diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 9f66a2b5ef..6f8a78cc9a 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -10,12 +10,14 @@ use num_traits::{Float, ToPrimitive}; use crate::{ DType, Distribution, Element, ElementConversion, - quantization::{QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes}, + quantization::{QuantInputType, QuantScheme, QuantizationStrategy, QuantizedBytes}, tensor::bytes::Bytes, }; use rand::RngCore; +use super::quantization::{QuantLevel, QuantMode}; + /// The things that can go wrong when manipulating tensor data. #[derive(Debug)] pub enum DataError { @@ -250,7 +252,12 @@ impl TensorData { // bool is a byte value equal to either 0 or 1 DType::Bool => Box::new(self.bytes.iter().map(|e| e.elem::())), DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { // Quantized int8 values let q_bytes = QuantizedBytes { bytes: self.bytes.clone(), @@ -535,12 +542,10 @@ impl TensorData { } else { panic!("Quantized data differs from other not quantized data") }; - match (q, q_other) { - ( - QuantizationScheme::PerTensor(mode, QuantizationType::QInt8), - QuantizationScheme::PerTensor(mode_other, QuantizationType::QInt8), - ) if mode == mode_other => self.assert_eq_elem::(other), - _ => panic!("Quantization schemes differ ({:?} != {:?})", q, q_other), + if q == q_other { + self.assert_eq_elem::(other) + } else { + panic!("Quantization schemes differ ({:?} != {:?})", q, q_other) } } } @@ -808,7 +813,12 @@ impl core::fmt::Display for TensorData { DType::U8 => format!("{:?}", self.as_slice::().unwrap()), DType::Bool => format!("{:?}", self.as_slice::().unwrap()), DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { format!("{:?} {scheme:?}", self.iter::().collect::>()) } }, diff --git a/crates/burn-tensor/src/tensor/element/base.rs b/crates/burn-tensor/src/tensor/element/base.rs index 122e50ed1e..d83a85d0d6 100644 --- a/crates/burn-tensor/src/tensor/element/base.rs +++ b/crates/burn-tensor/src/tensor/element/base.rs @@ -3,7 +3,7 @@ use core::cmp::Ordering; use crate::{ Distribution, cast::ToElement, - quantization::{QuantizationScheme, QuantizationType}, + quantization::{QuantInputType, QuantScheme}, }; #[cfg(feature = "cubecl")] use cubecl::flex32; @@ -316,7 +316,7 @@ pub enum DType { U16, U8, Bool, - QFloat(QuantizationScheme), + QFloat(QuantScheme), } #[cfg(feature = "cubecl")] @@ -366,10 +366,8 @@ impl DType { DType::U16 => core::mem::size_of::(), DType::U8 => core::mem::size_of::(), DType::Bool => core::mem::size_of::(), - DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => { - core::mem::size_of::() - } + DType::QFloat(scheme) => match scheme.q_type { + QuantInputType::QInt8 => core::mem::size_of::(), }, } } diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 8b1b7a27d5..db63870aea 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -2,16 +2,19 @@ use alloc::vec::Vec; use core::ops::Range; use crate::{ - Device, Shape, TensorData, TensorMetadata, + Device, Shape, TensorData, TensorMetadata, TensorPrimitive, backend::Backend, quantization::{ - Calibration, QTensorPrimitive, QuantizationParametersPrimitive, QuantizationScheme, + Calibration, QTensorPrimitive, QuantPropagation, QuantScheme, + QuantizationParametersPrimitive, }, }; use super::{BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, QuantizedTensor}; -/// Automatically applies dequantization -> float operation -> quantization. +/// Automatically applies `dequantization -> float operation -> quantization`. +/// +/// Used for tensor ops that should always return a quantized output. #[macro_export] macro_rules! dequant_op_quant { // Binary tensor float op w/ lhs & rhs @@ -42,8 +45,73 @@ macro_rules! dequant_op_quant { }}; } -/// Quantized Tensor API for basic operations, see [tensor](crate::Tensor) -/// for documentation on each function. +/// Automatically applies `dequantization -> float operation [-> quantization]`. +/// +/// The output quantization step is optional. +/// It is only performed when the input quantization scheme is propagated. +#[macro_export] +macro_rules! dequant_op_flow { + // Binary tensor float op w/ lhs & rhs + ( + ty $ty:ty, float_op $float_op:expr, $t1:expr, $t2:expr + ) => {{ + // Heuristic: prioritize lhs scheme + let scheme = $t1.scheme().clone(); + + let t1_f = <$ty>::dequantize($t1); + let t2_f = <$ty>::dequantize($t2); + #[allow(clippy::redundant_closure_call)] + let out_f = $float_op(t1_f, t2_f); + + match scheme.propagation { + QuantPropagation::Propagate => { + TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme)) + } + QuantPropagation::Inhibit => TensorPrimitive::Float(out_f), + } + }}; + // Unary tensor float op + ( + ty $ty:ty, float_op $float_op:expr, $tensor:expr + ) => {{ + let scheme = $tensor.scheme().clone(); + + let tensor_f = <$ty>::dequantize($tensor); + #[allow(clippy::redundant_closure_call)] + let out_f = $float_op(tensor_f); + + match scheme.propagation { + QuantPropagation::Propagate => { + TensorPrimitive::QFloat(<$ty>::quantize_dynamic(out_f, &scheme)) + } + QuantPropagation::Inhibit => TensorPrimitive::Float(out_f), + } + }}; +} + +/// Operations on quantized tensors. +/// +/// # Return Type Semantics +/// +/// The return type of each operation indicates how quantization is handled: +/// +/// ## [`QuantizedTensor`] +/// If the method returns a `QuantizedTensor`, the operation is expected to preserve the quantized +/// representation. Implementations should avoid dequantizing when possible to maintain performance. +/// For example, shape or layout changes such as expand or transpose preserve quantization. +/// +/// *Note: while this currently doesn't affect the quantized tensor parameters (only per-tensor is +/// supported at the time of writing), other quantization levels (e.g., per-block) may require re-ordering +/// the quantization parameters to match the new layout.* +/// +/// +/// ## [`TensorPrimitive`] +/// If the method returns a `TensorPrimitive` enum, the return type should align with propagation +/// strategy specified in the quantization scheme. The output should remain quantized ([`TensorPrimitive::QFloat`]) +/// returned in floating-point form ([`TensorPrimitive::Float`]). +/// +/// This distinction allows for fine-grained control over mixed-precision flows while still operating +/// through a unified API. pub trait QTensorOps { /// Creates a new tensor from the data structure. /// @@ -60,12 +128,12 @@ pub trait QTensorOps { /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. fn quantize( tensor: FloatTensor, - scheme: &QuantizationScheme, + scheme: &QuantScheme, qparams: QuantizationParametersPrimitive, ) -> QuantizedTensor; /// Dynamically convert the tensor to a lower precision data type based on the quantization scheme. - fn quantize_dynamic(tensor: FloatTensor, scheme: &QuantizationScheme) -> QuantizedTensor { + fn quantize_dynamic(tensor: FloatTensor, scheme: &QuantScheme) -> QuantizedTensor { // Dynamically compute min/max tensor range and qparams before quantizing let (min, max) = scheme.compute_range_primitive::(tensor.clone(), &Calibration::MinMax); let qparams = scheme.compute_q_params_primitive(min, max); @@ -139,6 +207,110 @@ pub trait QTensorOps { false } + /// Broadcasts the `tensor` to the given `shape`. + fn q_expand(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor; + + /// Transposes a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to transpose. + /// + /// # Returns + /// + /// The transposed tensor. + fn q_transpose(tensor: QuantizedTensor) -> QuantizedTensor { + let ndims = tensor.shape().num_dims(); + Self::q_swap_dims(tensor, ndims - 2, ndims - 1) + } + + /// Swaps two dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to swap the dimensions of. + /// * `dim1` - The first dimension to swap. + /// * `dim2` - The second dimension to swap. + /// + /// # Returns + /// + /// The tensor with the dimensions swapped. + fn q_swap_dims(tensor: QuantizedTensor, dim1: usize, dim2: usize) -> QuantizedTensor; + + /// Permutes the dimensions of a tensor. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to permute the dimensions of. + /// * `axes` - The new order of the dimensions. + /// # Returns + /// + /// The tensor with the dimensions permuted. + fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor; + + /// Reverse the order of elements in a tensor along the given axes. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to reverse. + /// * `axes` - The axes to reverse. + /// + /// The tensor with the elements reversed. + fn q_flip(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor; + + /// Select tensor elements along the given dimension corresponding for the given indices. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `dim` - The dimension to select from. + /// * `indices` - The indices to select. + /// + /// # Returns + /// + /// The selected elements. + fn q_select( + tensor: QuantizedTensor, + dim: usize, + indices: IntTensor, + ) -> QuantizedTensor; + + /// Select tensor elements corresponding for the given ranges. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to select from. + /// * `ranges` - The ranges to select. + /// + /// # Returns + /// + /// The selected elements in a new tensor. + fn q_slice(tensor: QuantizedTensor, ranges: &[Range]) -> QuantizedTensor; + + /// Gather elements from a tensor. + /// + /// # Arguments + /// + /// * `dim` - The dimension to gather from. + /// * `tensor` - The tensor to gather from. + /// * `indices` - The indices to gather. + /// + /// # Returns + /// + /// The gathered elements. + fn q_gather( + dim: usize, + tensor: QuantizedTensor, + indices: IntTensor, + ) -> QuantizedTensor { + // Default implementation. Backends can gather on the quantized values when supported. + dequant_op_quant!( + ty Self, + float_op |tensor| B::float_gather(dim, tensor, indices), + tensor + ) + } + /// Repeat the tensor along the given dimension. /// /// # Arguments @@ -168,8 +340,8 @@ pub trait QTensorOps { /// # Returns /// /// The result of adding the two tensors together. - fn q_add(lhs: QuantizedTensor, rhs: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_add(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_add(lhs, rhs), lhs, @@ -187,13 +359,12 @@ pub trait QTensorOps { /// # Returns /// /// The result of adding the scalar to the tensor. - fn q_add_scalar(lhs: QuantizedTensor, rhs: FloatElem) -> QuantizedTensor { - let scheme = *lhs.scheme(); - - let lhs_f = Self::dequantize(lhs); - let out_f = B::float_add_scalar(lhs_f, rhs); - - Self::quantize_dynamic(out_f, &scheme) + fn q_add_scalar(lhs: QuantizedTensor, rhs: FloatElem) -> TensorPrimitive { + dequant_op_flow!( + ty Self, + float_op |tensor| B::float_add_scalar(tensor, rhs), + lhs + ) } /// Clamps a tensor under a minimum value. @@ -206,13 +377,12 @@ pub trait QTensorOps { /// # Returns /// /// The clamped tensor. - fn q_clamp_min(tensor: QuantizedTensor, min: FloatElem) -> QuantizedTensor { - let scheme = *tensor.scheme(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_clamp_min(tensor_f, min); - - Self::quantize_dynamic(out_f, &scheme) + fn q_clamp_min(tensor: QuantizedTensor, min: FloatElem) -> TensorPrimitive { + dequant_op_flow!( + ty Self, + float_op |tensor| B::float_clamp_min(tensor, min), + tensor + ) } /// Clamps a tensor over a maximum value. @@ -225,13 +395,12 @@ pub trait QTensorOps { /// # Returns /// /// The clamped tensor. - fn q_clamp_max(tensor: QuantizedTensor, max: FloatElem) -> QuantizedTensor { - let scheme = *tensor.scheme(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_clamp_max(tensor_f, max); - - Self::quantize_dynamic(out_f, &scheme) + fn q_clamp_max(tensor: QuantizedTensor, max: FloatElem) -> TensorPrimitive { + dequant_op_flow!( + ty Self, + float_op |tensor| B::float_clamp_max(tensor, max), + tensor + ) } /// Clamps a tensor between a minimum and maximum value. @@ -249,13 +418,12 @@ pub trait QTensorOps { tensor: QuantizedTensor, min: FloatElem, max: FloatElem, - ) -> QuantizedTensor { - let scheme = *tensor.scheme(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_clamp(tensor_f, min, max); - - Self::quantize_dynamic(out_f, &scheme) + ) -> TensorPrimitive { + dequant_op_flow!( + ty Self, + float_op |tensor| B::float_clamp(tensor, min, max), + tensor + ) } /// Subtracts two tensors. @@ -268,8 +436,8 @@ pub trait QTensorOps { /// # Returns /// /// The result of subtracting the two tensors. - fn q_sub(lhs: QuantizedTensor, rhs: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_sub(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_sub(lhs, rhs), lhs, @@ -287,18 +455,17 @@ pub trait QTensorOps { /// # Returns /// /// The result of subtracting the scalar from the tensor. - fn q_sub_scalar(lhs: QuantizedTensor, rhs: FloatElem) -> QuantizedTensor { - let scheme = *lhs.scheme(); - - let lhs_f = Self::dequantize(lhs); - let out_f = B::float_sub_scalar(lhs_f, rhs); - - Self::quantize_dynamic(out_f, &scheme) + fn q_sub_scalar(lhs: QuantizedTensor, rhs: FloatElem) -> TensorPrimitive { + dequant_op_flow!( + ty Self, + float_op |tensor| B::float_sub_scalar(tensor, rhs), + lhs + ) } /// Multiplies two tensors together element-wise. - fn q_mul(lhs: QuantizedTensor, rhs: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_mul(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_mul(lhs, rhs), lhs, @@ -316,13 +483,12 @@ pub trait QTensorOps { /// # Returns /// /// The result of multiplying the tensor by the scalar. - fn q_mul_scalar(lhs: QuantizedTensor, rhs: FloatElem) -> QuantizedTensor { - let scheme = *lhs.scheme(); - - let lhs_f = Self::dequantize(lhs); - let out_f = B::float_mul_scalar(lhs_f, rhs); - - Self::quantize_dynamic(out_f, &scheme) + fn q_mul_scalar(lhs: QuantizedTensor, rhs: FloatElem) -> TensorPrimitive { + dequant_op_flow!( + ty Self, + float_op |tensor| B::float_mul_scalar(tensor, rhs), + lhs + ) } /// Divides two tensors element-wise. @@ -335,8 +501,8 @@ pub trait QTensorOps { /// # Returns /// /// The result of dividing the two tensors. - fn q_div(lhs: QuantizedTensor, rhs: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_div(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_div(lhs, rhs), lhs, @@ -354,52 +520,14 @@ pub trait QTensorOps { /// # Returns /// /// The result of dividing the tensor by the scalar. - fn q_div_scalar(lhs: QuantizedTensor, rhs: FloatElem) -> QuantizedTensor { - let scheme = *lhs.scheme(); - - let lhs_f = Self::dequantize(lhs); - let out_f = B::float_div_scalar(lhs_f, rhs); - - Self::quantize_dynamic(out_f, &scheme) - } - - /// Computes the remainder of division between two tensors element-wise. - /// - /// # Arguments - /// - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side tensor. - /// - /// # Returns - /// - /// The element-wise remainder when dividing `lhs` by `rhs`. - fn q_remainder(lhs: QuantizedTensor, rhs: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_div_scalar(lhs: QuantizedTensor, rhs: FloatElem) -> TensorPrimitive { + dequant_op_flow!( ty Self, - float_op |lhs, rhs| B::float_remainder(lhs, rhs), - lhs, - rhs + float_op |tensor| B::float_div_scalar(tensor, rhs), + lhs ) } - /// Computes the modulus of a tensor given a scalar. - /// - /// # Arguments - /// * `lhs` - The left hand side tensor. - /// * `rhs` - The right hand side scalar. - /// - /// # Returns - /// - /// The result of applying the modulus of the scalar to the tensor. - fn q_remainder_scalar(lhs: QuantizedTensor, rhs: FloatElem) -> QuantizedTensor { - let scheme = *lhs.scheme(); - - let lhs_f = Self::dequantize(lhs); - let out_f = B::float_remainder_scalar(lhs_f, rhs); - - Self::quantize_dynamic(out_f, &scheme) - } - /// Multiplies two tensors together using matrix multiplication. /// /// # Arguments @@ -410,8 +538,8 @@ pub trait QTensorOps { /// # Returns /// /// The result of multiplying the two tensors together using matrix multiplication. - fn q_matmul(lhs: QuantizedTensor, rhs: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_matmul(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_matmul(lhs, rhs), lhs, @@ -420,246 +548,19 @@ pub trait QTensorOps { } /// Negates a tensor element-wise. - fn q_neg(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = *tensor.scheme(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_neg(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) - } - - /// Calculates the reciprocals element-wise - fn q_recip(tensor: QuantizedTensor) -> QuantizedTensor { - let scheme = *tensor.scheme(); - - let tensor_f = Self::dequantize(tensor); - let out_f = B::float_recip(tensor_f); - - Self::quantize_dynamic(out_f, &scheme) - } - - /// Transposes a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to transpose. - /// - /// # Returns - /// - /// The transposed tensor. - fn q_transpose(tensor: QuantizedTensor) -> QuantizedTensor { - let ndims = tensor.shape().num_dims(); - Self::q_swap_dims(tensor, ndims - 2, ndims - 1) - } - - /// Swaps two dimensions of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to swap the dimensions of. - /// * `dim1` - The first dimension to swap. - /// * `dim2` - The second dimension to swap. - /// - /// # Returns - /// - /// The tensor with the dimensions swapped. - fn q_swap_dims(tensor: QuantizedTensor, dim1: usize, dim2: usize) -> QuantizedTensor; - - /// Permutes the dimensions of a tensor. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to permute the dimensions of. - /// * `axes` - The new order of the dimensions. - /// # Returns - /// - /// The tensor with the dimensions permuted. - fn q_permute(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor; - - /// Reverse the order of elements in a tensor along the given axes. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to reverse. - /// * `axes` - The axes to reverse. - /// - /// The tensor with the elements reversed. - fn q_flip(tensor: QuantizedTensor, axes: &[usize]) -> QuantizedTensor; - - /// Gather elements from a tensor. - /// - /// # Arguments - /// - /// * `dim` - The dimension to gather from. - /// * `tensor` - The tensor to gather from. - /// * `indices` - The indices to gather. - /// - /// # Returns - /// - /// The gathered elements. - fn q_gather( - dim: usize, - tensor: QuantizedTensor, - indices: IntTensor, - ) -> QuantizedTensor { - // Default implementation. Backends can gather on the quantized values when supported. - dequant_op_quant!( + fn q_neg(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, - float_op |tensor| B::float_gather(dim, tensor, indices), + float_op |tensor| B::float_neg(tensor), tensor ) } - /// Scatter elements into a tensor. - /// - /// # Arguments - /// - /// * `dim` - The dimension to scatter into. - /// * `tensor` - The tensor to scatter into. - /// * `indices` - The indices to scatter into. - /// * `value` - The value to scatter. - /// - /// # Returns - /// - /// The tensor with the scattered elements. - fn q_scatter( - dim: usize, - tensor: QuantizedTensor, - indices: IntTensor, - value: QuantizedTensor, - ) -> QuantizedTensor { - dequant_op_quant!( - ty Self, - float_op |tensor, value| B::float_scatter(dim, tensor, indices, value), - tensor, - value - ) - } - - /// Select tensor elements along the given dimension corresponding for the given indices. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices to select. - /// - /// # Returns - /// - /// The selected elements. - fn q_select( - tensor: QuantizedTensor, - dim: usize, - indices: IntTensor, - ) -> QuantizedTensor; - - /// Assign the selected elements along the given dimension corresponding for the given indices - /// to the given value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `dim` - The dimension to select from. - /// * `indices` - The indices to select. - /// * `value` - The value to assign. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn q_select_assign( - tensor: QuantizedTensor, - dim: usize, - indices: IntTensor, - value: QuantizedTensor, - ) -> QuantizedTensor { - dequant_op_quant!( - ty Self, - float_op |tensor, value| B::float_select_assign(tensor, dim, indices, value), - tensor, - value - ) - } - - /// Select tensor elements corresponding for the given ranges. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `ranges` - The ranges to select. - /// - /// # Returns - /// - /// The selected elements in a new tensor. - fn q_slice(tensor: QuantizedTensor, ranges: &[Range]) -> QuantizedTensor; - - /// Assign the selected elements corresponding for the given ranges to the given value. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `ranges` - The ranges to select. - /// * `value` - The value to assign. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn q_slice_assign( - tensor: QuantizedTensor, - ranges: &[Range], - value: QuantizedTensor, - ) -> QuantizedTensor { - dequant_op_quant!( - ty Self, - float_op |tensor, value| B::float_slice_assign(tensor, ranges, value), - tensor, - value - ) - } - - /// Update the given tensor with the value tensor where the mask is true. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `mask` - The boolean mask to select with. - /// * `value` - The value to assign to the selected elements from the value tensor. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn q_mask_where( - tensor: QuantizedTensor, - mask: BoolTensor, - value: QuantizedTensor, - ) -> QuantizedTensor { - dequant_op_quant!( - ty Self, - float_op |tensor, value| B::float_mask_where(tensor, mask, value), - tensor, - value - ) - } - - /// Update the given tensor with the value where the mask is true. - /// - /// # Arguments - /// - /// * `tensor` - The tensor to select from. - /// * `mask` - The boolean mask to select with. - /// * `value` - The value to assign to the selected elements. - /// - /// # Returns - /// - /// The tensor with the selected elements assigned to the given value. - fn q_mask_fill( - tensor: QuantizedTensor, - mask: BoolTensor, - value: FloatElem, - ) -> QuantizedTensor { - dequant_op_quant!( + /// Calculates the reciprocals element-wise + fn q_recip(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, - float_op |tensor| B::float_mask_fill(tensor, mask, value), + float_op |tensor| B::float_recip(tensor), tensor ) } @@ -673,8 +574,8 @@ pub trait QTensorOps { /// # Returns /// /// A scalar tensor with the sum of all elements in `tensor`. - fn q_sum(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_sum(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_sum(tensor), tensor @@ -691,8 +592,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the sum of all elements in `tensor` along `dim`. - fn q_sum_dim(tensor: QuantizedTensor, dim: usize) -> QuantizedTensor { - dequant_op_quant!( + fn q_sum_dim(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_sum_dim(tensor, dim), tensor @@ -708,8 +609,8 @@ pub trait QTensorOps { /// # Returns /// /// A scalar tensor with the product of all elements in `tensor`. - fn q_prod(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_prod(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_prod(tensor), tensor @@ -725,8 +626,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the product of all elements in `tensor` along `dim`. - fn q_prod_dim(tensor: QuantizedTensor, dim: usize) -> QuantizedTensor { - dequant_op_quant!( + fn q_prod_dim(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_prod_dim(tensor, dim), tensor @@ -742,8 +643,8 @@ pub trait QTensorOps { /// # Returns /// /// A scalar tensor with the mean of all elements in `tensor`. - fn q_mean(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_mean(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_mean(tensor), tensor @@ -760,8 +661,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the mean of all elements in `tensor` along `dim`. - fn q_mean_dim(tensor: QuantizedTensor, dim: usize) -> QuantizedTensor { - dequant_op_quant!( + fn q_mean_dim(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_mean_dim(tensor, dim), tensor @@ -777,8 +678,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with exponential values. - fn q_exp(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_exp(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_exp(tensor), tensor @@ -794,8 +695,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with natural logarithm values. - fn q_log(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_log(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_log(tensor), tensor @@ -811,8 +712,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi). - fn q_log1p(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_log1p(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_log1p(tensor), tensor @@ -829,8 +730,8 @@ pub trait QTensorOps { /// # Returns /// /// The elements of `lhs` raised to the power of the elements of `rhs`. - fn q_powf(lhs: QuantizedTensor, rhs: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_powf(lhs: QuantizedTensor, rhs: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |lhs, rhs| B::float_powf(lhs, rhs), lhs, @@ -848,8 +749,8 @@ pub trait QTensorOps { /// # Returns /// /// The elements of `lhs` raised to the value of `rhs`. Result is an IntTensor. - fn q_powi(lhs: QuantizedTensor, rhs: IntTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_powi(lhs: QuantizedTensor, rhs: IntTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_powi(tensor, rhs), lhs @@ -866,8 +767,8 @@ pub trait QTensorOps { /// # Returns /// /// The elements of `lhs` raised to the value of `rhs`. - fn q_powi_scalar(lhs: QuantizedTensor, rhs: IntElem) -> QuantizedTensor { - dequant_op_quant!( + fn q_powi_scalar(lhs: QuantizedTensor, rhs: IntElem) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_powi_scalar(tensor, rhs), lhs @@ -884,8 +785,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with values raised to the power of `value`. - fn q_powf_scalar(tensor: QuantizedTensor, value: f32) -> QuantizedTensor { - dequant_op_quant!( + fn q_powf_scalar(tensor: QuantizedTensor, value: f32) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_powf_scalar(tensor, value), tensor @@ -901,8 +802,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with square root values. - fn q_sqrt(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_sqrt(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_sqrt(tensor), tensor @@ -935,8 +836,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with cosine values. - fn q_cos(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_cos(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_cos(tensor), tensor @@ -952,8 +853,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with sine values. - fn q_sin(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_sin(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_sin(tensor), tensor @@ -969,8 +870,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with tangent values. - fn q_tan(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_tan(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_tan(tensor), tensor @@ -986,8 +887,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with hyperbolic cosine values. - fn q_cosh(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_cosh(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_cosh(tensor), tensor @@ -1003,8 +904,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with hyperbolic sine values. - fn q_sinh(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_sinh(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_sinh(tensor), tensor @@ -1020,8 +921,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with hyperbolic tangent values. - fn q_tanh(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_tanh(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_tanh(tensor), tensor @@ -1037,8 +938,8 @@ pub trait QTensorOps { /// # Returns /// /// A tensor with the same shape as `tensor` with error function values. - fn q_erf(tensor: QuantizedTensor) -> QuantizedTensor { - dequant_op_quant!( + fn q_erf(tensor: QuantizedTensor) -> TensorPrimitive { + dequant_op_flow!( ty Self, float_op |tensor| B::float_erf(tensor), tensor @@ -1300,9 +1201,6 @@ pub trait QTensorOps { B::float_all_dim(tensor_f, dim) } - /// Broadcasts the `tensor` to the given `shape`. - fn q_expand(tensor: QuantizedTensor, shape: Shape) -> QuantizedTensor; - /// Sort the elements of the input `tensor` by value in along a given dimension. /// /// This sort is unstable (i.e., may reorder equal elements). diff --git a/crates/burn-tensor/src/tensor/quantization/bytes.rs b/crates/burn-tensor/src/tensor/quantization/bytes.rs index 1373009b24..fea90cb1ba 100644 --- a/crates/burn-tensor/src/tensor/quantization/bytes.rs +++ b/crates/burn-tensor/src/tensor/quantization/bytes.rs @@ -4,7 +4,7 @@ use crate::{Bytes, Element}; use alloc::vec::Vec; use super::{ - QParams, QuantizationMode, QuantizationScheme, QuantizationStrategy, QuantizationType, + QParams, QuantInputType, QuantLevel, QuantMode, QuantScheme, QuantizationStrategy, SymmetricQuantization, pack_i8s_to_u32s, unpack_u32s_to_i8s, }; @@ -21,7 +21,7 @@ pub struct QuantizedBytes { /// The quantized values and quantization parameters represented as bytes. pub bytes: Bytes, /// The quantization scheme. - pub scheme: QuantizationScheme, + pub scheme: QuantScheme, /// The number of quantized elements. pub num_elements: usize, } @@ -101,8 +101,8 @@ impl QuantizedBytes { _ => unreachable!(), }; - let num_params = match self.scheme { - QuantizationScheme::PerTensor(..) => 1, + let num_params = match self.scheme.level { + QuantLevel::Tensor => 1, }; let scale_size = num_params; // f32 scale is the same number of bytes as u32 @@ -116,7 +116,12 @@ impl QuantizedBytes { /// Dequantizes the data according to its quantization scheme. pub fn dequantize(self) -> (Vec, QParams, Vec>) { match self.scheme { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { let (values, qparams) = self.into_vec_i8(); let strategy = QuantizationStrategy::PerTensorSymmetricInt8( SymmetricQuantization::init(qparams.scale[0]), diff --git a/crates/burn-tensor/src/tensor/quantization/primitive.rs b/crates/burn-tensor/src/tensor/quantization/primitive.rs index a4f1fba1c8..8e575b34b4 100644 --- a/crates/burn-tensor/src/tensor/quantization/primitive.rs +++ b/crates/burn-tensor/src/tensor/quantization/primitive.rs @@ -1,7 +1,7 @@ -use super::QuantizationScheme; +use super::QuantScheme; /// Quantized tensor primitive. pub trait QTensorPrimitive { /// Returns the quantization scheme for the given tensor. - fn scheme(&self) -> &QuantizationScheme; + fn scheme(&self) -> &QuantScheme; } diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index 8766c9e4b0..df0864123f 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -1,5 +1,3 @@ -#![allow(missing_docs)] // cube derive macros - use serde::{Deserialize, Serialize}; use crate::{Tensor, TensorPrimitive, backend::Backend}; @@ -8,56 +6,106 @@ use super::{ Calibration, CalibrationRange, QuantizationParameters, QuantizationParametersPrimitive, }; -#[cfg(feature = "cubecl")] -use cubecl::prelude::*; +/// Describes a quantization scheme/configuration. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct QuantScheme { + /// Granularity level of quantization (e.g., per-tensor). + pub level: QuantLevel, + /// Quantization mode (e.g., symmetric). + pub mode: QuantMode, + /// Data type used for storing quantized values (e.g., QInt8). + pub q_type: QuantInputType, + /// Precision used for accumulating intermediate values (e.g., during matmul). + pub acc_precision: QuantAccPrecision, + /// Whether to propagate quantization to outputs or return unquantized results. + pub propagation: QuantPropagation, +} + +impl Default for QuantScheme { + fn default() -> Self { + Self { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + acc_precision: QuantAccPrecision::Full, + propagation: QuantPropagation::Inhibit, + } + } +} + +impl QuantScheme { + /// Set the quantization level. + pub fn set_level(mut self, level: QuantLevel) -> Self { + self.level = level; + self + } -/// Quantization data type. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] -#[cfg_attr(feature = "cubecl", derive(CubeType, PartialOrd, Ord))] -pub enum QuantizationType { + /// Set the quantization mode. + pub fn set_mode(mut self, mode: QuantMode) -> Self { + self.mode = mode; + self + } + + /// Set the data type used for quantized values. + pub fn set_q_type(mut self, q_type: QuantInputType) -> Self { + self.q_type = q_type; + self + } + + /// Set the accumulation precision used during computations. + pub fn set_acc_precision(mut self, acc_precision: QuantAccPrecision) -> Self { + self.acc_precision = acc_precision; + self + } + + /// Set whether quantization is propagated through operations. + pub fn set_propagation(mut self, propagation: QuantPropagation) -> Self { + self.propagation = propagation; + self + } +} +/// Level or granularity of quantization. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum QuantLevel { + /// Quantize the whole tensor using a single tensor. + Tensor, +} + +/// Data type used to represent quantized values. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum QuantInputType { /// 8-bit signed integer. QInt8, } -/// Quantization mode. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] -#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))] -pub enum QuantizationMode { +/// Strategy used to quantize values. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum QuantMode { /// Symmetric or scale quantization. Symmetric, } -/// Quantization scheme. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] -#[cfg_attr(feature = "cubecl", derive(PartialOrd, Ord))] -pub enum QuantizationScheme { - /// Per-tensor quantization. - PerTensor(QuantizationMode, QuantizationType), +/// Quantization accumulator precision. This is the precision to used when accumulating values +/// while executing algorithms such as matmul. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum QuantAccPrecision { + /// Full precision accumulation (f32). + Full, + /// Half precision accumulation (f16). + Half, } -#[cfg(feature = "cubecl")] -impl CubeType for QuantizationScheme { - type ExpandType = Self; -} - -#[cfg(feature = "cubecl")] -impl CubeDebug for QuantizationScheme {} - -#[cfg(feature = "cubecl")] -impl cubecl::frontend::IntoMut for QuantizationScheme { - fn into_mut(self, _scope: &mut cubecl::ir::Scope) -> Self { - self - } +/// Specify if the output of an operation is quantized using the scheme of the input +/// or returned unquantized. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum QuantPropagation { + /// The output is quantized using the scheme of the input. + Propagate, + /// The output is not quantized. + Inhibit, } -impl QuantizationScheme { - /// Get the [quantization mode](QuantizationMode) - pub fn mode(&self) -> QuantizationMode { - match self { - QuantizationScheme::PerTensor(mode, ..) => *mode, - } - } - +impl QuantScheme { /// Compute the quantization range mapping. pub fn compute_range( &self, @@ -83,10 +131,8 @@ impl QuantizationScheme { calibration: &Calibration, ) -> (B::FloatTensorPrimitive, B::FloatTensorPrimitive) { match calibration { - Calibration::MinMax => match self { - QuantizationScheme::PerTensor(_, _) => { - (B::float_min(tensor.clone()), B::float_max(tensor)) - } + Calibration::MinMax => match self.level { + QuantLevel::Tensor => (B::float_min(tensor.clone()), B::float_max(tensor)), }, } } @@ -97,7 +143,12 @@ impl QuantizationScheme { range: CalibrationRange, ) -> QuantizationParameters { match self { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => { + QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + .. + } => { // Quantized range `[a, b]` let b = i8::MAX as i32; let a = -b; @@ -125,10 +176,4 @@ impl QuantizationScheme { }; self.compute_q_params(range).into() } - - pub fn q_type(&self) -> QuantizationType { - match self { - QuantizationScheme::PerTensor(_, quantization_type) => *quantization_type, - } - } } diff --git a/crates/burn-tensor/src/tensor/quantization/strategy.rs b/crates/burn-tensor/src/tensor/quantization/strategy.rs index 4450975e59..452c885e4d 100644 --- a/crates/burn-tensor/src/tensor/quantization/strategy.rs +++ b/crates/burn-tensor/src/tensor/quantization/strategy.rs @@ -3,7 +3,9 @@ use core::marker::PhantomData; use num_traits::{Float, PrimInt, Signed}; use serde::{Deserialize, Serialize}; -use super::{QuantizationMode, QuantizationScheme, QuantizationType}; +use super::{ + QuantAccPrecision, QuantInputType, QuantLevel, QuantMode, QuantPropagation, QuantScheme, +}; /// Quantization strategy. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -30,11 +32,15 @@ impl QuantizationStrategy { impl QuantizationStrategy { /// Returns the corresponding quantization scheme. - pub fn scheme(&self) -> QuantizationScheme { + pub fn scheme(&self) -> QuantScheme { match self { - QuantizationStrategy::PerTensorSymmetricInt8(_) => { - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) - } + QuantizationStrategy::PerTensorSymmetricInt8(_) => QuantScheme { + level: QuantLevel::Tensor, + mode: QuantMode::Symmetric, + q_type: QuantInputType::QInt8, + acc_precision: QuantAccPrecision::Full, + propagation: QuantPropagation::Inhibit, + }, } } } diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 7f2a36a7aa..7bb7a7c427 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -365,11 +365,7 @@ macro_rules! as_type { pub mod qtensor { use core::marker::PhantomData; - use crate::{ - Tensor, TensorData, - backend::Backend, - quantization::{QuantizationMode, QuantizationScheme, QuantizationType}, - }; + use crate::{Tensor, TensorData, backend::Backend, quantization::QuantScheme}; pub struct QTensor { b: PhantomData, @@ -384,12 +380,8 @@ pub mod qtensor { /// Creates a quantized int8 tensor from the floating point data using per-tensor symmetric quantization. pub fn int8_symmetric>(floats: F) -> Tensor { - Tensor::from_floats(floats, &Default::default()).quantize_dynamic( - &QuantizationScheme::PerTensor( - QuantizationMode::Symmetric, - QuantizationType::QInt8, - ), - ) + Tensor::from_floats(floats, &Default::default()) + .quantize_dynamic(&QuantScheme::default()) } } } diff --git a/crates/burn-tensor/src/tests/quantization/calibration.rs b/crates/burn-tensor/src/tests/quantization/calibration.rs index a06b4a3f1c..b1abe55e83 100644 --- a/crates/burn-tensor/src/tests/quantization/calibration.rs +++ b/crates/burn-tensor/src/tests/quantization/calibration.rs @@ -3,15 +3,14 @@ mod tests { use super::*; use burn_tensor::{ Tensor, TensorData, - quantization::{Calibration, QuantizationMode, QuantizationScheme, QuantizationType}, + quantization::{Calibration, QuantScheme}, }; // NOTE: The scheme variant fields are not important for calibration, only the "main" variant (e.g., per-tensor) #[test] fn min_max_calibration_range_per_tensor() { let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &Default::default()); - let scheme = - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); + let scheme = QuantScheme::default(); let range = scheme.compute_range(&tensor, &Calibration::MinMax); diff --git a/crates/burn-tensor/src/tests/quantization/data.rs b/crates/burn-tensor/src/tests/quantization/data.rs index af87d8a4a8..39090c523b 100644 --- a/crates/burn-tensor/src/tests/quantization/data.rs +++ b/crates/burn-tensor/src/tests/quantization/data.rs @@ -5,12 +5,6 @@ mod tests { use burn_tensor::quantization::{QuantizationStrategy, SymmetricQuantization}; use burn_tensor::{Tensor, TensorData}; - // NOTE: we mark the per-block tests as `might_panic` since backends are not strictly - // required to support this quantization scheme. - // Also std feature gated (until `catch_unwind` is stable in core). - #[cfg(feature = "std")] - use burn_tensor::might_panic; - #[test] fn should_support_per_tensor_symmetric_int8() { let data = TensorData::quantized( diff --git a/crates/burn-tensor/src/tests/quantization/ops/matmul.rs b/crates/burn-tensor/src/tests/quantization/ops/matmul.rs index b9ac028e5e..9e095a9d67 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/matmul.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/matmul.rs @@ -2,7 +2,7 @@ mod tests { use super::*; use burn_tensor::TensorData; - use burn_tensor::{Tolerance, ops::FloatElem}; + use burn_tensor::{TensorPrimitive, Tolerance, ops::FloatElem, ops::QTensorOps}; type FT = FloatElem; #[test] @@ -11,8 +11,6 @@ mod tests { let tensor_2 = QTensor::::int8([[12.7], [4.0], [5.0], [1.0]]); let tensor_3 = tensor_1.matmul(tensor_2); - let expected = - TensorData::from([[18.0, 28.0, 40.0], [14.0, 23.0, 25.0], [14.0, 22.0, 30.0]]); let expected = TensorData::from([[42.05]]); tensor_3 @@ -92,4 +90,26 @@ mod tests { let _ = tensor_1.matmul(tensor_2); } + + #[test] + fn test_matmul_lhs_float_rhs_quantized() { + // Simulates a typical workflow with linear layers (e.g., transformers), where the rhs + // represents the weights. The lhs might be a float if a previous operation did not propagate + // the quantization. We still want to perform an efficient matmul with quantized weights. + // + // Since `q_matmul(lhs_f16, rhs_quant)` isn't currently supported, in practice it makes + // more sense to re-quantize the input back at this time. Better usability. + // + // This might be handled differently in the future (dequantize on read in fusion?). + let tensor_1 = TestTensor::<2>::from([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]]); + let tensor_2 = QTensor::::int8([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]]); + let tensor_3 = tensor_1.matmul(tensor_2); + + let expected = TensorData::from([[16.7, 27.05, 50.8], [14., 25., 43.4], [10., 17., 30.7]]); + let output = tensor_3.into_data(); + output.assert_approx_eq::(&expected, Tolerance::rel_abs(1e-2, 1e-1)); + + // Default quantization scheme does not propagate quantization with matmul + assert!(output.dtype.is_float()); + } } diff --git a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs index 1428aa22a3..b7792c6c10 100644 --- a/crates/burn-tensor/src/tests/quantization/ops/quantize.rs +++ b/crates/burn-tensor/src/tests/quantization/ops/quantize.rs @@ -4,8 +4,8 @@ mod tests { use super::*; use alloc::{vec, vec::Vec}; use burn_tensor::quantization::{ - QParams, QuantizationMode, QuantizationParameters, QuantizationScheme, - QuantizationStrategy, QuantizationType, QuantizedBytes, SymmetricQuantization, + QParams, QuantScheme, QuantizationParameters, QuantizationStrategy, QuantizedBytes, + SymmetricQuantization, }; use burn_tensor::{DType, Tensor, TensorData}; use burn_tensor::{Tolerance, ops::FloatElem}; @@ -30,8 +30,7 @@ mod tests { fn should_support_quantize_symmetric_int8() { let device = Default::default(); let tensor = TestTensor::<1>::from_floats([-1.8, -1.0, 0.0, 0.5], &device); - let scheme = - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); + let scheme = QuantScheme::default(); let qparams = QuantizationParameters { scale: Tensor::from_floats([0.014_173_228], &device), offset: None, @@ -75,8 +74,7 @@ mod tests { // NOTE: we use fully representable values since different backend implementations could differ slightly // due to rounding discrepancies let tensor = TestTensor::<1>::from_floats([5., 0., 4., -12.7], &device); - let scheme = - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); + let scheme = QuantScheme::default(); let x_q = tensor.quantize_dynamic(&scheme); diff --git a/crates/burn-tensor/src/tests/quantization/scheme.rs b/crates/burn-tensor/src/tests/quantization/scheme.rs index a747abc086..5e6a8fbead 100644 --- a/crates/burn-tensor/src/tests/quantization/scheme.rs +++ b/crates/burn-tensor/src/tests/quantization/scheme.rs @@ -2,8 +2,11 @@ mod tests { use super::*; use burn_tensor::{ - Tensor, TensorData, - quantization::{CalibrationRange, QuantizationMode, QuantizationScheme, QuantizationType}, + DType, Element, Tensor, TensorData, + quantization::{ + CalibrationRange, QuantAccPrecision, QuantInputType, QuantLevel, QuantMode, + QuantPropagation, QuantScheme, + }, }; use burn_tensor::{Tolerance, ops::FloatElem}; type FT = FloatElem; @@ -11,8 +14,7 @@ mod tests { #[test] fn per_tensor_symmetric_int8() { let device = Default::default(); - let scheme = - QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8); + let scheme = QuantScheme::default(); let range = CalibrationRange { min: TestTensor::<1>::from_floats([0.5], &device), max: TestTensor::<1>::from_floats([1.8], &device), @@ -26,4 +28,58 @@ mod tests { .assert_approx_eq::(&TensorData::from([0.014_173_228]), Tolerance::default()); assert!(qparams.offset.is_none()); } + + #[test] + fn quant_scheme_should_propagate() { + let device = Default::default(); + let scheme = QuantScheme { + propagation: QuantPropagation::Propagate, + ..Default::default() + }; + + let tensor_1 = TestTensor::<2>::from_floats([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]], &device) + .quantize_dynamic(&scheme); + let tensor_2 = TestTensor::<2>::from_floats([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]], &device) + .quantize_dynamic(&scheme); + + let tensor_3 = tensor_1.matmul(tensor_2); + assert_eq!(tensor_3.to_data().dtype, DType::QFloat(scheme)); + + let tensor_4 = tensor_3.add_scalar(1.); + assert_eq!(tensor_4.to_data().dtype, DType::QFloat(scheme)); + } + + #[test] + fn quant_scheme_should_not_propagate() { + let device = Default::default(); + let scheme = QuantScheme { + propagation: QuantPropagation::Inhibit, + acc_precision: QuantAccPrecision::Full, // f32 + ..Default::default() + }; + + let tensor_1 = TestTensor::<2>::from_floats([[1.0, 6.35], [2.0, 3.0], [1.0, 3.0]], &device) + .quantize_dynamic(&scheme); + let tensor_2 = TestTensor::<2>::from_floats([[4.0, 8.0, 12.7], [2.0, 3.0, 6.0]], &device) + .quantize_dynamic(&scheme); + + // Some ops like reshape, swap_dims, permute, expand, select, slice, etc. do not affect + // the propagation. It mostly applies to compute kernels. + let tensor_1 = tensor_1 + .permute([1, 0]) + .swap_dims(0, 1) + .reshape([1, 6]) + .reshape([3, 2]); + assert_eq!(tensor_1.to_data().dtype, DType::QFloat(scheme)); + + // When propagation is not desired, compute kernels like matmul should return tensor + // in floating point precision + let tensor_3 = tensor_1.matmul(tensor_2); + let dtype = tensor_3.to_data().dtype; + assert!(dtype.is_float()); + + // Subsequent ops will therefore be performed on floats + let tensor_4 = tensor_3.add(TestTensor::<2>::ones([3, 3], &device).cast(dtype)); + assert!(tensor_4.to_data().dtype.is_float()); + } } diff --git a/crates/burn-vision/src/backends/cpu/morphology/mod.rs b/crates/burn-vision/src/backends/cpu/morphology/mod.rs index 696787dc36..8d642b8934 100644 --- a/crates/burn-vision/src/backends/cpu/morphology/mod.rs +++ b/crates/burn-vision/src/backends/cpu/morphology/mod.rs @@ -1,11 +1,8 @@ use std::fmt::Debug; use burn_tensor::{ - BasicOps, Bool, DType, Element, Shape, Tensor, TensorData, - backend::Backend, - cast::ToElement, + BasicOps, Bool, DType, Element, Shape, Tensor, TensorData, backend::Backend, cast::ToElement, ops::BoolTensor, - quantization::{QuantizationScheme, QuantizationType}, }; use filter::{MaxOp, MinOp, MorphOperator, VecMorphOperator}; use filter_engine::{ColFilter, Filter, Filter2D, FilterEngine, RowFilter}; @@ -108,11 +105,7 @@ pub fn morph>( } DType::U8 => morph_typed::(data, shape, kernel, op, iter, btype, bvalue, &device), DType::Bool => morph_bool::(data, shape, kernel, op, iter, btype, bvalue, &device), - DType::QFloat(scheme) => match scheme { - QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => { - morph_typed::(data, shape, kernel, op, iter, btype, bvalue, &device) - } - }, + DType::QFloat(_) => unimplemented!(), } }