Skip to content

[WIP] Clarify dataset ownership and allocation semantics#1738

Draft
seunghwak wants to merge 5 commits intorapidsai:mainfrom
seunghwak:enh_datasets_api
Draft

[WIP] Clarify dataset ownership and allocation semantics#1738
seunghwak wants to merge 5 commits intorapidsai:mainfrom
seunghwak:enh_datasets_api

Conversation

@seunghwak
Copy link
Contributor

@seunghwak seunghwak commented Jan 27, 2026

Initial attempt to address #1574 and #1571.

Currently, index constructors accept an mdspan and internally invoke make_strided|aligned_dataset to construct a strided_dataset from input data with potentially unknown alignment. When the input is properly aligned, a zero-copy view is used; otherwise, an aligned copy is created. This implicit copy can unexpectedly double GPU memory usage, which may be surprising to API users.

Current API

   auto dataset = raft::make_device_matrix<float, int64_t>(res, n_rows, n_cols);
   auto knn_graph = raft::make_device_matrix<uint32_t, int64_t>(res, n_rows, graph_degree);

   // custom loading and graph creation
   // load_dataset(dataset.view());
   // create_knn_graph(knn_graph.view());
   // Wrap the existing device arrays into an index structure
   cagra::index<T, IdxT> index(res, metric, raft::make_const_mdspan(dataset.view()),
                               raft::make_const_mdspan(knn_graph.view()));

New API (explicitly create strided_dataset)

   auto dataset_matrix = raft::make_device_matrix<float, int64_t>(res, n_rows, n_cols);
   auto knn_graph = raft::make_device_matrix<uint32_t, int64_t>(res, n_rows, graph_degree);
   // custom loading and graph creation
   // load_dataset(dataset_matrix.view());
   // create_knn_graph(knn_graph.view());
   // create strided dataset (zero-copy if aligned, throws if not)
   auto dataset = make_strided_dataset_zerocopy(dataset_matrix.view());
   or
   // create strided_dataset (create an aligned copy)
   auto dataset = make_strided_dataset_owning(dataset_matrix.view());
   // Wrap the existing strided dataset into an index structure
   cagra::index<T, IdxT> index(res, metric, *dataset,
                               raft::make_const_mdspan(knn_graph.view()));

We deprecate the index constructors that take device mdspan (device_matrix_view) to avoid surprising implicit copy.

In addition, this PR makes an initial attempt to update the build functions to accept strided_dataset, in addition to host_matrix_view. At present, the build overload that takes strided_dataset supports only iterative-CAGRA for initial graph construction. Supporting NN-descent or IVF-PQ is more involved, as these implementations assume a row-major layout; extending them to handle strided layouts would require more substantial code changes.

Closes #1571
Closes #1574

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 27, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@seunghwak
Copy link
Contributor Author

@cjnolet Let me know whether this PR aligns with what you are thinking.

@cjnolet
Copy link
Member

cjnolet commented Jan 29, 2026

Hey @seunghwak, thank you for giving this some initial thought and sharing here.

I think this is a good start. It might also be worth me mentioning some of the other dataset types that this design is going to prepare us for- in addition to the padded_dataset, we also intend to provide a vpq_dataset and a pq_dataset.

I wonder if maybe we should make the dataset memory type explicit too? (device_padded_dataset, device_padded_dataset_view, host_padded_dataset_view)

The nice thing about the dataset abstraction is that it'll allow us to have the pq_dataset contain additoinal state like trained centroids and codebooks, which need to be passed along w/ the pq-encoded vectors because they are used to compute the distances.

It's nice to have these things abstracted behind the dataset class so that, for example, we aren't having to provide additional additional overloads for each different combination we support- we can support a dataset_view as input to the build() functions and keep supporting new dataset types in behind the scenes, while not having to keep changing the APIs.

   // create strided dataset (zero-copy if aligned, throws if not)
   auto dataset = make_strided_dataset_zerocopy(dataset_matrix.view());

I wonder if maybe instead of zerocopy and owning maybe we could use make_strided_dataset_view and make_strided_dataset? We could just always accept the view. Or do you not think view conveys the proper message to the user?

@cjnolet cjnolet added improvement Improves an existing functionality non-breaking Introduces a non-breaking change breaking Introduces a breaking change and removed non-breaking Introduces a non-breaking change breaking Introduces a breaking change labels Jan 29, 2026
@seunghwak
Copy link
Contributor Author

I wonder if maybe we should make the dataset memory type explicit too? (device_padded_dataset, device_padded_dataset_view, host_padded_dataset_view)

Yes, if there is a need to support padded data set for host as well. I though the current strided_dataset is implicitly device-only, considering

using view_type = raft::device_matrix_view<const value_type, index_type, raft::layout_stride>;

const bool device_accessible = device_ptr != nullptr;

but if you see a need to support padded dataset for host data as well, I will explicitly create a class (or an alias like host_matrix_view and device_matrix_view) for host data as well. For consistency, we may better copy how we handle mdarray and mdspan in designing the dataset API.

It's nice to have these things abstracted behind the dataset class so that, for example, we aren't having to provide additional additional overloads for each different combination we support- we can support a dataset_view as input to the build() functions and keep supporting new dataset types in behind the scenes, while not having to keep changing the APIs.

You mean that we want to take a dataset class (https://github.com/rapidsai/cuvs/blob/main/cpp/include/cuvs/neighbors/common.hpp#L140) object and internally dynamic cast to support different dataset types? I will update the PR in this direction but please correct me if I mis-interpreted your intention.

I wonder if maybe instead of zerocopy and owning maybe we could use make_strided_dataset_view and make_strided_dataset? We could just always accept the view. Or do you not think view conveys the proper message to the user?

No problem, I will update the function names.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

Development

Successfully merging this pull request may close these issues.

[FEA] Standardize Datasets API for public end-user use [FEA] Indexes to accept Dataset directly

2 participants