Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,20 @@ def select(
) -> "Register": ...


def gather_to_lds(
src: Memory,
src_idx: dict[IndexSymbol, IndexSequence],
src_type: DataType,
dst: Memory,
dst_idx: dict[IndexSymbol, IndexSequence],
dst_type: DataType,
src_mapping: Optional[IndexMapping] = None,
dst_mapping: Optional[IndexMapping] = None,
elements_per_thread: Optional[IndexExpr | int] = None,
):
...


def define_op(op_name: str) -> Callable[[T], T]:
def decorator(cls: T) -> T:
cls.tkw_op_name = op_name
Expand Down Expand Up @@ -2482,3 +2496,23 @@ def indexing_dims(self) -> list[IndexExpr]:

def infer_type(self):
self.type = get_custom(_to_sequence(self.args)[0]).type


@define_op("gather_to_lds")
@dataclass
class GatherToLDS(CustomOp):
"""
Represents an instruction that performs direct load from global
to lds. Source node points to the global memory to load from
and the destination node points to shared memory.
"""

src: Memory
src_idx: dict[IndexSymbol, IndexSequence]
src_type: DataType
dst: Memory
dst_idx: dict[IndexSymbol, IndexSequence]
dst_type: DataType
src_mapping: Optional[IndexMapping]
dst_mapping: Optional[IndexMapping]
elements_per_thread: Optional[IndexExpr | int]
81 changes: 81 additions & 0 deletions iree/turbine/kernel/wave/codegen/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
exp2,
extract,
extract_slice,
gather_to_lds,
ge,
get_custom,
get_result,
Expand Down Expand Up @@ -1707,3 +1708,83 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):
[1],
)
emitter.bind_node_proxy(node, IRProxyValue(slice))


@handle_op(gather_to_lds)
def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node):
try:
(
src,
src_idx,
src_type,
dst,
dst_idx,
dst_type,
src_mapping,
dst_mapping,
elements_per_thread,
) = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e

element_type = IrType.parse(src_type.dtype.ir_type_asm())

src = cast_py_value(emitter, src)
dst = cast_py_value(emitter, dst)
src_data_type = get_type_or_element_type(src.ir_value.type)
dst_data_type = get_type_or_element_type(dst.ir_value.type)

if not (
MemRefType.isinstance(src.ir_value.type)
and MemRefType.isinstance(dst.ir_value.type)
):
op = get_custom(node)
raise ValidationError(
f"Expected src and dst to be of Memref type for\n"
f"{op}\nGot\n"
f"src: {src.ir_value.type}\n"
f"dst: {dst.ir_value.type}\n"
)

if src_data_type != dst_data_type:
op = get_custom(node)
raise ValidationError(
f"Expected src and dst to have same data type for\n"
f"{op}\nGot\n"
f"src: {src_data_type} vs dst: {dst_data_type}\n"
)

src = src.ir_value
dst = dst.ir_value

src_index_transformed, dst_index_transformed = src_idx, dst_idx
if src_mapping:
src_index_transformed = transform_index_on_mapping(
src_mapping, src_type.symbolic_shape, src_idx
)
if dst_mapping:
dst_index_transformed = transform_index_on_mapping(
dst_mapping, dst_type.symbolic_shape, dst_idx
)
src_keys = list(src_index_transformed.keys())
src_fastest_dim = get_fastest_index(src_idx)
dst_keys = list(dst_index_transformed.keys())
dst_fastest_dim = get_fastest_index(dst_idx)
for i in range(elements_per_thread):
new_src_index = copy.deepcopy(src_index_transformed)
src_key = src_keys[src_fastest_dim]
new_src_index[src_key].start += i
src_index_transformed_ = _build_start_indices(emitter, new_src_index)
new_dst_index = copy.deepcopy(dst_index_transformed)
dst_key = dst_keys[dst_fastest_dim]
new_dst_index[dst_key].start += i
dst_index_transformed_ = _build_start_indices(emitter, new_dst_index)
amdgpu_d.gather_to_lds(
src=src,
src_indices=src_index_transformed_,
dst=dst,
dst_indices=dst_index_transformed_,
transfer_type=element_type,
)

amdgpu_d.lds_barrier()
83 changes: 83 additions & 0 deletions iree/turbine/kernel/wave/gather_to_shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .._support.tracing import CapturedTrace
from ..lang.global_symbols import *
from ..ops.wave_ops import GatherToLDS, Write, get_custom
from ..wave.constraints import (
Constraint,
)
from ..wave.utils.run_utils import get_default_arch
from .utils.general_utils import get_fastest_index, is_valid_global_read
from .utils.graph_utils import DCE
from .utils.mapping_utils import transform_index_on_mapping
from .utils.symbol_utils import (
subs_idxc,
)


gather_to_shared_supported_arch = ["gfx950"]


def get_write_node_consumers(read_custom):
write_node = []

for user in read_custom.users:
if (
isinstance(user, Write)
and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE
):
write_node.append(user)

return write_node


def gather_to_shared(trace: CapturedTrace, constraints: list[Constraint]):
"""
This pass enables direct memory load from global to lds without passing
through register reducing the data movement. This instruction is supported
only on specific architectures (gfx950).
"""

# if get_default_arch() not in gather_to_shared_supported_arch:
# return

global_read_nodes = trace.walk(is_valid_global_read)
for read_node in global_read_nodes:
read_custom = get_custom(read_node)
write_consumers = get_write_node_consumers(read_custom)
if not write_consumers:
continue
read_memory, read_mapping, read_type = (
read_custom.memory,
read_custom.mapping,
read_custom.type,
)

