Skip to content

Commit a08b5fa

Browse files
committed
Unroll gather_to_lds calls
Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
1 parent 1089d44 commit a08b5fa

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,6 +1659,9 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
16591659
f"src: {src_data_type} vs dst: {dst_data_type}\n"
16601660
)
16611661

1662+
src = src.ir_value
1663+
dst = dst.ir_value
1664+
16621665
src_index_transformed, dst_index_transformed = src_idx, dst_idx
16631666
if src_mapping:
16641667
src_index_transformed = transform_index_on_mapping(
@@ -1668,28 +1671,23 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
16681671
dst_index_transformed = transform_index_on_mapping(
16691672
dst_mapping, dst_type.symbolic_shape, dst_idx
16701673
)
1671-
16721674
src_keys = list(src_index_transformed.keys())
16731675
src_fastest_dim = get_fastest_index(src_idx)
16741676
dst_keys = list(dst_index_transformed.keys())
16751677
dst_fastest_dim = get_fastest_index(dst_idx)
1676-
src_idx, dst_idx = [], []
16771678
for i in range(elements_per_thread):
16781679
new_src_index = copy.deepcopy(src_index_transformed)
16791680
src_key = src_keys[src_fastest_dim]
16801681
new_src_index[src_key].start += i
1681-
src_index_transformed = _build_start_indices(emitter, new_src_index)
1682+
src_index_transformed_ = _build_start_indices(emitter, new_src_index)
16821683
new_dst_index = copy.deepcopy(dst_index_transformed)
16831684
dst_key = dst_keys[dst_fastest_dim]
16841685
new_dst_index[dst_key].start += i
1685-
dst_index_transformed = _build_start_indices(emitter, new_dst_index)
1686-
src_idx.append(src_index_transformed)
1687-
dst_idx.append(dst_index_transformed)
1688-
1689-
return amdgpu_d.gather_to_lds(
1690-
src=src,
1691-
src_indices=src_idx,
1692-
dst=dst,
1693-
dst_indices=dst_idx,
1694-
transfer_type=element_type,
1695-
)
1686+
dst_index_transformed_ = _build_start_indices(emitter, new_dst_index)
1687+
amdgpu_d.gather_to_lds(
1688+
src=src,
1689+
src_indices=src_index_transformed_,
1690+
dst=dst,
1691+
dst_indices=dst_index_transformed_,
1692+
transfer_type=element_type,
1693+
)

0 commit comments

Comments
 (0)