Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ba7028b
Add initial cache modifiers code and docs
mawad-amd Sep 10, 2025
276713b
Add test
mawad-amd Sep 10, 2025
6f6818f
Apply Ruff auto-fixes
github-actions[bot] Sep 10, 2025
9ad63a0
Use `None` for default value
mawad-amd Sep 13, 2025
677c966
Apply Ruff auto-fixes
github-actions[bot] Sep 13, 2025
af3592d
Cleanup the test
mawad-amd Sep 14, 2025
8a411a2
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
ff26f96
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Sep 14, 2025
8f76d95
Check return value
mawad-amd Sep 14, 2025
87fb74a
Remove volatile from store
mawad-amd Sep 14, 2025
162ec39
Add test store
mawad-amd Sep 14, 2025
01da6ca
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
9a27ead
Add put/get modifiers
mawad-amd Sep 14, 2025
99ee66c
Add tests for put and get cache modifiers
mawad-amd Sep 14, 2025
74d0133
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
e76d4c5
Test default values
mawad-amd Sep 14, 2025
451ee99
Fix default value docstring
mawad-amd Sep 14, 2025
b524f40
Fix tests
mawad-amd Sep 14, 2025
b8bd8a7
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
0a157b5
Sync cache modifiers branch with main and add cache modifiers to copy…
Copilot Oct 11, 2025
b127b91
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Oct 11, 2025
c9f314f
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Oct 24, 2025
e74aacd
Fix device mismatch in test_copy_cache_modifiers assertions (#271)
Copilot Oct 29, 2025
88970ee
Fix pointer arithmetic in test_copy_cache_modifiers (#273)
Copilot Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 68 additions & 10 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,7 @@ def __translate(ptr, from_rank, to_rank, heap_bases):


@triton.jit
def load(pointer, to_rank, from_rank, heap_bases, mask=None):
def load(pointer, to_rank, from_rank, heap_bases, mask=None, cache_modifier=None, volatile=False):
"""
Loads a value from the specified rank's memory location.

Expand All @@ -1473,12 +1473,28 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
data from the target memory location. If the from_rank and to_rank are the same,
this function performs a local load operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global load instruction. These affect cache usage across the CU,
L2, and last-level caches.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local.
from_rank (int): The rank ID from which to read the data.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the load.

Supported values:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.
Ensures global coherence by invalidating stale GPU cache lines.

volatile (bool, optional): If True, disables compiler optimizations that
could reorder or eliminate the load.

Returns:
Block: The loaded value from the target memory location.
Expand All @@ -1493,12 +1509,12 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
>>> return data
"""
translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases)
result = tl.load(translated_ptr, mask=mask)
result = tl.load(translated_ptr, mask=mask, cache_modifier=cache_modifier, volatile=volatile)
return result


@triton.jit
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modifier=None):
"""
Writes data to the specified rank's memory location.

Expand All @@ -1507,13 +1523,25 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
the provided data to the target memory location. If the from_rank and to_rank are the same,
this function performs a local store operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global store instruction. These affect cache usage across the CU (L1),
L2, and last-level cache (LLC), following the CDNA ISA.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
value (Block): The tensor of elements to be stored.
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the data will be written.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:

- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None
Expand All @@ -1528,11 +1556,13 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
>>> iris.store(ptr, value, cur_rank, remote_rank, heap_bases)
"""
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
tl.store(translated_ptr, value, mask=mask)
tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier)


@triton.jit
def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
def get(
from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, load_cache_modifier=None, store_cache_modifier=None
):
"""
Copies data from the specified rank's memory to the current rank's local memory.

Expand All @@ -1549,6 +1579,19 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None

Expand All @@ -1561,13 +1604,15 @@ def get(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
"""
translated_from_ptr = __translate(from_ptr, from_rank, to_rank, heap_bases)

data = tl.load(translated_from_ptr, mask=mask)
data = tl.load(translated_from_ptr, mask=mask, cache_modifier=load_cache_modifier)

