|
| 1 | +# Copyright 2025 The IREE Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +# See https://llvm.org/LICENSE.txt for license information. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +from .._support.tracing import CapturedTrace |
| 8 | +from ..lang.global_symbols import * |
| 9 | +from ..ops.wave_ops import GatherToLDS, Write, get_custom |
| 10 | +from ..wave.constraints import ( |
| 11 | + Constraint, |
| 12 | +) |
| 13 | +from ..wave.utils.run_utils import get_default_arch |
| 14 | +from .utils.graph_utils import DCE, is_valid_global_read |
| 15 | +from .utils.symbol_utils import ( |
| 16 | + subs_idxc, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +gather_to_shared_supported_arch = ["gfx950"] |
| 21 | + |
| 22 | + |
| 23 | +def get_write_node_info(read_node): |
| 24 | + read_custom = get_custom(read_node) |
| 25 | + write_node = None |
| 26 | + for user in read_custom.users: |
| 27 | + if ( |
| 28 | + isinstance(user, Write) |
| 29 | + and subs_idxc(user.memory_type.address_space) == SHARED_ADDRESS_SPACE |
| 30 | + ): |
| 31 | + # get memory location and idx |
| 32 | + dst = user.memory |
| 33 | + dst_idx = user.get_derived_indices |
| 34 | + write_node = user |
| 35 | + |
| 36 | + return write_node, dst, dst_idx |
| 37 | + |
| 38 | + |
| 39 | +def gather_to_shared(trace: CapturedTrace, constraints: list[Constraint]): |
| 40 | + """ |
| 41 | + This function enables direct memory load from global to lds without |
| 42 | + passing through register reducing the data movement. This instruction |
| 43 | + is supported only on specific architectures (gfx950). |
| 44 | + """ |
| 45 | + |
| 46 | + if get_default_arch() not in gather_to_shared_supported_arch: |
| 47 | + return |
| 48 | + |
| 49 | + global_read_nodes = trace.walk(is_valid_global_read) |
| 50 | + for read_node in global_read_nodes: |
| 51 | + custom = get_custom(read_node) |
| 52 | + src = custom.memory |
| 53 | + src_idx = custom.get_derived_indices |
| 54 | + element_type = custom.type.dtype |
| 55 | + write_node, dst, dst_idx = get_write_node_info(read_node) |
| 56 | + write_node.replace_all_uses_with( |
| 57 | + GatherToLDS(src, src_idx, dst, dst_idx, element_type) |
| 58 | + ) |
| 59 | + |
| 60 | + DCE(trace) |
0 commit comments