@@ -1659,6 +1659,9 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
1659
1659
f"src: { src_data_type } vs dst: { dst_data_type } \n "
1660
1660
)
1661
1661
1662
+ src = src .ir_value
1663
+ dst = dst .ir_value
1664
+
1662
1665
src_index_transformed , dst_index_transformed = src_idx , dst_idx
1663
1666
if src_mapping :
1664
1667
src_index_transformed = transform_index_on_mapping (
@@ -1668,28 +1671,23 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
1668
1671
dst_index_transformed = transform_index_on_mapping (
1669
1672
dst_mapping , dst_type .symbolic_shape , dst_idx
1670
1673
)
1671
-
1672
1674
src_keys = list (src_index_transformed .keys ())
1673
1675
src_fastest_dim = get_fastest_index (src_idx )
1674
1676
dst_keys = list (dst_index_transformed .keys ())
1675
1677
dst_fastest_dim = get_fastest_index (dst_idx )
1676
- src_idx , dst_idx = [], []
1677
1678
for i in range (elements_per_thread ):
1678
1679
new_src_index = copy .deepcopy (src_index_transformed )
1679
1680
src_key = src_keys [src_fastest_dim ]
1680
1681
new_src_index [src_key ].start += i
1681
- src_index_transformed = _build_start_indices (emitter , new_src_index )
1682
+ src_index_transformed_ = _build_start_indices (emitter , new_src_index )
1682
1683
new_dst_index = copy .deepcopy (dst_index_transformed )
1683
1684
dst_key = dst_keys [dst_fastest_dim ]
1684
1685
new_dst_index [dst_key ].start += i
1685
- dst_index_transformed = _build_start_indices (emitter , new_dst_index )
1686
- src_idx .append (src_index_transformed )
1687
- dst_idx .append (dst_index_transformed )
1688
-
1689
- return amdgpu_d .gather_to_lds (
1690
- src = src ,
1691
- src_indices = src_idx ,
1692
- dst = dst ,
1693
- dst_indices = dst_idx ,
1694
- transfer_type = element_type ,
1695
- )
1686
+ dst_index_transformed_ = _build_start_indices (emitter , new_dst_index )
1687
+ amdgpu_d .gather_to_lds (
1688
+ src = src ,
1689
+ src_indices = src_index_transformed_ ,
1690
+ dst = dst ,
1691
+ dst_indices = dst_index_transformed_ ,
1692
+ transfer_type = element_type ,
1693
+ )
0 commit comments