Skip to content

Commit 8e15c23

Browse files
committed
[Wave] Add support for direct global load to lds
Signed-off-by: nithinsubbiah <nithinsubbiah@gmail.com>
1 parent a2c1a75 commit 8e15c23

File tree

9 files changed

+126
-28
lines changed

9 files changed

+126
-28
lines changed

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,16 @@ def select(cond: "Register", if_true: "Register", if_false: "Register") -> "Regi
287287
...
288288

289289

290+
def gather_to_lds(
291+
src: "Memory",
292+
src_idx: dict[IndexSymbol, IndexSequence],
293+
dst: "Memory",
294+
dst_idx: dict[IndexSymbol, IndexSequence],
295+
dtype: DataType,
296+
):
297+
...
298+
299+
290300
def define_op(op_name: str) -> Callable[[T], T]:
291301
def decorator(cls: T) -> T:
292302
cls.tkw_op_name = op_name
@@ -2323,3 +2333,12 @@ def indexing_dims(self) -> list[IndexExpr]:
23232333

23242334
def infer_type(self):
23252335
self.type = get_custom(_to_sequence(self.args)[0]).type
2336+
2337+
2338+
@define_op("gather_to_lds")
2339+
@dataclass
2340+
class GatherToLDS(CustomOp):
2341+
"""
2342+
Represents an instruction that performs direct load from global
2343+
to lds.
2344+
"""

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
exp2,
6161
extract,
6262
extract_slice,
63+
gather_to_lds,
6364
ge,
6465
get_custom,
6566
get_result,
@@ -1584,3 +1585,15 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):
15841585
[1],
15851586
)
15861587
emitter.bind_node_proxy(node, IRProxyValue(slice))
1588+
1589+
1590+
@handle_op(gather_to_lds)
1591+
def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
1592+
try:
1593+
src, src_idx, dst, dst_idx, dtype = node.args
1594+
except ValueError as e:
1595+
raise ValidationError("Malformed arguments") from e
1596+
1597+
amdgpu_d.gather_to_lds(
1598+
transfer_type=dtype, src=src, src_indices=src_idx, dst=dst, dst_indices=dst_idx
1599+
)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from .._support.tracing import CapturedTrace
8+
from ..lang.global_symbols import *
9+
from ..ops.wave_ops import GatherToLDS, Write, get_custom
10+
from ..wave.constraints import (
11+
Constraint,
12+
)
13+
from ..wave.utils.run_utils import get_default_arch
14+
from .utils.graph_utils import DCE, is_valid_global_read
15+
from .utils.symbol_utils import (
16+
subs_idxc,
17+
)
18+
19+
20+
gather_to_shared_supported_arch = ["gfx950"]
21+
22+
23+
def get_write_node_info(read_node):
24+
read_custom = get_custom(read_node)
25+
write_node = None
26+
for user in read_custom.users:
27+
if (
28+
isinstance(user, Write)
29+
and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE
30+
):
31+
# get memory location and idx
32+
dst = user.memory
33+
dst_idx = user.get_derived_indices
34+
write_node = user
35+
36+
return write_node, dst, dst_idx
37+
38+
39+
def gather_to_shared(trace: CapturedTrace, constraints: list[Constraint]):
40+
"""
41+
This function enables direct memory load from global to lds without
42+
passing through register reducing the data movement. This instruction
43+
is supported only on specific architectures (gfx950).
44+
"""
45+
46+
if get_default_arch() not in gather_to_shared_supported_arch:
47+
return
48+
49+
global_read_nodes = trace.walk(is_valid_global_read)
50+
for read_node in global_read_nodes:
51+
custom = get_custom(read_node)
52+
src = custom.memory
53+
src_idx = custom.get_derived_indices
54+
element_type = custom.type.dtype
55+
write_node, dst, dst_idx = get_write_node_info(read_node)
56+
write_node.replace_all_uses_with(
57+
GatherToLDS(src, src_idx, dst, dst_idx, element_type)
58+
)
59+
60+
DCE(trace)

iree/turbine/kernel/wave/global_to_shared_gathers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
from .utils.symbol_utils import subs_idxc
2323
from .utils.general_utils import is_gather
2424
from .minimize_global_loads import (
25-
has_write_shared_user,
2625
construct_min_global_access_pattern,
2726
materialize_shape,
2827
identify_optimizable_loads,
2928
update_write_dependencies,
3029
SharedReadMetadata,
3130
)
31+
from .utils.graph_utils import (
32+
has_write_shared_user,
33+
)
3234