elements_per_thread = read_custom.elements_per_thread

for write_custom in write_consumers:
write_memory, write_mapping, write_type = (
write_custom.memory,
write_custom.mapping,
write_custom.type,
)
with write_custom.graph.inserting_before(write_custom.fx_node):
write_custom.replace_all_uses_with(
GatherToLDS(
read_memory,
read_custom.index,
read_type,
write_memory,
write_memory.index,
write_type,
read_mapping,
write_mapping,
elements_per_thread,
).add_to_graph(write_custom.graph)
)
write_custom.erase()
read_custom.erase()
4 changes: 3 additions & 1 deletion iree/turbine/kernel/wave/global_to_shared_gathers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
from .utils.symbol_utils import subs_idxc
from .utils.general_utils import is_gather
from .minimize_global_loads import (
has_write_shared_user,
construct_min_global_access_pattern,
materialize_shape,
identify_optimizable_loads,
update_write_dependencies,
SharedReadMetadata,
)
from .utils.general_utils import (
has_write_shared_user,
)

"""
We are given N global gathers that are promoted to shared memory. This function
Expand Down
22 changes: 4 additions & 18 deletions iree/turbine/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
TilingConstraint,
)
from .._support.tracing import CapturedTrace
from .._support.indexing import IndexingContext, IndexSequence, IndexSymbol, IndexExpr
from .._support.indexing import IndexSequence, IndexSymbol, IndexExpr
from ..ops.wave_ops import Read, Write, get_custom
from ..lang.global_symbols import *
from .utils.general_utils import (
Expand All @@ -21,6 +21,9 @@
is_shared_read,
get_fastest_index,
)
from .utils.general_utils import (
is_valid_global_read,
)
from .utils.graph_utils import (
DCE,
)
Expand All @@ -42,23 +45,6 @@ class SharedReadMetadata:
memory_shape: tuple[int | IndexExpr]


def has_write_shared_user(node: Read) -> bool:
return any(
isinstance(user, Write)
and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE
for user in node.users
)


def is_valid_global_read(node: fx.Node) -> bool:
custom = get_custom(node)
return (
isinstance(custom, Read)
and subs_idxc(custom.memory_type.address_space) == GLOBAL_ADDRESS_SPACE
and has_write_shared_user(custom)
)


def is_transposed_read(custom: Read) -> bool:
"""
Checks whether or not we are doing a transposed read.
Expand Down
4 changes: 2 additions & 2 deletions iree/turbine/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def apply_promotion_pattern(
```
read_from_global lhs
write_to_shared lhs
read_from_global lhs
write_to_shared lhs
read_from_global rhs
write_to_shared rhs
shared_barrier
read_from_shared lhs
read_from_shared rhs
Expand Down
23 changes: 22 additions & 1 deletion iree/turbine/kernel/wave/utils/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import os
import sympy
import torch
import torch.fx as fx
from typing import Any, Callable, Optional


from ..._support.indexing import IndexExpr, IndexSequence, IndexSymbol
from ...lang.global_symbols import *
from ...ops.wave_ops import CustomOp, Read, Iterate, Write
from ...ops.wave_ops import CustomOp, Read, Iterate, Write, get_custom
from ..assumptions import Assumption
from ..constraints import (
Constraint,
Expand Down Expand Up @@ -375,6 +376,26 @@ def is_shared_read(node: CustomOp) -> bool:
)


def has_write_shared_user(node: Read) -> bool:
return any(
isinstance(user, Write)
and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE
for user in node.users
)


def is_valid_global_read(node: fx.Node) -> bool:
"""
Check if a read node is global and if its user writes to shared memory.
"""
custom = get_custom(node)
return (
isinstance(custom, Read)
and subs_idxc(custom.memory_type.address_space) == GLOBAL_ADDRESS_SPACE
and has_write_shared_user(custom)
)


def is_gather(custom: CustomOp) -> bool:
if not isinstance(custom, Read):
return False
Expand Down
1 change: 1 addition & 0 deletions iree/turbine/kernel/wave/utils/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import iree.turbine.kernel.lang as tkl
from ...ops.wave_ops import (
get_custom,
Read,
Write,
NestedRegionOp,
Output,
Expand Down
4 changes: 3 additions & 1 deletion iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..lang import Grid, IndexMapping
from ..lang.global_symbols import *
from ..ops import wave_ops
from ..ops.wave_ops import Iterate, CustomOp, get_custom, IterArg
from ..ops.wave_ops import Iterate, CustomOp, get_custom
from .._support.indexing import IndexingContext, IndexExpr, index_symbol
from .symbolic_constraints import SymbolicAlias
from .._support.tracing import (
Expand Down Expand Up @@ -51,6 +51,7 @@
from .decompose_scan_ops import decompose_scan_ops
from .decompose_dot_mma import decompose_dot_mma
from .expansion.expansion import expand_graph, add_get_results
from .gather_to_shared import gather_to_shared
from .global_to_shared_gathers import global_to_shared_gathers
from .hoisting import hoist_loop_invariant_ops
from .minimize_global_loads import minimize_global_loads
Expand Down Expand Up @@ -559,6 +560,7 @@ def _trace_and_get_kernel_signature(
partial(global_to_shared_gathers, trace, self.constraints),
partial(minimize_global_loads, trace, self.constraints),
partial(apply_shared_memory_indexing_corrections, trace, self.constraints),
partial(gather_to_shared, trace, self.constraints),
]

# Partition strided operators.
Expand Down
Loading