Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

import asyncio
import operator
import struct
from contextlib import asynccontextmanager, contextmanager
from functools import reduce
from typing import TYPE_CHECKING, Any

from rapidsmpf.memory.packed_data import PackedData
from rapidsmpf.streaming.coll.allgather import AllGather
from rapidsmpf.streaming.core.message import Message
from rapidsmpf.streaming.cudf.channel_metadata import (
ChannelMetadata,
Expand Down Expand Up @@ -398,3 +401,46 @@ def opaque_reservation(
yield context.br().reserve_device_memory_and_spill(
estimated_bytes, allow_overbooking=True
)


async def allgather_reduce(
context: Context,
op_id: int,
*local_values: int,
) -> tuple[int, ...]:
"""
Allgather local scalar values and sum each across all ranks.

Parameters
----------
context
The rapidsmpf context.
op_id
The collective operation ID for this allgather.
*local_values
One or more local scalar values to contribute.

Returns
-------
tuple[int, ...]
The sum of each local_value across all ranks.
"""
n = len(local_values)
fmt = "q" * n
data = struct.pack(fmt, *local_values)
packed = PackedData.from_host_bytes(data, context.br())

allgather = AllGather(context, op_id)
allgather.insert(0, packed)
allgather.insert_finished()

results = await allgather.extract_all(context, ordered=False)

totals = [0] * n
for packed_result in results:
result_bytes = packed_result.to_host_bytes()
values = struct.unpack(fmt, result_bytes)
for i, v in enumerate(values):
totals[i] += v

return tuple(totals)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""Tests for RapidsMPF AllGather functionality."""
Expand All @@ -14,6 +14,7 @@
import pylibcudf as plc

from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager
from cudf_polars.experimental.rapidsmpf.utils import allgather_reduce

if TYPE_CHECKING:
from rapidsmpf.streaming.core.context import Context
Expand Down Expand Up @@ -52,3 +53,18 @@ async def _test_allgather(context: Context) -> None:

def test_allgather(local_context: Context) -> None:
asyncio.run(_test_allgather(local_context))


async def _test_allgather_reduce(context: Context) -> None:
"""Test allgather_reduce with single and multiple values."""
# Test with a single value
(result,) = await allgather_reduce(context, 0, 42)
assert result == 42 # Single rank, so sum is just the local value

# Test with multiple values
results = await allgather_reduce(context, 1, 10, 20, 30)
assert results == (10, 20, 30) # Single rank, so sums are just the local values


def test_allgather_reduce(local_context: Context) -> None:
asyncio.run(_test_allgather_reduce(local_context))