Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ba7028b
Add initial cache modifiers code and docs
mawad-amd Sep 10, 2025
276713b
Add test
mawad-amd Sep 10, 2025
6f6818f
Apply Ruff auto-fixes
github-actions[bot] Sep 10, 2025
9ad63a0
Use `None` for default value
mawad-amd Sep 13, 2025
677c966
Apply Ruff auto-fixes
github-actions[bot] Sep 13, 2025
af3592d
Cleanup the test
mawad-amd Sep 14, 2025
8a411a2
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
ff26f96
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Sep 14, 2025
8f76d95
Check return value
mawad-amd Sep 14, 2025
87fb74a
Remove volatile from store
mawad-amd Sep 14, 2025
162ec39
Add test store
mawad-amd Sep 14, 2025
01da6ca
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
9a27ead
Add put/get modifiers
mawad-amd Sep 14, 2025
99ee66c
Add tests for put and get cache modifiers
mawad-amd Sep 14, 2025
74d0133
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
e76d4c5
Test default values
mawad-amd Sep 14, 2025
451ee99
Fix default value docstring
mawad-amd Sep 14, 2025
b524f40
Fix tests
mawad-amd Sep 14, 2025
b8bd8a7
Apply Ruff auto-fixes
github-actions[bot] Sep 14, 2025
0a157b5
Sync cache modifiers branch with main and add cache modifiers to copy…
Copilot Oct 11, 2025
b127b91
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Oct 11, 2025
c9f314f
Merge branch 'main' into muhaawad/cache-modifiers
mawad-amd Oct 24, 2025
e74aacd
Fix device mismatch in test_copy_cache_modifiers assertions (#271)
Copilot Oct 29, 2025
88970ee
Fix pointer arithmetic in test_copy_cache_modifiers (#273)
Copilot Oct 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions iris/iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,7 @@ def __translate(ptr, from_rank, to_rank, heap_bases):


@triton.jit
def load(pointer, to_rank, from_rank, heap_bases, mask=None):
def load(pointer, to_rank, from_rank, heap_bases, mask=None, cache_modifier="", volatile=False):
"""
Loads a value from the specified rank's memory location.

Expand All @@ -1370,23 +1370,39 @@ def load(pointer, to_rank, from_rank, heap_bases, mask=None):
data from the target memory location. If the from_rank and to_rank are the same,
this function performs a local load operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global load instruction. These affect cache usage across the CU,
L2, and last-level caches.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the pointer will be translated. Must be the current rank where the pointer is local.
from_rank (int): The rank ID from which to read the data.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not load the data at address pointer[idx]. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the load.

Supported values:
- "": Default behavior. Behaves the same as .ca (cache at all levels (CU, L2, LLC) with LRU policy)
- ".ca": Cache at all levels (CU, L2, LLC) with LRU policy
- ".cg": Bypasses the CU (L1) cache, streams through L2, and may hit in LLC but the line is not retained or inserted.
- ".cv": Bypasses all GPU caches (CU and L2) and fetches directly from system memory. If data exists in the LLC, it may hit, but is not retained or inserted.
Ensures global coherence by invalidating stale GPU cache lines.

volatile (bool, optional): If True, disables compiler optimizations that
could reorder or eliminate the load.

Returns:
Block: The loaded value from the target memory location.
"""
translated_ptr = __translate(pointer, to_rank, from_rank, heap_bases)
result = tl.load(translated_ptr, mask=mask)
result = tl.load(translated_ptr, mask=mask, cache_modifier=cache_modifier, volatile=volatile)
return result


@triton.jit
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
def store(pointer, value, from_rank, to_rank, heap_bases, mask=None, cache_modifier="", volatile=False):
"""
Writes data to the specified rank's memory location.

Expand All @@ -1395,19 +1411,34 @@ def store(pointer, value, from_rank, to_rank, heap_bases, mask=None):
the provided data to the target memory location. If the from_rank and to_rank are the same,
this function performs a local store operation.

The `cache_modifier` parameter controls instruction-level cache behavior
by setting the appropriate scope (`SC0`, `SC1`) and non-temporal (`NT`) bits
in the global store instruction. These affect cache usage across the CU (L1),
L2, and last-level cache (LLC), following the CDNA ISA.

Args:
pointer (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the from_rank's address space that will be translated to the to_rank's address space. Must be the current rank where the pointer is local.
value (Block): The tensor of elements to be stored.
from_rank (int): The rank ID from which the pointer originates. Must be the current rank where the pointer is local.
to_rank (int): The rank ID to which the data will be written.
heap_bases (triton.PointerType): Array containing the heap base addresses for all ranks.
mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None.
cache_modifier (str, optional): Controls cache behavior of the store. Supported values are:

- "": *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy.
- ".wb": Write-back. Write-allocate on L1 miss, inserted into caches and written back later.
- ".cg": Cache Global. Equivalent to ".wb" — stored through L1 → L2 → LLC under LRU.
- ".cs": Cache Streaming. Bypasses L1, streamed through L2, not retained in LLC.
- ".wt": Write-Through. Bypasses L1 and L2 (coherent cache bypass), may hit in LLC with LRU.

volatile (bool, optional): If True, disables compiler optimizations that
could reorder or eliminate the store.

Returns:
None
"""
translated_ptr = __translate(pointer, from_rank, to_rank, heap_bases)
tl.store(translated_ptr, value, mask=mask)
tl.store(translated_ptr, value, mask=mask, cache_modifier=cache_modifier, volatile=volatile)


@triton.jit
Expand Down
Loading
Loading