Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 78a0f3c

Browse files
committed
Add Channel Wise Quantization Support (#441)
1 parent 61caebe commit 78a0f3c

File tree

3 files changed

+48
-14
lines changed

3 files changed

+48
-14
lines changed

src/sparsezoo/analyze/analysis.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,13 @@ class NodeAnalysis(YAMLSerializableBaseModel):
304304
)
305305
sparse_node: bool = Field(description="Does the node have sparse weights")
306306
quantized_node: bool = Field(description="Does the node have quantized weights")
307-
zero_point: int = Field(
307+
zero_point: Union[int, numpy.ndarray] = Field(
308308
description="Node zero point for quantization, default zero"
309309
)
310310

311+
class Config:
312+
arbitrary_types_allowed = True
313+
311314
@classmethod
312315
def from_node(
313316
cls,

src/sparsezoo/utils/calculate_ops.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,23 @@ def _get_gemm_dense_sparse_ops(
195195
:param is_four_block_sparse: true if the weight is four block sparse
196196
:return: number of dense and sparse operations performed
197197
"""
198-
if is_four_block_sparse:
199-
weight_blocks = group_four_block(weight, pad_value=zero_point)
200-
num_zeros_per_block = numpy.count_nonzero(weight_blocks == zero_point, axis=1)
198+
is_channel_wise_quantized = isinstance(zero_point, numpy.ndarray)
199+
if is_channel_wise_quantized:
200+
# Broadcast zero_point per channel
201+
zero_point = numpy.broadcast_to(zero_point, weight.shape)
201202

203+
if is_four_block_sparse:
204+
weight_blocks = group_four_block(weight)
205+
if is_channel_wise_quantized:
206+
# Group zero_point in the same way as weight
207+
zero_point_blocks = group_four_block(zero_point)
208+
num_zeros_per_block = numpy.count_nonzero(
209+
weight_blocks == zero_point_blocks, axis=1
210+
)
211+
else:
212+
num_zeros_per_block = numpy.count_nonzero(
213+
weight_blocks == zero_point, axis=1
214+
)
202215
num_zero_blocks = numpy.count_nonzero(num_zeros_per_block == 4, axis=0)
203216
num_non_zero_blocks = numpy.count_nonzero(num_zeros_per_block != 4, axis=0)
204217

src/sparsezoo/utils/onnx/analysis.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,21 @@ def get_node_num_zeros_and_size(
150150
if weight is None:
151151
return 0, 0
152152

153-
num_zeros = numpy.count_nonzero(weight == zero_point)
153+
is_channel_wise_quantized = (
154+
isinstance(zero_point, numpy.ndarray) and zero_point.ndim > 0
155+
)
156+
if is_channel_wise_quantized:
157+
# Broadcast zero point to match weight shape
158+
broadcasted_zero_point = numpy.broadcast_to(zero_point, weight.shape)
159+
num_zeros = numpy.count_nonzero(weight == broadcasted_zero_point)
160+
del broadcasted_zero_point
161+
else:
162+
num_zeros = numpy.count_nonzero(weight == zero_point)
154163

155164
return num_zeros, weight.size
156165

157166

158-
def group_four_block(array: numpy.ndarray, pad_value: bool = True) -> numpy.ndarray:
167+
def group_four_block(array: numpy.ndarray) -> numpy.ndarray:
159168
"""
160169
:param array: array to group into four blocks
161170
:param pad_value: value to pad remainder block with
@@ -205,16 +214,28 @@ def get_node_num_four_block_zeros_and_size(
205214
return 0, 0
206215

207216
# Group into blocks
208-
weight_blocks = group_four_block(weight, pad_value=zero_point)
217+
weight_blocks = group_four_block(weight)
209218

210219
# Count non-zero blocks
211-
num_zeros_per_block = numpy.count_nonzero(weight_blocks == zero_point, axis=1)
212-
num_zero_blocks = numpy.count_nonzero(num_zeros_per_block == 4, axis=0)
220+
if isinstance(zero_point, numpy.ndarray):
221+
# Channel-wise quantized case
222+
# Group zero point into blocks like the weight
223+
zero_point_blocks = group_four_block(
224+
numpy.broadcast_to(zero_point, weight.shape)
225+
)
226+
num_zeros_per_block = numpy.count_nonzero(
227+
weight_blocks == zero_point_blocks, axis=1
228+
)
229+
else:
230+
num_zeros_per_block = numpy.count_nonzero(weight_blocks == zero_point, axis=1)
213231

232+
num_zero_blocks = numpy.count_nonzero(num_zeros_per_block == 4, axis=0)
214233
return num_zero_blocks, weight_blocks.shape[0]
215234

216235

217-
def get_zero_point(model_graph: ONNXGraph, node: NodeProto) -> int:
236+
def get_zero_point(
237+
model_graph: ONNXGraph, node: NodeProto
238+
) -> Union[int, numpy.ndarray]:
218239
"""
219240
:param model_graph: instance of ONNXGraph that contains the given node
220241
:param node: node to find zero point of
@@ -240,10 +261,7 @@ def _get_node_zero_point_init_name(node: NodeProto) -> str:
240261
zero_point = get_initializer_value(
241262
model_graph, node, zero_point_initializer_name
242263
)
243-
if zero_point.ndim != 0:
244-
raise NotImplementedError("Channel-wise zero points are not supported")
245-
246-
return int(zero_point)
264+
return int(zero_point) if zero_point.ndim == 0 else zero_point
247265
else:
248266
return 0
249267

0 commit comments

Comments
 (0)