Skip to content

C++ Disjoint sampling implementation#5414

Open
ChuckHastings wants to merge 7 commits intorapidsai:mainfrom
ChuckHastings:disjoint_sampling_implementation
Open

C++ Disjoint sampling implementation#5414
ChuckHastings wants to merge 7 commits intorapidsai:mainfrom
ChuckHastings:disjoint_sampling_implementation

Conversation

@ChuckHastings
Copy link
Collaborator

This PR adds the disjoint sampling feature to sampling in C++. C++ tests exist for homogeneous uniform and biased sampling, both for SG and MG.

This should get us started on the disjoint feature, we can add tests for heterogeneous and temporal variations as well, I plan to do that in a follow on activity. We also need to test that the C API level, which I will add to a later pull request.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 29, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@ChuckHastings ChuckHastings self-assigned this Jan 29, 2026
@ChuckHastings ChuckHastings added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Jan 29, 2026
@ChuckHastings ChuckHastings marked this pull request as ready for review January 29, 2026 22:18
@ChuckHastings ChuckHastings requested review from a team as code owners January 29, 2026 22:18
Copy link
Contributor

@seunghwak seunghwak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review part 1.

rmm::device_uvector<vertex_t>,
std::vector<arithmetic_device_uvector_t>,
std::optional<rmm::device_uvector<int32_t>>>
gather_one_hop_edgelist_with_visited(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this name, this function finds all one hop edges and filters out all the edges with visited destinations. This function name some what implies that it keeps the edges with visited destinations. Can we rename this function to easily find what this function is doing?

What about something like gather_edgelist_to_unvisited_neighbors or gather_one_hop_edgelist_to_unvisited_neighbors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to gather_one_hop_edgelist_to_unvisited_neighbors in next push.

Comment on lines 349 to 350
std::optional<rmm::device_uvector<vertex_t>>& visited_vertices,
std::optional<rmm::device_uvector<int32_t>>& visited_vertex_labels,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better take visited_vertices and visited_vertex_lables as R-value references and return the new ones to be more functional.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change in the next push.

std::optional<rmm::device_uvector<int32_t>>& visited_vertex_labels,
bool do_expensive_check)
{
CUGRAPH_EXPECTS(visited_vertices, "Visited vertices must be provided");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, why are we taking std::optional? Better to detect this in compile time than run-time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I started with this as a change to gather_one_hop_vertices, and it would be optional... if specified it added the new logic. Then I realized I needed to restructure the main flow, seemed too complex, so I made it a separate function.

I'll change it in the next push.

std::optional<edge_property_view_t<edge_t, int32_t const*>> edge_type_view,
raft::device_span<vertex_t const> active_majors,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we are using uint8_t in gather_one_hop_edgelist as well, but should we better use bool instead of uint8_t here?

Comment on lines 377 to 378
visited_vertices = visited_vertices->data(),
visited_labels = visited_vertex_labels->data(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any assumptions about how visited_vertices and visited_vertex_labels will be partitioned in multi-GPU? Will this code work in multi-GPU? (especially in extreme-scale?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is managed by update_dst_visited_vertices_and_labels. They are replicated (allgatherv) across the minor communicator. This is necessary for the logic to work.

Size should be reasonably manageable. Number of hops * fanout / p_row (or something like that) would be the expected number of entries per GPU.

keep_count);
} else {
rmm::device_uvector<vertex_t> remove_srcs(result_srcs.size(), handle.get_stream());
rmm::device_uvector<vertex_t> remove_dsts(result_dsts.size(), handle.get_stream());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need srcs here? Aren't we just removing destination vertices that appear more than once?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this code won't work across multi-GPUs, won't this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use srcs in the sort so that I select the edge with the lowest source as the one that is selected. Perhaps not important in the gather_one_hop path, but the other path needs to break the ties in a way that guarantees that at least one source is fully sampled in each iteration of the loop. I kept that in place here more for consistency, but I can drop it if you think we shouldn't worry about that consistency.

The multi-GPU issue is a defect, I'll fix that in the next push. I have an extra shuffle and check in the sample edges path that I should replicate after this to check for duplicates across GPUs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we break ties with positions?

but the other path needs to break the ties in a way that guarantees that at least one source is fully sampled in each iteration of the loop.

I am having hard time interpreting this. You mean in the sampling path, you want each active major to have at least one sampled edges (so, if there are multiple (different active_major, same minor) pairs, you prefer to select the one with no currently sampled edges?).

Comment on lines +450 to +454
positions = detail::keep_marked_entries(
handle,
std::move(positions),
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this code is same for both the if and else cases? Should we replicate the code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could pull it out, but right now the scope of keep_flags and keep_count is inside the code block and they are automatically freed when we exit.

If I pull it out I'll need to define them explicitly and then resize and shrink to fit. Not sure which is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, yeah... those are comparable... so not worth the additional work.

Comment on lines 522 to 523
rmm::device_uvector<vertex_t> new_visited_vertices(visited_vertices->size() + result_dsts.size(),
handle.get_stream());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better sort result_dsts first and call thrust::merge. Sorting the entire array again and again is expensive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be calling the update_dst_visited_vertices_and_labels function which handles the MG support as well.

Now your thrust::merge comment might be relevant there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next push will call the function instead.

Comment on lines +90 to +91
raft::device_span<vertex_t const> visited_vertices{};
cuda::std::optional<raft::device_span<int32_t const>> visited_vertex_labels{};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any assumption about partitioning here? will this follow edge partitioning or vertex partitioning? If this follows edge partitioning and if this is for edge destinations, we may better use minors instead of vertices in the naming.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

visited_vertices/visited_vertex_labels is an allgatherv across the minor communicator, so all elements are replicated across the minor communicator. This allows any GPU that might include a vertex as a destination to have that information.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, should we better name this as visited_minors & visited_minor_labels?

Copy link
Contributor

@seunghwak seunghwak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review part 2

}

// Check for duplicates in the sampled minor vertices
rmm::device_uvector<vertex_t> local_majors(sampled_majors.size(), handle.get_stream());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this code to remove duplicates in the sampled minor vertices is same for both this function and the gather-one-hop-edgelist function. Can't we merge the two to a single utility function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look at this as I fix the MG portion of gather-one-hop-edgelist

rmm::device_uvector<vertex_t>,
std::vector<arithmetic_device_uvector_t>,
std::optional<rmm::device_uvector<int32_t>>>
sample_edges_with_visited(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the gather one-hop edge list function, should we better rename this function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to sample_edges_to_unvisited_neighbors in the latest push.

Comment on lines 1343 to 1344
std::optional<rmm::device_uvector<vertex_t>>& visited_vertices,
std::optional<rmm::device_uvector<int32_t>>& visited_vertex_labels,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here, better take R-value references.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in next push

std::optional<rmm::device_uvector<int32_t>>>
update_dst_visited_vertices_and_labels(
raft::handle_t const& handle,
graph_view_t<vertex_t, edge_t, false, multi_gpu> const& graph_view,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any assumptions about MG partitioning here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sampled_vertices follow the vertex partitioning and visited_vertices/visited_vertex follows edge partitioning (for minors), we should better rename accordingly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions on names?

Yes, sampled_vertices are partitioned by vertex partitioning,
visited_vertices/visited_vertex_labels are replicated (allgatherv) across the minor communicator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

visited_minors & visited_minor_labels (if in the detail namespace) or visited_dsts and visited_dst_lables (if in the public namespace)?

raft::device_span<vertex_t const> sampled_vertices,
std::optional<raft::device_span<int32_t const>> sampled_vertex_labels)
{
CUGRAPH_EXPECTS(visited_vertices.has_value(), "Invalid input: visited_vertices must be provided");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, why are we taking std::optional here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in next push.


if constexpr (multi_gpu) {
std::tie(new_samples, props) = cugraph::shuffle_int_vertices(
handle, std::move(new_samples), std::move(props), graph_view.vertex_partition_range_lasts());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about labels here? No need to shuffle labels as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Labels are in props. I reorganized in the next push, the props.push_back above and the std::move below are all in the multi_gpu block which makes it clear what's happening.

Copy link
Collaborator Author

@ChuckHastings ChuckHastings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will address many of these in my next push. A few comments/questions that aren't corrected yet. A few things not commented on I'll follow up with later.

rmm::device_uvector<vertex_t>,
std::vector<arithmetic_device_uvector_t>,
std::optional<rmm::device_uvector<int32_t>>>
gather_one_hop_edgelist_with_visited(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to gather_one_hop_edgelist_to_unvisited_neighbors in next push.

std::optional<rmm::device_uvector<int32_t>>& visited_vertex_labels,
bool do_expensive_check)
{
CUGRAPH_EXPECTS(visited_vertices, "Visited vertices must be provided");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I started with this as a change to gather_one_hop_vertices, and it would be optional... if specified it added the new logic. Then I realized I needed to restructure the main flow, seemed too complex, so I made it a separate function.

I'll change it in the next push.

Comment on lines 349 to 350
std::optional<rmm::device_uvector<vertex_t>>& visited_vertices,
std::optional<rmm::device_uvector<int32_t>>& visited_vertex_labels,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change in the next push.

Comment on lines 377 to 378
visited_vertices = visited_vertices->data(),
visited_labels = visited_vertex_labels->data(),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is managed by update_dst_visited_vertices_and_labels. They are replicated (allgatherv) across the minor communicator. This is necessary for the logic to work.

Size should be reasonably manageable. Number of hops * fanout / p_row (or something like that) would be the expected number of entries per GPU.

Comment on lines +450 to +454
positions = detail::keep_marked_entries(
handle,
std::move(positions),
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could pull it out, but right now the scope of keep_flags and keep_count is inside the code block and they are automatically freed when we exit.

If I pull it out I'll need to define them explicitly and then resize and shrink to fit. Not sure which is better.

Comment on lines 1343 to 1344
std::optional<rmm::device_uvector<vertex_t>>& visited_vertices,
std::optional<rmm::device_uvector<int32_t>>& visited_vertex_labels,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in next push

raft::device_span<vertex_t const> sampled_vertices,
std::optional<raft::device_span<int32_t const>> sampled_vertex_labels)
{
CUGRAPH_EXPECTS(visited_vertices.has_value(), "Invalid input: visited_vertices must be provided");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in next push.

std::optional<rmm::device_uvector<int32_t>>>
update_dst_visited_vertices_and_labels(
raft::handle_t const& handle,
graph_view_t<vertex_t, edge_t, false, multi_gpu> const& graph_view,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestions on names?

Yes, sampled_vertices are partitioned by vertex partitioning,
visited_vertices/visited_vertex_labels are replicated (allgatherv) across the minor communicator.


if constexpr (multi_gpu) {
std::tie(new_samples, props) = cugraph::shuffle_int_vertices(
handle, std::move(new_samples), std::move(props), graph_view.vertex_partition_range_lasts());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Labels are in props. I reorganized in the next push, the props.push_back above and the std::move below are all in the multi_gpu block which makes it clear what's happening.

// Implement this, should be a little easier than sample_edges_to_unvisited_neighbors, since we
// don't need to compute the probability of sampling for an edge based on the label/tag. We can
// just extract everything and then filter the results based on the visited vertices and vertex
// labels.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Implement this,"=>Isn't this now an outdated comment? You already implemented this.

raft::device_span<vertex_t const> active_majors,
std::optional<raft::device_span<int32_t const>> active_major_labels,
std::optional<raft::device_span<uint8_t const>> gather_flags,
rmm::device_uvector<vertex_t>&& visited_vertices,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In multi-GPU, is this visited_minors? If this just stores the "visited_(local_)vertices, this code won't work.

// don't need to compute the probability of sampling for an edge based on the label/tag. We can
// just extract everything and then filter the results based on the visited vertices and vertex
// labels.
auto [result_srcs, result_dsts, result_properties, result_labels] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we may consistently use majors & minors in the detail namespace (even though here, store_transposed == false, so minors are always destinations).

Comment on lines +450 to +454
positions = detail::keep_marked_entries(
handle,
std::move(positions),
raft::device_span<uint32_t const>{keep_flags.data(), keep_flags.size()},
keep_count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, yeah... those are comparable... so not worth the additional work.

keep_count);
} else {
rmm::device_uvector<vertex_t> remove_srcs(result_srcs.size(), handle.get_stream());
rmm::device_uvector<vertex_t> remove_dsts(result_dsts.size(), handle.get_stream());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we break ties with positions?

but the other path needs to break the ties in a way that guarantees that at least one source is fully sampled in each iteration of the loop.

I am having hard time interpreting this. You mean in the sampling path, you want each active major to have at least one sampled edges (so, if there are multiple (different active_major, same minor) pairs, you prefer to select the one with no currently sampled edges?).

Comment on lines +90 to +91
raft::device_span<vertex_t const> visited_vertices{};
cuda::std::optional<raft::device_span<int32_t const>> visited_vertex_labels{};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, should we better name this as visited_minors & visited_minor_labels?

std::optional<edge_arithmetic_property_view_t<edge_t>> edge_bias_view,
cugraph::vertex_frontier_t<vertex_t, tag_t, multi_gpu, false>& vertex_frontier,
rmm::device_uvector<vertex_t>&& visited_vertices,
std::optional<rmm::device_uvector<int32_t>>&& visited_vertex_labels,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should better be visited_minors & visited_minor_labels (or visited_dsts and visited_dst_labels if this function is in the public namespace)?

std::optional<raft::device_span<int32_t const>> active_major_labels,
raft::host_span<size_t const> Ks,
rmm::device_uvector<vertex_t>&& visited_vertices,
std::optional<rmm::device_uvector<int32_t>>&& visited_vertex_labels,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

visited_minors & visited_minor_labels (or visited_dsts & visisted_dst_labels if in the public namespace)?

std::optional<rmm::device_uvector<int32_t>>>
update_dst_visited_vertices_and_labels(
raft::handle_t const& handle,
graph_view_t<vertex_t, edge_t, false, multi_gpu> const& graph_view,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

visited_minors & visited_minor_labels (if in the detail namespace) or visited_dsts and visited_dst_lables (if in the public namespace)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improvement / enhancement to an existing function non-breaking Non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants