From ca07c081744c559858a4b460fd9243e51b05e5b3 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 9 Jan 2026 13:59:57 -0800 Subject: [PATCH 1/7] first commit (unclean) --- cpp/include/cuvs/cluster/kmeans.hpp | 143 +++++- cpp/src/cluster/detail/kmeans_batched.cuh | 536 ++++++++++++++++++++++ cpp/src/cluster/kmeans_fit_double.cu | 44 +- cpp/src/cluster/kmeans_fit_float.cu | 44 +- 4 files changed, 764 insertions(+), 3 deletions(-) create mode 100644 cpp/src/cluster/detail/kmeans_batched.cuh diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index a8aa6b9807..3b827630ab 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -134,6 +134,147 @@ struct balanced_params : base_params { * @{ */ +/** + * @defgroup kmeans_batched Batched k-means for out-of-core / host data + * @{ + */ + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * This version supports out-of-core computation where the dataset resides + * on the host. Data is processed in batches, with partial sums accumulated + * across batches and centroids finalized at the end of each iteration. + * This is mathematically equivalent to standard kmeans. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int n_features = 15; + * float inertia; + * int n_iter; + * + * // Data on host + * std::vector h_X(n_samples * n_features); + * auto X = raft::make_host_matrix_view(h_X.data(), n_samples, n_features); + * + * // Centroids on device + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * + * kmeans::fit_batched(handle, + * params, + * X, + * 100000, // batch_size + * std::nullopt, + * centroids.view(), + * raft::make_host_scalar_view(&inertia), + * raft::make_host_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * @param[inout] centroids Cluster centers on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * @param[inout] centroids Cluster centers on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * @param[inout] centroids Cluster centers on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @} + */ + /** * @brief Find clusters with k-means algorithm. * Initial centroids are chosen with k-means++ algorithm. Empty diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh new file mode 100644 index 0000000000..63e36f2f5b --- /dev/null +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -0,0 +1,536 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "kmeans_common.cuh" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace cuvs::cluster::kmeans::batched::detail { + +/** + * @brief Sample data from host to device for initialization + * + * Samples `n_samples_to_gather` rows from host data and copies to device. + * Uses uniform strided sampling for simplicity and cache-friendliness. + */ +template +void prepare_init_sample(raft::resources const& handle, + raft::host_matrix_view X, + raft::device_matrix_view X_sample, + uint64_t seed) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_samples_out = X_sample.extent(0); + + // Use strided sampling for cache-friendliness + // For truly random, could use std::shuffle on indices first + std::mt19937 gen(seed); + std::vector indices(n_samples); + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), gen); + + std::vector host_sample(n_samples_out * n_features); + +#pragma omp parallel for + for (IndexT i = 0; i < static_cast(n_samples_out); i++) { + IndexT src_idx = indices[i]; + std::memcpy(host_sample.data() + i * n_features, + X.data_handle() + src_idx * n_features, + n_features * sizeof(DataT)); + } + + raft::copy(X_sample.data_handle(), host_sample.data(), host_sample.size(), stream); +} + +/** + * @brief Initialize centroids using k-means++ on a sample of the host data + */ +template +void init_centroids_from_host_sample(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + raft::device_matrix_view centroids, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + + // Sample size for initialization: at least 3 * n_clusters, but not more than n_samples + size_t init_sample_size = std::min(static_cast(n_samples), + std::max(static_cast(3 * n_clusters), + static_cast(10000))); + + RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); + + // Sample data from host to device + auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); + + // Run k-means++ on the sample + if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + cuvs::cluster::kmeans::detail::kmeansPlusPlus( + handle, + params, + raft::make_device_matrix_view( + init_sample.data_handle(), init_sample_size, n_features), + centroids, + workspace); + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + // Just use the first n_clusters samples + raft::copy(centroids.data_handle(), + init_sample.data_handle(), + n_clusters * n_features, + stream); + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { + // Centroids already provided, nothing to do + } else { + RAFT_FAIL("Unknown initialization method"); + } +} + +/** + * @brief Accumulate partial centroid sums and counts from a batch + * + * This function adds the partial sums from a batch to the running accumulators. + * It does NOT divide - that happens once at the end of all batches. + */ +template +void accumulate_batch_centroids( + raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view, IndexT> minClusterAndDistance, + raft::device_vector_view sample_weights, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = batch_data.extent(0); + auto n_features = batch_data.extent(1); + auto n_clusters = centroid_sums.extent(0); + + // Temporary buffers for this batch's partial results + auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(handle, n_clusters); + + // Zero the batch temporaries + thrust::fill(raft::resource::get_thrust_policy(handle), + batch_sums.data_handle(), + batch_sums.data_handle() + batch_sums.size(), + DataT{0}); + thrust::fill(raft::resource::get_thrust_policy(handle), + batch_counts.data_handle(), + batch_counts.data_handle() + batch_counts.size(), + DataT{0}); + + // Extract cluster labels from KeyValuePair + cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> + labels_itr(minClusterAndDistance.data_handle(), conversion_op); + + workspace.resize(n_samples, stream); + + // Compute weighted sum of samples per cluster for this batch + raft::linalg::reduce_rows_by_key(const_cast(batch_data.data_handle()), + batch_data.extent(1), + labels_itr, + sample_weights.data_handle(), + workspace.data(), + batch_data.extent(0), + batch_data.extent(1), + n_clusters, + batch_sums.data_handle(), + stream); + + // Compute sum of weights per cluster for this batch + raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), + labels_itr, + batch_counts.data_handle(), + static_cast(1), + static_cast(n_samples), + static_cast(n_clusters), + stream); + + // Add batch results to running accumulators + raft::linalg::add(centroid_sums.data_handle(), + centroid_sums.data_handle(), + batch_sums.data_handle(), + centroid_sums.size(), + stream); + + raft::linalg::add(cluster_counts.data_handle(), + cluster_counts.data_handle(), + batch_counts.data_handle(), + cluster_counts.size(), + stream); +} + +/** + * @brief Finalize centroids by dividing accumulated sums by counts + */ +template +void finalize_centroids(raft::resources const& handle, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_clusters = new_centroids.extent(0); + auto n_features = new_centroids.extent(1); + + // Copy sums to new_centroids first + raft::copy( + new_centroids.data_handle(), centroid_sums.data_handle(), centroid_sums.size(), stream); + + // Divide by counts: new_centroids[i] = centroid_sums[i] / cluster_counts[i] + // When count is 0, set to 0 (will be fixed below) + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(new_centroids), + cluster_counts, + new_centroids, + raft::div_checkzero_op{}); + + // Copy old centroids to new centroids where cluster_counts[i] == 0 + cub::ArgIndexInputIterator itr_wt(cluster_counts.data_handle()); + raft::matrix::gather_if( + old_centroids.data_handle(), + static_cast(old_centroids.extent(1)), + static_cast(old_centroids.extent(0)), + itr_wt, + itr_wt, + static_cast(cluster_counts.size()), + new_centroids.data_handle(), + [=] __device__(raft::KeyValuePair map) { + return map.value == DataT{0}; // predicate: copy when count is 0 + }, + raft::key_op{}, + stream); +} + +/** + * @brief Main fit function for batched k-means with host data + * + * @tparam DataT Data type (float, double) + * @tparam IndexT Index type (int, int64_t) + * + * @param[in] handle RAFT resources handle + * @param[in] params K-means parameters + * @param[in] X Input data on HOST [n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch + * @param[in] sample_weight Optional weights per sample (on host) + * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid + * @param[out] n_iter Number of iterations run + */ +template +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IndexT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, + "centroids.extent(0) must equal n_clusters"); + RAFT_EXPECTS(centroids.extent(1) == n_features, + "centroids.extent(1) must equal n_features"); + + raft::default_logger().set_level(params.verbosity); + + RAFT_LOG_DEBUG( + "KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, batch_size=%zu", + static_cast(n_samples), + static_cast(n_features), + n_clusters, + static_cast(batch_size)); + + rmm::device_uvector workspace(0, stream); + + // Initialize centroids from a sample of host data + if (params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { + init_centroids_from_host_sample(handle, params, X, centroids, workspace); + } + + // Allocate device buffers + // Batch buffer for data + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + // Batch buffer for weights + auto batch_weights = raft::make_device_vector(handle, batch_size); + // Cluster assignment for batch + auto minClusterAndDistance = + raft::make_device_vector, IndexT>(handle, batch_size); + // L2 norms of batch data + auto L2NormBatch = raft::make_device_vector(handle, batch_size); + // Temporary buffer for distance computation + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + // Accumulators for centroid computation (persist across batches within an iteration) + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto cluster_counts = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + + // Host buffer for batch data (pinned memory for faster H2D transfer) + std::vector host_batch_buffer(batch_size * n_features); + std::vector host_weight_buffer(batch_size); + + // Cluster cost for convergence check + rmm::device_scalar clusterCostD(stream); + DataT priorClusteringCost = 0; + + // Main iteration loop + for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { + RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); + + // Zero accumulators at start of each iteration + thrust::fill(raft::resource::get_thrust_policy(handle), + centroid_sums.data_handle(), + centroid_sums.data_handle() + centroid_sums.size(), + DataT{0}); + thrust::fill(raft::resource::get_thrust_policy(handle), + cluster_counts.data_handle(), + cluster_counts.data_handle() + cluster_counts.size(), + DataT{0}); + + DataT total_cost = 0; + + // Process all data in batches + for (IndexT offset = 0; offset < n_samples; offset += batch_size) { + IndexT current_batch_size = std::min(batch_size, n_samples - offset); + + // Copy batch data from host to device + raft::copy(batch_data.data_handle(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); + + // Copy or set weights for this batch + if (sample_weight) { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + offset, + current_batch_size, + stream); + } else { + thrust::fill(raft::resource::get_thrust_policy(handle), + batch_weights.data_handle(), + batch_weights.data_handle() + current_batch_size, + DataT{1}); + } + + // Create views for current batch size + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto batch_weights_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + auto minClusterAndDistance_view = + raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle(), current_batch_size); + auto L2NormBatch_view = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + // Compute L2 norms for batch if needed + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormBatch.data_handle(), + batch_data.data_handle(), + n_features, + current_batch_size, + stream); + } + + // Find nearest centroid for each sample in batch + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_data_view, + centroids_const, + minClusterAndDistance_view, + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + // Accumulate partial sums for this batch + auto minClusterAndDistance_const = + raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle(), current_batch_size); + + accumulate_batch_centroids(handle, + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + cluster_counts.view(), + workspace); + + // Accumulate cluster cost if checking convergence + if (params.inertia_check) { + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + DataT batch_cost = clusterCostD.value(stream); + total_cost += batch_cost; + } + } // end batch loop + + // Finalize centroids: divide sums by counts + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = raft::make_device_vector_view( + cluster_counts.data_handle(), n_clusters); + + finalize_centroids( + handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); + + // Compute squared norm of change in centroids + auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + new_centroids.size(), + raft::sqdiff_op{}, + stream, + centroids.data_handle(), + new_centroids.data_handle()); + + DataT sqrdNormError = 0; + raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); + + // Update centroids + raft::copy(centroids.data_handle(), new_centroids.data_handle(), new_centroids.size(), stream); + + // Check convergence + bool done = false; + if (params.inertia_check) { + if (n_iter[0] > 1) { + DataT delta = total_cost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; + } + priorClusteringCost = total_cost; + } + + raft::resource::sync_stream(handle, stream); + if (sqrdNormError < params.tol) done = true; + + if (done) { + RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); + break; + } + } // end iteration loop + + // Compute final inertia by processing all data once more + inertia[0] = 0; + for (IndexT offset = 0; offset < n_samples; offset += batch_size) { + IndexT current_batch_size = std::min(batch_size, n_samples - offset); + + raft::copy(batch_data.data_handle(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); + + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto minClusterAndDistance_view = + raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle(), current_batch_size); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormBatch.data_handle(), batch_data.data_handle(), n_features, current_batch_size, stream); + } + + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_data_view, + centroids_const, + minClusterAndDistance_view, + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + + inertia[0] += clusterCostD.value(stream); + } + + RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", inertia[0]); +} + +} // namespace cuvs::cluster::kmeans::batched::detail + diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 43f457a29a..0962c87890 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -1,8 +1,9 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -30,14 +31,55 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view inertia, \ raft::host_scalar_view n_iter); +#define INSTANTIATE_FIT_BATCHED(DataT, IndexT) \ + template void batched::detail::fit( \ + raft::resources const& handle, \ + const kmeans::params& params, \ + raft::host_matrix_view X, \ + IndexT batch_size, \ + std::optional> sample_weight, \ + raft::device_matrix_view centroids, \ + raft::host_scalar_view inertia, \ + raft::host_scalar_view n_iter); + INSTANTIATE_FIT_MAIN(double, int) INSTANTIATE_FIT_MAIN(double, int64_t) INSTANTIATE_FIT(double, int) INSTANTIATE_FIT(double, int64_t) +INSTANTIATE_FIT_BATCHED(double, int) +INSTANTIATE_FIT_BATCHED(double, int64_t) + #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT +#undef INSTANTIATE_FIT_BATCHED + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 5624151943..39cc074d9e 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -1,8 +1,9 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -30,14 +31,55 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view inertia, \ raft::host_scalar_view n_iter); +#define INSTANTIATE_FIT_BATCHED(DataT, IndexT) \ + template void batched::detail::fit( \ + raft::resources const& handle, \ + const kmeans::params& params, \ + raft::host_matrix_view X, \ + IndexT batch_size, \ + std::optional> sample_weight, \ + raft::device_matrix_view centroids, \ + raft::host_scalar_view inertia, \ + raft::host_scalar_view n_iter); + INSTANTIATE_FIT_MAIN(float, int) INSTANTIATE_FIT_MAIN(float, int64_t) INSTANTIATE_FIT(float, int) INSTANTIATE_FIT(float, int64_t) +INSTANTIATE_FIT_BATCHED(float, int) +INSTANTIATE_FIT_BATCHED(float, int64_t) + #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT +#undef INSTANTIATE_FIT_BATCHED + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, From f1a19dfd29ea1f5ee32c47814c580628fc3cc2ec Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 9 Jan 2026 17:59:35 -0800 Subject: [PATCH 2/7] style --- cpp/src/cluster/detail/kmeans_batched.cuh | 140 +++++++++++----------- 1 file changed, 69 insertions(+), 71 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 63e36f2f5b..1999d80c69 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -47,29 +47,27 @@ namespace cuvs::cluster::kmeans::batched::detail { * Samples `n_samples_to_gather` rows from host data and copies to device. * Uses uniform strided sampling for simplicity and cache-friendliness. */ -template +template void prepare_init_sample(raft::resources const& handle, - raft::host_matrix_view X, - raft::device_matrix_view X_sample, - uint64_t seed) + raft::host_matrix_view X, + raft::device_matrix_view X_sample, + uint64_t seed) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_samples_out = X_sample.extent(0); - // Use strided sampling for cache-friendliness - // For truly random, could use std::shuffle on indices first std::mt19937 gen(seed); - std::vector indices(n_samples); + std::vector indices(n_samples); std::iota(indices.begin(), indices.end(), 0); std::shuffle(indices.begin(), indices.end(), gen); std::vector host_sample(n_samples_out * n_features); #pragma omp parallel for - for (IndexT i = 0; i < static_cast(n_samples_out); i++) { - IndexT src_idx = indices[i]; + for (IdxT i = 0; i < static_cast(n_samples_out); i++) { + IdxT src_idx = indices[i]; std::memcpy(host_sample.data() + i * n_features, X.data_handle() + src_idx * n_features, n_features * sizeof(DataT)); @@ -81,11 +79,11 @@ void prepare_init_sample(raft::resources const& handle, /** * @brief Initialize centroids using k-means++ on a sample of the host data */ -template +template void init_centroids_from_host_sample(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - raft::device_matrix_view centroids, + raft::host_matrix_view X, + raft::device_matrix_view centroids, rmm::device_uvector& workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -101,15 +99,15 @@ void init_centroids_from_host_sample(raft::resources const& handle, RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); // Sample data from host to device - auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); // Run k-means++ on the sample if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { - cuvs::cluster::kmeans::detail::kmeansPlusPlus( + cuvs::cluster::kmeans::detail::kmeansPlusPlus( handle, params, - raft::make_device_matrix_view( + raft::make_device_matrix_view( init_sample.data_handle(), init_sample_size, n_features), centroids, workspace); @@ -132,14 +130,14 @@ void init_centroids_from_host_sample(raft::resources const& handle, * This function adds the partial sums from a batch to the running accumulators. * It does NOT divide - that happens once at the end of all batches. */ -template +template void accumulate_batch_centroids( raft::resources const& handle, - raft::device_matrix_view batch_data, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, + raft::device_matrix_view batch_data, + raft::device_vector_view, IdxT> minClusterAndDistance, + raft::device_vector_view sample_weights, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, rmm::device_uvector& workspace) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -148,8 +146,8 @@ void accumulate_batch_centroids( auto n_clusters = centroid_sums.extent(0); // Temporary buffers for this batch's partial results - auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto batch_counts = raft::make_device_vector(handle, n_clusters); + auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(handle, n_clusters); // Zero the batch temporaries thrust::fill(raft::resource::get_thrust_policy(handle), @@ -162,9 +160,9 @@ void accumulate_batch_centroids( DataT{0}); // Extract cluster labels from KeyValuePair - cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; - thrust::transform_iterator, - const raft::KeyValuePair*> + cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> labels_itr(minClusterAndDistance.data_handle(), conversion_op); workspace.resize(n_samples, stream); @@ -185,9 +183,9 @@ void accumulate_batch_centroids( raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), labels_itr, batch_counts.data_handle(), - static_cast(1), - static_cast(n_samples), - static_cast(n_clusters), + static_cast(1), + static_cast(n_samples), + static_cast(n_clusters), stream); // Add batch results to running accumulators @@ -207,12 +205,12 @@ void accumulate_batch_centroids( /** * @brief Finalize centroids by dividing accumulated sums by counts */ -template +template void finalize_centroids(raft::resources const& handle, - raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, - raft::device_matrix_view old_centroids, - raft::device_matrix_view new_centroids) + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_clusters = new_centroids.extent(0); @@ -252,7 +250,7 @@ void finalize_centroids(raft::resources const& handle, * @brief Main fit function for batched k-means with host data * * @tparam DataT Data type (float, double) - * @tparam IndexT Index type (int, int64_t) + * @tparam IdxT Index type (int, int64_t) * * @param[in] handle RAFT resources handle * @param[in] params K-means parameters @@ -263,15 +261,15 @@ void finalize_centroids(raft::resources const& handle, * @param[out] inertia Sum of squared distances to nearest centroid * @param[out] n_iter Number of iterations run */ -template +template void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, - raft::host_matrix_view X, - IndexT batch_size, - std::optional> sample_weight, - raft::device_matrix_view centroids, + raft::host_matrix_view X, + IdxT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -281,7 +279,7 @@ void fit(raft::resources const& handle, RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); - RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, "centroids.extent(0) must equal n_clusters"); RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); @@ -304,21 +302,21 @@ void fit(raft::resources const& handle, // Allocate device buffers // Batch buffer for data - auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); // Batch buffer for weights - auto batch_weights = raft::make_device_vector(handle, batch_size); + auto batch_weights = raft::make_device_vector(handle, batch_size); // Cluster assignment for batch auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, batch_size); + raft::make_device_vector, IdxT>(handle, batch_size); // L2 norms of batch data - auto L2NormBatch = raft::make_device_vector(handle, batch_size); + auto L2NormBatch = raft::make_device_vector(handle, batch_size); // Temporary buffer for distance computation rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); // Accumulators for centroid computation (persist across batches within an iteration) - auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); - auto cluster_counts = raft::make_device_vector(handle, n_clusters); - auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto cluster_counts = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); // Host buffer for batch data (pinned memory for faster H2D transfer) std::vector host_batch_buffer(batch_size * n_features); @@ -345,8 +343,8 @@ void fit(raft::resources const& handle, DataT total_cost = 0; // Process all data in batches - for (IndexT offset = 0; offset < n_samples; offset += batch_size) { - IndexT current_batch_size = std::min(batch_size, n_samples - offset); + for (IdxT offset = 0; offset < n_samples; offset += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - offset); // Copy batch data from host to device raft::copy(batch_data.data_handle(), @@ -368,14 +366,14 @@ void fit(raft::resources const& handle, } // Create views for current batch size - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); - auto batch_weights_view = raft::make_device_vector_view( + auto batch_weights_view = raft::make_device_vector_view( batch_weights.data_handle(), current_batch_size); auto minClusterAndDistance_view = - raft::make_device_vector_view, IndexT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - auto L2NormBatch_view = raft::make_device_vector_view( + auto L2NormBatch_view = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); // Compute L2 norms for batch if needed @@ -390,12 +388,12 @@ void fit(raft::resources const& handle, } // Find nearest centroid for each sample in batch - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( + auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, batch_data_view, centroids_const, @@ -409,10 +407,10 @@ void fit(raft::resources const& handle, // Accumulate partial sums for this batch auto minClusterAndDistance_const = - raft::make_device_vector_view, IndexT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - accumulate_batch_centroids(handle, + accumulate_batch_centroids(handle, batch_data_view, minClusterAndDistance_const, batch_weights_view, @@ -435,14 +433,14 @@ void fit(raft::resources const& handle, } // end batch loop // Finalize centroids: divide sums by counts - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto centroid_sums_const = raft::make_device_matrix_view( + auto centroid_sums_const = raft::make_device_matrix_view( centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = raft::make_device_vector_view( + auto cluster_counts_const = raft::make_device_vector_view( cluster_counts.data_handle(), n_clusters); - finalize_centroids( + finalize_centroids( handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); // Compute squared norm of change in centroids @@ -481,18 +479,18 @@ void fit(raft::resources const& handle, // Compute final inertia by processing all data once more inertia[0] = 0; - for (IndexT offset = 0; offset < n_samples; offset += batch_size) { - IndexT current_batch_size = std::min(batch_size, n_samples - offset); + for (IdxT offset = 0; offset < n_samples; offset += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - offset); raft::copy(batch_data.data_handle(), X.data_handle() + offset * n_features, current_batch_size * n_features, stream); - auto batch_data_view = raft::make_device_matrix_view( + auto batch_data_view = raft::make_device_matrix_view( batch_data.data_handle(), current_batch_size, n_features); auto minClusterAndDistance_view = - raft::make_device_vector_view, IndexT>( + raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); if (metric == cuvs::distance::DistanceType::L2Expanded || @@ -501,12 +499,12 @@ void fit(raft::resources const& handle, L2NormBatch.data_handle(), batch_data.data_handle(), n_features, current_batch_size, stream); } - auto centroids_const = raft::make_device_matrix_view( + auto centroids_const = raft::make_device_matrix_view( centroids.data_handle(), n_clusters, n_features); - auto L2NormBatch_const = raft::make_device_vector_view( + auto L2NormBatch_const = raft::make_device_vector_view( L2NormBatch.data_handle(), current_batch_size); - cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, batch_data_view, centroids_const, From 0fa00b052dd2d5372f703fb5c46d0fe7f87cd484 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Fri, 9 Jan 2026 19:09:38 -0800 Subject: [PATCH 3/7] copyright --- cpp/src/cluster/detail/kmeans_batched.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index 1999d80c69..de44b309ba 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once From fcbdda59917eaaf8249355ffbd93dc91e917bcd8 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 02:11:01 -0800 Subject: [PATCH 4/7] python test --- python/cuvs/cuvs/tests/test_kmeans.py | 80 +++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 6f18137b13..cc3b1cf4a4 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -69,3 +69,83 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): # need reduced tolerance for float32 tol = 1e-3 if dtype == np.float32 else 1e-6 assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) + + +@pytest.mark.parametrize("n_rows", [1000]) +@pytest.mark.parametrize("n_cols", [10]) +@pytest.mark.parametrize("n_clusters", [8]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize( + "batch_samples_list", + [ + [32, 64, 128, 256, 512], # various batch sizes + ], +) +def test_kmeans_batch_size_determinism( + n_rows, n_cols, n_clusters, dtype, batch_samples_list +): + """ + Test that different batch sizes produce identical centroids. + + When starting from the same initial centroids, the k-means algorithm + should produce identical final centroids regardless of the batch_samples + parameter. This is because the accumulated adjustments to centroids after + the entire dataset pass should be the same. + """ + # Use fixed seed for reproducibility + rng = np.random.default_rng(42) + + # Generate random data + X_host = rng.random((n_rows, n_cols)).astype(dtype) + X = device_ndarray(X_host) + + # Generate fixed initial centroids (using first n_clusters rows) + initial_centroids_host = X_host[:n_clusters].copy() + + # Store results from each batch size + results = [] + + for batch_samples in batch_samples_list: + # Create fresh copy of initial centroids for each run + centroids = device_ndarray(initial_centroids_host.copy()) + + params = KMeansParams( + n_clusters=n_clusters, + init_method="Array", # Use provided centroids + max_iter=100, + tol=1e-10, # Very small tolerance to ensure convergence + batch_samples=batch_samples, + ) + + centroids_out, inertia, n_iter = fit(params, X, centroids) + results.append( + { + "batch_samples": batch_samples, + "centroids": centroids_out.copy_to_host(), + "inertia": inertia, + "n_iter": n_iter, + } + ) + + # Compare all results against the first one + reference = results[0] + for result in results[1:]: + # Centroids should be identical (or very close due to float precision) + assert np.allclose( + reference["centroids"], + result["centroids"], + rtol=1e-5, + atol=1e-5, + ), ( + f"Centroids differ between batch_samples=" + f"{reference['batch_samples']} and {result['batch_samples']}" + ) + + # Inertia should also be identical + assert np.allclose( + reference["inertia"], result["inertia"], rtol=1e-5, atol=1e-5 + ), ( + f"Inertia differs between batch_samples=" + f"{reference['batch_samples']} and {result['batch_samples']}: " + f"{reference['inertia']} vs {result['inertia']}" + ) From d6ed934577fffc6d24ed07053746d2411b3770ba Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 02:31:38 -0800 Subject: [PATCH 5/7] minibatch first commit --- c/include/cuvs/cluster/kmeans.h | 26 +++ cpp/include/cuvs/cluster/kmeans.hpp | 28 ++- cpp/src/cluster/detail/kmeans_batched.cuh | 219 +++++++++++++++++---- python/cuvs/cuvs/cluster/kmeans/kmeans.pxd | 5 + python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 42 ++++ 5 files changed, 281 insertions(+), 39 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 0bb9591f63..79448af26f 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -36,6 +36,25 @@ typedef enum { Array = 2 } cuvsKMeansInitMethod; +/** + * @brief Centroid update mode for k-means algorithm + */ +typedef enum { + /** + * Standard k-means (Lloyd's algorithm): accumulate assignments over the + * entire dataset, then update centroids once per iteration. + * More accurate but requires full pass over data before each update. + */ + CUVS_KMEANS_UPDATE_FULL_BATCH = 0, + + /** + * Mini-batch k-means: update centroids after each randomly sampled batch. + * Faster convergence for large datasets, but may have slightly lower accuracy. + * Uses streaming/online centroid updates with learning rate decay. + */ + CUVS_KMEANS_UPDATE_MINI_BATCH = 1 +} cuvsKMeansCentroidUpdateMode; + /** * @brief Hyper-parameters for the kmeans algorithm */ @@ -90,6 +109,13 @@ struct cuvsKMeansParams { */ int batch_centroids; + /** + * Centroid update mode: + * - CUVS_KMEANS_UPDATE_FULL_BATCH: Standard Lloyd's algorithm, update after full dataset pass + * - CUVS_KMEANS_UPDATE_MINI_BATCH: Mini-batch k-means, update after each batch + */ + cuvsKMeansCentroidUpdateMode update_mode; + bool inertia_check; /** diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index e141947dad..4f702d74ca 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -48,6 +48,25 @@ struct params : base_params { Array }; + /** + * Centroid update mode determines when centroids are updated during training. + */ + enum CentroidUpdateMode { + /** + * Standard k-means (Lloyd's algorithm): accumulate assignments over the + * entire dataset, then update centroids once per iteration. + * More accurate but requires full pass over data before each update. + */ + FullBatch, + + /** + * Mini-batch k-means: update centroids after each randomly sampled batch. + * Faster convergence for large datasets, but may have slightly lower accuracy. + * Uses streaming/online centroid updates with learning rate decay. + */ + MiniBatch + }; + /** * The number of clusters to form as well as the number of centroids to generate (default:8). */ @@ -104,7 +123,14 @@ struct params : base_params { /** * if 0 then batch_centroids = n_clusters */ - int batch_centroids = 0; // + int batch_centroids = 0; + + /** + * Centroid update mode: + * - FullBatch: Standard Lloyd's algorithm, update centroids after full dataset pass + * - MiniBatch: Mini-batch k-means, update centroids after each batch + */ + CentroidUpdateMode update_mode = FullBatch; bool inertia_check = false; }; diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index de44b309ba..c622c9b5a3 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -4,6 +4,7 @@ */ #pragma once +#include "kmeans.cuh" #include "kmeans_common.cuh" #include @@ -36,6 +37,8 @@ #include #include +#include +#include #include #include @@ -202,6 +205,78 @@ void accumulate_batch_centroids( stream); } +/** + * @brief Update centroids using mini-batch online learning + * + * Uses the online update formula: + * learning_rate[k] = batch_count[k] / (total_count[k] + batch_count[k]) + * centroid[k] = centroid[k] + learning_rate[k] * (batch_mean[k] - centroid[k]) + * + * This is equivalent to a weighted average where total_count tracks cumulative weight. + */ +template +void minibatch_update_centroids(raft::resources const& handle, + raft::device_matrix_view centroids, + raft::device_matrix_view batch_sums, + raft::device_vector_view batch_counts, + raft::device_vector_view total_counts) +{ + auto n_clusters = centroids.extent(0); + auto n_features = centroids.extent(1); + + // Compute batch means: batch_mean = batch_sums / batch_counts + auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); + raft::copy(batch_means.data_handle(), + batch_sums.data_handle(), + batch_sums.size(), + raft::resource::get_cuda_stream(handle)); + + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(batch_means.view()), + batch_counts, + batch_means.view(), + raft::div_checkzero_op{}); + + // Step 1: Update total_counts = total_counts + batch_counts + raft::linalg::add(handle, + raft::make_const_mdspan(total_counts), + batch_counts, + total_counts); + + // Step 2: Compute learning rates: lr = batch_count / total_count (after update) + auto learning_rates = raft::make_device_vector(handle, n_clusters); + raft::linalg::map(handle, + learning_rates.view(), + raft::div_checkzero_op{}, + batch_counts, + raft::make_const_mdspan(total_counts)); + + // Update centroids: centroid = centroid + lr * (batch_mean - centroid) + // = (1 - lr) * centroid + lr * batch_mean + // Using matrix_vector_op to scale each row by (1 - lr), then add lr * batch_mean + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(learning_rates.view()), + centroids, + [] __device__(DataT centroid_val, DataT lr) { return (DataT{1} - lr) * centroid_val; }); + + // Add lr * batch_mean to centroids + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(batch_means.view()), + raft::make_const_mdspan(learning_rates.view()), + batch_means.view(), + [] __device__(DataT mean_val, DataT lr) { return lr * mean_val; }); + + // centroids += lr * batch_means + raft::linalg::add(handle, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(batch_means.view()), + centroids); +} + /** * @brief Finalize centroids by dividing accumulated sums by counts */ @@ -313,11 +388,14 @@ void fit(raft::resources const& handle, // Temporary buffer for distance computation rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - // Accumulators for centroid computation (persist across batches within an iteration) + // Accumulators for centroid computation auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto cluster_counts = raft::make_device_vector(handle, n_clusters); auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + // For mini-batch mode: track total counts for learning rate calculation + auto total_counts = raft::make_device_vector(handle, n_clusters); + // Host buffer for batch data (pinned memory for faster H2D transfer) std::vector host_batch_buffer(batch_size * n_features); std::vector host_weight_buffer(batch_size); @@ -326,38 +404,83 @@ void fit(raft::resources const& handle, rmm::device_scalar clusterCostD(stream); DataT priorClusteringCost = 0; + // Check update mode + bool use_minibatch = + (params.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); + + RAFT_LOG_DEBUG("KMeans batched: update_mode=%s", use_minibatch ? "MiniBatch" : "FullBatch"); + + // For mini-batch mode with random sampling, create index shuffle + std::vector sample_indices(n_samples); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + std::mt19937 rng(params.rng_state.seed); + // Main iteration loop for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); - // Zero accumulators at start of each iteration - thrust::fill(raft::resource::get_thrust_policy(handle), - centroid_sums.data_handle(), - centroid_sums.data_handle() + centroid_sums.size(), - DataT{0}); - thrust::fill(raft::resource::get_thrust_policy(handle), - cluster_counts.data_handle(), - cluster_counts.data_handle() + cluster_counts.size(), - DataT{0}); + // For full-batch mode: zero accumulators at start of each iteration + // For mini-batch mode: zero total_counts at start of each iteration + if (!use_minibatch) { + thrust::fill(raft::resource::get_thrust_policy(handle), + centroid_sums.data_handle(), + centroid_sums.data_handle() + centroid_sums.size(), + DataT{0}); + thrust::fill(raft::resource::get_thrust_policy(handle), + cluster_counts.data_handle(), + cluster_counts.data_handle() + cluster_counts.size(), + DataT{0}); + } else { + // Mini-batch mode: zero total counts for learning rate calculation + thrust::fill(raft::resource::get_thrust_policy(handle), + total_counts.data_handle(), + total_counts.data_handle() + total_counts.size(), + DataT{0}); + // Shuffle sample indices for random batch selection + std::shuffle(sample_indices.begin(), sample_indices.end(), rng); + } + + // Save old centroids for convergence check + raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); DataT total_cost = 0; // Process all data in batches - for (IdxT offset = 0; offset < n_samples; offset += batch_size) { - IdxT current_batch_size = std::min(batch_size, n_samples - offset); + for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); // Copy batch data from host to device - raft::copy(batch_data.data_handle(), - X.data_handle() + offset * n_features, - current_batch_size * n_features, - stream); + if (use_minibatch) { + // Mini-batch: use shuffled indices for random sampling + for (IdxT i = 0; i < current_batch_size; ++i) { + IdxT sample_idx = sample_indices[batch_idx + i]; + std::memcpy(host_batch_buffer.data() + i * n_features, + X.data_handle() + sample_idx * n_features, + n_features * sizeof(DataT)); + } + raft::copy( + batch_data.data_handle(), host_batch_buffer.data(), current_batch_size * n_features, stream); + } else { + // Full-batch: sequential access + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); + } // Copy or set weights for this batch if (sample_weight) { - raft::copy(batch_weights.data_handle(), - sample_weight->data_handle() + offset, - current_batch_size, - stream); + if (use_minibatch) { + for (IdxT i = 0; i < current_batch_size; ++i) { + host_weight_buffer[i] = sample_weight->data_handle()[sample_indices[batch_idx + i]]; + } + raft::copy(batch_weights.data_handle(), host_weight_buffer.data(), current_batch_size, stream); + } else { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + batch_idx, + current_batch_size, + stream); + } } else { thrust::fill(raft::resource::get_thrust_policy(handle), batch_weights.data_handle(), @@ -373,8 +496,6 @@ void fit(raft::resources const& handle, auto minClusterAndDistance_view = raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); - auto L2NormBatch_view = raft::make_device_vector_view( - L2NormBatch.data_handle(), current_batch_size); // Compute L2 norms for batch if needed if (metric == cuvs::distance::DistanceType::L2Expanded || @@ -410,6 +531,18 @@ void fit(raft::resources const& handle, raft::make_device_vector_view, IdxT>( minClusterAndDistance.data_handle(), current_batch_size); + if (use_minibatch) { + // Mini-batch mode: zero batch accumulators before each batch + thrust::fill(raft::resource::get_thrust_policy(handle), + centroid_sums.data_handle(), + centroid_sums.data_handle() + centroid_sums.size(), + DataT{0}); + thrust::fill(raft::resource::get_thrust_policy(handle), + cluster_counts.data_handle(), + cluster_counts.data_handle() + cluster_counts.size(), + DataT{0}); + } + accumulate_batch_centroids(handle, batch_data_view, minClusterAndDistance_const, @@ -418,6 +551,17 @@ void fit(raft::resources const& handle, cluster_counts.view(), workspace); + if (use_minibatch) { + // Mini-batch mode: update centroids immediately after each batch + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = raft::make_device_vector_view( + cluster_counts.data_handle(), n_clusters); + + minibatch_update_centroids( + handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); + } + // Accumulate cluster cost if checking convergence if (params.inertia_check) { cuvs::cluster::kmeans::detail::computeClusterCost( @@ -432,32 +576,31 @@ void fit(raft::resources const& handle, } } // end batch loop - // Finalize centroids: divide sums by counts - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), n_clusters, n_features); - auto centroid_sums_const = raft::make_device_matrix_view( - centroid_sums.data_handle(), n_clusters, n_features); - auto cluster_counts_const = raft::make_device_vector_view( - cluster_counts.data_handle(), n_clusters); + if (!use_minibatch) { + // Full-batch mode: finalize centroids after processing all batches + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = raft::make_device_vector_view( + cluster_counts.data_handle(), n_clusters); - finalize_centroids( - handle, centroid_sums_const, cluster_counts_const, centroids_const, new_centroids.view()); + finalize_centroids( + handle, centroid_sums_const, cluster_counts_const, centroids_const, centroids); + } - // Compute squared norm of change in centroids + // Compute squared norm of change in centroids (compare to saved old centroids) auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - new_centroids.size(), + centroids.size(), raft::sqdiff_op{}, stream, - centroids.data_handle(), - new_centroids.data_handle()); + new_centroids.data_handle(), // old centroids + centroids.data_handle()); // new centroids DataT sqrdNormError = 0; raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); - // Update centroids - raft::copy(centroids.data_handle(), new_centroids.data_handle(), new_centroids.size(), stream); - // Check convergence bool done = false; if (params.inertia_check) { diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 9f16d46c4d..d219f4e903 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -18,6 +18,10 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: Random Array + ctypedef enum cuvsKMeansCentroidUpdateMode: + CUVS_KMEANS_UPDATE_FULL_BATCH + CUVS_KMEANS_UPDATE_MINI_BATCH + ctypedef enum cuvsKMeansType: CUVS_KMEANS_TYPE_KMEANS CUVS_KMEANS_TYPE_KMEANS_BALANCED @@ -32,6 +36,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: double oversampling_factor, int batch_samples, int batch_centroids, + cuvsKMeansCentroidUpdateMode update_mode, bool inertia_check, bool hierarchical, int hierarchical_n_iters diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 489d983ac7..b8ee467fb5 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -44,6 +44,12 @@ INIT_METHOD_TYPES = { INIT_METHOD_NAMES = {v: k for k, v in INIT_METHOD_TYPES.items()} +UPDATE_MODE_TYPES = { + "full_batch": cuvsKMeansCentroidUpdateMode.CUVS_KMEANS_UPDATE_FULL_BATCH, + "mini_batch": cuvsKMeansCentroidUpdateMode.CUVS_KMEANS_UPDATE_MINI_BATCH} + +UPDATE_MODE_NAMES = {v: k for k, v in UPDATE_MODE_TYPES.items()} + cdef class KMeansParams: """ Hyper-parameters for the kmeans algorithm @@ -70,6 +76,20 @@ cdef class KMeansParams: Number of instance k-means algorithm will be run with different seeds oversampling_factor : double Oversampling factor for use in the k-means|| algorithm + batch_samples : int + Number of samples to process in each batch for tiled 1NN computation. + Useful to optimize/control memory footprint. Default tile is + [batch_samples x n_clusters]. + batch_centroids : int + Number of centroids to process in each batch. If 0, uses n_clusters. + update_mode : str + Centroid update strategy. One of: + "full_batch" : Standard Lloyd's algorithm - accumulate assignments over + the entire dataset, then update centroids once per iteration. + More accurate but requires full pass over data before each update. + "mini_batch" : Mini-batch k-means - update centroids after each batch. + Faster convergence for large datasets, but may have slightly lower + accuracy. Uses online centroid updates with learning rate decay. hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int @@ -92,6 +112,9 @@ cdef class KMeansParams: tol=None, n_init=None, oversampling_factor=None, + batch_samples=None, + batch_centroids=None, + update_mode=None, hierarchical=None, hierarchical_n_iters=None): if metric is not None: @@ -109,6 +132,13 @@ cdef class KMeansParams: self.params.n_init = n_init if oversampling_factor is not None: self.params.oversampling_factor = oversampling_factor + if batch_samples is not None: + self.params.batch_samples = batch_samples + if batch_centroids is not None: + self.params.batch_centroids = batch_centroids + if update_mode is not None: + c_mode = UPDATE_MODE_TYPES[update_mode] + self.params.update_mode = c_mode if hierarchical is not None: self.params.hierarchical = hierarchical if hierarchical_n_iters is not None: @@ -145,6 +175,18 @@ cdef class KMeansParams: def oversampling_factor(self): return self.params.oversampling_factor + @property + def batch_samples(self): + return self.params.batch_samples + + @property + def batch_centroids(self): + return self.params.batch_centroids + + @property + def update_mode(self): + return UPDATE_MODE_NAMES[self.params.update_mode] + @property def hierarchical(self): return self.params.hierarchical From 5d4b4985b645640cbcfed514687d53ab57978216 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 02:42:28 -0800 Subject: [PATCH 6/7] fix docs --- c/include/cuvs/cluster/kmeans.h | 3 --- cpp/include/cuvs/cluster/kmeans.hpp | 3 --- python/cuvs/cuvs/cluster/kmeans/kmeans.pyx | 3 --- 3 files changed, 9 deletions(-) diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 79448af26f..e39b9c8523 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -43,14 +43,11 @@ typedef enum { /** * Standard k-means (Lloyd's algorithm): accumulate assignments over the * entire dataset, then update centroids once per iteration. - * More accurate but requires full pass over data before each update. */ CUVS_KMEANS_UPDATE_FULL_BATCH = 0, /** * Mini-batch k-means: update centroids after each randomly sampled batch. - * Faster convergence for large datasets, but may have slightly lower accuracy. - * Uses streaming/online centroid updates with learning rate decay. */ CUVS_KMEANS_UPDATE_MINI_BATCH = 1 } cuvsKMeansCentroidUpdateMode; diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 4f702d74ca..7ce492b136 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -55,14 +55,11 @@ struct params : base_params { /** * Standard k-means (Lloyd's algorithm): accumulate assignments over the * entire dataset, then update centroids once per iteration. - * More accurate but requires full pass over data before each update. */ FullBatch, /** * Mini-batch k-means: update centroids after each randomly sampled batch. - * Faster convergence for large datasets, but may have slightly lower accuracy. - * Uses streaming/online centroid updates with learning rate decay. */ MiniBatch }; diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index b8ee467fb5..7b0cb4b3a2 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -86,10 +86,7 @@ cdef class KMeansParams: Centroid update strategy. One of: "full_batch" : Standard Lloyd's algorithm - accumulate assignments over the entire dataset, then update centroids once per iteration. - More accurate but requires full pass over data before each update. "mini_batch" : Mini-batch k-means - update centroids after each batch. - Faster convergence for large datasets, but may have slightly lower - accuracy. Uses online centroid updates with learning rate decay. hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int From 72fe7892609bf6faa279381aa901034d163dcb4e Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 3 Feb 2026 03:36:16 -0800 Subject: [PATCH 7/7] replace thrust calls: --- cpp/src/cluster/detail/kmeans_batched.cuh | 54 +++++++---------------- 1 file changed, 16 insertions(+), 38 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh index c622c9b5a3..4db84b12c0 100644 --- a/cpp/src/cluster/detail/kmeans_batched.cuh +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -140,27 +140,24 @@ void accumulate_batch_centroids( raft::device_vector_view, IdxT> minClusterAndDistance, raft::device_vector_view sample_weights, raft::device_matrix_view centroid_sums, - raft::device_vector_view cluster_counts, - rmm::device_uvector& workspace) + raft::device_vector_view cluster_counts) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = batch_data.extent(0); auto n_features = batch_data.extent(1); auto n_clusters = centroid_sums.extent(0); + // Get workspace from handle + auto* workspace_resource = raft::resource::get_workspace_resource(handle); + auto workspace = rmm::device_uvector(n_samples, stream, workspace_resource); + // Temporary buffers for this batch's partial results auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); auto batch_counts = raft::make_device_vector(handle, n_clusters); // Zero the batch temporaries - thrust::fill(raft::resource::get_thrust_policy(handle), - batch_sums.data_handle(), - batch_sums.data_handle() + batch_sums.size(), - DataT{0}); - thrust::fill(raft::resource::get_thrust_policy(handle), - batch_counts.data_handle(), - batch_counts.data_handle() + batch_counts.size(), - DataT{0}); + raft::matrix::fill(handle, batch_sums.view(), DataT{0}); + raft::matrix::fill(handle, batch_counts.view(), DataT{0}); // Extract cluster labels from KeyValuePair cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; @@ -168,8 +165,6 @@ void accumulate_batch_centroids( const raft::KeyValuePair*> labels_itr(minClusterAndDistance.data_handle(), conversion_op); - workspace.resize(n_samples, stream); - // Compute weighted sum of samples per cluster for this batch raft::linalg::reduce_rows_by_key(const_cast(batch_data.data_handle()), batch_data.extent(1), @@ -422,20 +417,11 @@ void fit(raft::resources const& handle, // For full-batch mode: zero accumulators at start of each iteration // For mini-batch mode: zero total_counts at start of each iteration if (!use_minibatch) { - thrust::fill(raft::resource::get_thrust_policy(handle), - centroid_sums.data_handle(), - centroid_sums.data_handle() + centroid_sums.size(), - DataT{0}); - thrust::fill(raft::resource::get_thrust_policy(handle), - cluster_counts.data_handle(), - cluster_counts.data_handle() + cluster_counts.size(), - DataT{0}); + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, cluster_counts.view(), DataT{0}); } else { // Mini-batch mode: zero total counts for learning rate calculation - thrust::fill(raft::resource::get_thrust_policy(handle), - total_counts.data_handle(), - total_counts.data_handle() + total_counts.size(), - DataT{0}); + raft::matrix::fill(handle, total_counts.view(), DataT{0}); // Shuffle sample indices for random batch selection std::shuffle(sample_indices.begin(), sample_indices.end(), rng); } @@ -482,10 +468,9 @@ void fit(raft::resources const& handle, stream); } } else { - thrust::fill(raft::resource::get_thrust_policy(handle), - batch_weights.data_handle(), - batch_weights.data_handle() + current_batch_size, - DataT{1}); + auto batch_weights_fill_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + raft::matrix::fill(handle, batch_weights_fill_view, DataT{1}); } // Create views for current batch size @@ -533,14 +518,8 @@ void fit(raft::resources const& handle, if (use_minibatch) { // Mini-batch mode: zero batch accumulators before each batch - thrust::fill(raft::resource::get_thrust_policy(handle), - centroid_sums.data_handle(), - centroid_sums.data_handle() + centroid_sums.size(), - DataT{0}); - thrust::fill(raft::resource::get_thrust_policy(handle), - cluster_counts.data_handle(), - cluster_counts.data_handle() + cluster_counts.size(), - DataT{0}); + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, cluster_counts.view(), DataT{0}); } accumulate_batch_centroids(handle, @@ -548,8 +527,7 @@ void fit(raft::resources const& handle, minClusterAndDistance_const, batch_weights_view, centroid_sums.view(), - cluster_counts.view(), - workspace); + cluster_counts.view()); if (use_minibatch) { // Mini-batch mode: update centroids immediately after each batch