3335
"""
3436
We are given N global gathers that are promoted to shared memory. This function

iree/turbine/kernel/wave/minimize_global_loads.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
TilingConstraint,
1212
)
1313
from .._support.tracing import CapturedTrace
14-
from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr
14+
from .._support.indexing import IndexSequence, IndexSymbol, IndexExpr
1515
from ..ops.wave_ops import Read, Write, get_custom
1616
from ..lang.global_symbols import *
1717
from .utils.general_utils import (
@@ -20,6 +20,9 @@
2020
is_shared_read,
2121
get_fastest_index,
2222
)
23+
from .utils.graph_utils import (
24+
is_valid_global_read,
25+
)
2326
from .utils.graph_utils import (
2427
DCE,
2528
)
@@ -41,23 +44,6 @@ class SharedReadMetadata:
4144
memory_shape: tuple[int | IndexExpr]
4245

4346

44-
def has_write_shared_user(node: Read) -> bool:
45-
return any(
46-
isinstance(user, Write)
47-
and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE
48-
for user in node.users
49-
)
50-
51-
52-
def is_valid_global_read(node: fx.Node) -> bool:
53-
custom = get_custom(node)
54-
return (
55-
isinstance(custom, Read)
56-
and subs_idxc(custom.memory_type.address_space) == GLOBAL_ADDRESS_SPACE
57-
and has_write_shared_user(custom)
58-
)
59-
60-
6147
def is_transposed_read(custom: Read) -> bool:
6248
"""
6349
Checks whether or not we are doing a transposed read.

iree/turbine/kernel/wave/promotion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def apply_promotion_pattern(
6767
```
6868
read_from_global lhs
6969
write_to_shared lhs
70-
read_from_global lhs
71-
write_to_shared lhs
70+
read_from_global rhs
71+
write_to_shared rhs
7272
shared_barrier
7373
read_from_shared lhs
7474
read_from_shared rhs

iree/turbine/kernel/wave/utils/graph_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import iree.turbine.kernel.lang as tkl
1010
from ...ops.wave_ops import (
1111
get_custom,
12+
Read,
1213
Write,
1314
NestedRegionOp,
1415
Output,
@@ -438,3 +439,23 @@ def is_barrier_between(src: fx.Node, dst: fx.Node) -> bool:
438439
return dst_check or root_check
439440

440441
assert False, "Unhandled case when src and dst are in different graphs"
442+
443+
444+
def has_write_shared_user(node: Read) -> bool:
445+
return any(
446+
isinstance(user, Write)
447+
and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE
448+
for user in node.users
449+
)
450+
451+
452+
def is_valid_global_read(node: fx.Node) -> bool:
453+
"""
454+
Check if a read node is global and if its user writes to shared memory.
455+
"""
456+
custom = get_custom(node)
457+
return (
458+
isinstance(custom, Read)
459+
and subs_idxc(custom.memory_type.address_space) == GLOBAL_ADDRESS_SPACE
460+
and has_write_shared_user(custom)
461+
)

iree/turbine/kernel/wave/utils/run_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,5 +235,6 @@ def set_default_run_config(options: WaveCompileOptions) -> WaveCompileOptions:
235235
"""Return default config for running."""
236236
options.backend = "rocm"
237237
options.device = "hip"
238-
options.target = get_default_arch()
238+
options.target = "gfx950"
239+
# options.target = get_default_arch()
239240
return options

iree/turbine/kernel/wave/wave.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from .decompose_scan_ops import decompose_scan_ops
5151
from .decompose_dot_mma import decompose_dot_mma
5252
from .expansion.expansion import expand_graph, add_get_results
53+
from .gather_to_shared import gather_to_shared
5354
from .global_to_shared_gathers import global_to_shared_gathers
5455
from .hoisting import hoist_loop_invariant_ops
5556
from .minimize_global_loads import minimize_global_loads
@@ -86,12 +87,6 @@
8687
import inspect
8788
import sympy
8889
import warnings
89-
from pathlib import Path
90-
import sys
91-
import subprocess
92-
import os
93-
import shutil
94-
import glob
9590

9691
__all__ = ["wave", "wave_trace_only"]
9792

@@ -537,6 +532,7 @@ def _trace_and_get_kernel_signature(
537532
partial(hoist_loop_invariant_ops, trace, self.constraints),
538533
partial(global_to_shared_gathers, trace, self.constraints),
539534
partial(minimize_global_loads, trace, self.constraints),
535+
partial(gather_to_shared, trace, self.constraints),
540536
partial(apply_shared_memory_indexing_corrections, trace, self.constraints),
541537
]
542538

0 commit comments

Comments
 (0)