tl.store(to_ptr, data, mask=mask)
tl.store(to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
def put(
from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None, load_cache_modifier=None, store_cache_modifier=None
):
"""
Copies data from the current rank's local memory to the specified rank's memory.
This function performs a memory write operation by loading data from the current
Expand All @@ -1583,6 +1628,19 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None.

load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are:
- None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy.
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy.
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.

store_cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:
- None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

Returns:
None

Expand All @@ -1595,9 +1653,9 @@ def put(from_ptr, to_ptr, from_rank, to_rank, heap_bases, mask=None):
"""
translated_to_ptr = __translate(to_ptr, from_rank, to_rank, heap_bases)

data = tl.load(from_ptr, mask=mask)
data = tl.load(from_ptr, mask=mask, cache_modifier=load_cache_modifier)

tl.store(translated_to_ptr, data, mask=mask)
tl.store(translated_to_ptr, data, mask=mask, cache_modifier=store_cache_modifier)


@triton.jit
Expand Down
17 changes: 16 additions & 1 deletion tests/run_tests_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def _distributed_worker(rank, world_size, test_file, pytest_args):
try:
# Run pytest directly in this process
exit_code = pytest.main([test_file] + pytest_args)
# If tests failed, exit with the failure code
if exit_code != 0:
sys.exit(exit_code)
return exit_code
finally:
# Restore original argv
Expand Down Expand Up @@ -82,7 +85,19 @@ def main():
print(f"args={args}, test_file={test_file}, pytest_args={pytest_args}")

# Run all tests within a single distributed process group
mp.spawn(_distributed_worker, args=(num_ranks, test_file, pytest_args), nprocs=num_ranks, join=True)
try:
mp.spawn(
_distributed_worker,
args=(num_ranks, test_file, pytest_args),
nprocs=num_ranks,
join=True,
)
except SystemExit as e:
# Catch sys.exit() from worker and return same exit code
sys.exit(e.code if isinstance(e.code, int) else 1)
except Exception:
# Any other unhandled exception = failure
sys.exit(1)


if __name__ == "__main__":
Expand Down
111 changes: 111 additions & 0 deletions tests/unittests/test_get_cache_modifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import torch
import triton
import triton.language as tl
import pytest
import iris
from itertools import product


@triton.jit
def get_kernel(
data,
results,
cur_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases: tl.tensor,
load_cache_modifier: tl.constexpr,
store_cache_modifier: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < BLOCK_SIZE

acc = tl.zeros([BLOCK_SIZE], dtype=data.type.element_ty)

# Loop over all ranks, get the stored data with cache modifiers
# We test default values set by the function when parameters are None
for target_rank in range(num_ranks):
if load_cache_modifier is None and store_cache_modifier is None:
iris.get(data + offsets, results + offsets, cur_rank, target_rank, heap_bases, mask=mask)
elif load_cache_modifier is None:
iris.get(
data + offsets,
results + offsets,
cur_rank,
target_rank,
heap_bases,
mask=mask,
store_cache_modifier=store_cache_modifier,
)
elif store_cache_modifier is None:
iris.get(
data + offsets,
results + offsets,
cur_rank,
target_rank,
heap_bases,
mask=mask,
load_cache_modifier=load_cache_modifier,
)
else:
iris.get(
data + offsets,
results + offsets,
cur_rank,
target_rank,
heap_bases,
mask=mask,
load_cache_modifier=load_cache_modifier,
store_cache_modifier=store_cache_modifier,
)
acc += tl.load(results + offsets, mask=mask)

# Store the accumulated value back to the output
tl.store(results + offsets, acc, mask=mask)


# Define cache modifiers for load and store operations
LOAD_CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"]
STORE_CACHE_MODIFIERS = [None, "", ".wb", ".cg", ".cs", ".wt"]


@pytest.mark.parametrize(
"load_cache_modifier,store_cache_modifier", list(product(LOAD_CACHE_MODIFIERS, STORE_CACHE_MODIFIERS))
)
def test_get_cache_modifiers(load_cache_modifier, store_cache_modifier):
"""Test get (copy from other rank) with various cache modifiers."""
shmem = iris.iris(1 << 20)
num_ranks = shmem.get_num_ranks()
heap_bases = shmem.get_heap_bases()
cur_rank = shmem.get_rank()

BLOCK_SIZE = 16
data = shmem.ones(BLOCK_SIZE, dtype=torch.float32)
results = shmem.zeros_like(data)

shmem.barrier()

grid = lambda meta: (1,)
get_kernel[grid](
data, results, cur_rank, num_ranks, BLOCK_SIZE, heap_bases, load_cache_modifier, store_cache_modifier
)
shmem.barrier()

# Verify the result - should get data from all ranks (including self)
expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * num_ranks

try:
torch.testing.assert_close(results, expected, rtol=0, atol=0)
except AssertionError as e:
print(
f"GET test failed with load_cache_modifier={load_cache_modifier}, store_cache_modifier={store_cache_modifier}"
)
print(e)
print("Expected:", expected)
print("Actual:", results)
raise
82 changes: 82 additions & 0 deletions tests/unittests/test_load_cache_modifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import torch
import triton
import triton.language as tl
import pytest
import iris
from itertools import product


@triton.jit
def kernel(
data,
results,
source_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
heap_bases: tl.tensor,
cache_modifier: tl.constexpr,
volatile: tl.constexpr,
):
pid = tl.program_id(0)

partner = int((source_rank + num_ranks // 2) % num_ranks)
# Compute start index of this block
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

# Guard for out-of-bounds accesses
mask = offsets < BLOCK_SIZE

if cache_modifier is None:
result = iris.load(data + offsets, source_rank, partner, heap_bases, mask=mask, volatile=volatile)
else:
result = iris.load(
data + offsets,
source_rank,
partner,
heap_bases,
mask=mask,
cache_modifier=cache_modifier,
volatile=volatile,
)

tl.store(results + offsets, result, mask=mask)


# Define cache modifiers and volatile options
CACHE_MODIFIERS = [None, "", ".ca", ".cg", ".cv"]
VOLATILE_OPTIONS = [False, True]


@pytest.mark.parametrize("cache_modifier,volatile", list(product(CACHE_MODIFIERS, VOLATILE_OPTIONS)))
def test_load_cache_modifiers(cache_modifier, volatile):
"""Test load with various cache modifiers and volatile settings."""
shmem = iris.iris(1 << 20)
num_ranks = shmem.get_num_ranks()
heap_bases = shmem.get_heap_bases()
source_rank = shmem.get_rank()
partner = int((source_rank + num_ranks // 2) % num_ranks)

BLOCK_SIZE = 16
data = shmem.full((BLOCK_SIZE,), source_rank, dtype=torch.float32)
results = shmem.zeros_like(data)

shmem.barrier()

grid = lambda meta: (1,)
kernel[grid](data, results, source_rank, num_ranks, BLOCK_SIZE, heap_bases, cache_modifier, volatile)
shmem.barrier()

# Verify the result
expected = torch.ones(BLOCK_SIZE, dtype=torch.float32, device="cuda") * partner

try:
torch.testing.assert_close(results, expected, rtol=0, atol=0)
except AssertionError as e:
print(e)
print("Expected:", expected)
print("Actual:", results)
raise
Loading