@@ -150,12 +150,21 @@ def get_node_num_zeros_and_size(
150
150
if weight is None :
151
151
return 0 , 0
152
152
153
- num_zeros = numpy .count_nonzero (weight == zero_point )
153
+ is_channel_wise_quantized = (
154
+ isinstance (zero_point , numpy .ndarray ) and zero_point .ndim > 0
155
+ )
156
+ if is_channel_wise_quantized :
157
+ # Broadcast zero point to match weight shape
158
+ broadcasted_zero_point = numpy .broadcast_to (zero_point , weight .shape )
159
+ num_zeros = numpy .count_nonzero (weight == broadcasted_zero_point )
160
+ del broadcasted_zero_point
161
+ else :
162
+ num_zeros = numpy .count_nonzero (weight == zero_point )
154
163
155
164
return num_zeros , weight .size
156
165
157
166
158
- def group_four_block (array : numpy .ndarray , pad_value : bool = True ) -> numpy .ndarray :
167
+ def group_four_block (array : numpy .ndarray ) -> numpy .ndarray :
159
168
"""
160
169
:param array: array to group into four blocks
161
170
:param pad_value: value to pad remainder block with
@@ -205,16 +214,28 @@ def get_node_num_four_block_zeros_and_size(
205
214
return 0 , 0
206
215
207
216
# Group into blocks
208
- weight_blocks = group_four_block (weight , pad_value = zero_point )
217
+ weight_blocks = group_four_block (weight )
209
218
210
219
# Count non-zero blocks
211
- num_zeros_per_block = numpy .count_nonzero (weight_blocks == zero_point , axis = 1 )
212
- num_zero_blocks = numpy .count_nonzero (num_zeros_per_block == 4 , axis = 0 )
220
+ if isinstance (zero_point , numpy .ndarray ):
221
+ # Channel-wise quantized case
222
+ # Group zero point into blocks like the weight
223
+ zero_point_blocks = group_four_block (
224
+ numpy .broadcast_to (zero_point , weight .shape )
225
+ )
226
+ num_zeros_per_block = numpy .count_nonzero (
227
+ weight_blocks == zero_point_blocks , axis = 1
228
+ )
229
+ else :
230
+ num_zeros_per_block = numpy .count_nonzero (weight_blocks == zero_point , axis = 1 )
213
231
232
+ num_zero_blocks = numpy .count_nonzero (num_zeros_per_block == 4 , axis = 0 )
214
233
return num_zero_blocks , weight_blocks .shape [0 ]
215
234
216
235
217
- def get_zero_point (model_graph : ONNXGraph , node : NodeProto ) -> int :
236
+ def get_zero_point (
237
+ model_graph : ONNXGraph , node : NodeProto
238
+ ) -> Union [int , numpy .ndarray ]:
218
239
"""
219
240
:param model_graph: instance of ONNXGraph that contains the given node
220
241
:param node: node to find zero point of
@@ -240,10 +261,7 @@ def _get_node_zero_point_init_name(node: NodeProto) -> str:
240
261
zero_point = get_initializer_value (
241
262
model_graph , node , zero_point_initializer_name
242
263
)
243
- if zero_point .ndim != 0 :
244
- raise NotImplementedError ("Channel-wise zero points are not supported" )
245
-
246
- return int (zero_point )
264
+ return int (zero_point ) if zero_point .ndim == 0 else zero_point
247
265
else :
248
266
return 0
249
267
0 commit comments