diff --git a/c/include/cuvs/cluster/kmeans.h b/c/include/cuvs/cluster/kmeans.h index 0bb9591f63..e39b9c8523 100644 --- a/c/include/cuvs/cluster/kmeans.h +++ b/c/include/cuvs/cluster/kmeans.h @@ -36,6 +36,22 @@ typedef enum { Array = 2 } cuvsKMeansInitMethod; +/** + * @brief Centroid update mode for k-means algorithm + */ +typedef enum { + /** + * Standard k-means (Lloyd's algorithm): accumulate assignments over the + * entire dataset, then update centroids once per iteration. + */ + CUVS_KMEANS_UPDATE_FULL_BATCH = 0, + + /** + * Mini-batch k-means: update centroids after each randomly sampled batch. + */ + CUVS_KMEANS_UPDATE_MINI_BATCH = 1 +} cuvsKMeansCentroidUpdateMode; + /** * @brief Hyper-parameters for the kmeans algorithm */ @@ -90,6 +106,13 @@ struct cuvsKMeansParams { */ int batch_centroids; + /** + * Centroid update mode: + * - CUVS_KMEANS_UPDATE_FULL_BATCH: Standard Lloyd's algorithm, update after full dataset pass + * - CUVS_KMEANS_UPDATE_MINI_BATCH: Mini-batch k-means, update after each batch + */ + cuvsKMeansCentroidUpdateMode update_mode; + bool inertia_check; /** diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 22c8e056ec..7ce492b136 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -48,6 +48,22 @@ struct params : base_params { Array }; + /** + * Centroid update mode determines when centroids are updated during training. + */ + enum CentroidUpdateMode { + /** + * Standard k-means (Lloyd's algorithm): accumulate assignments over the + * entire dataset, then update centroids once per iteration. + */ + FullBatch, + + /** + * Mini-batch k-means: update centroids after each randomly sampled batch. + */ + MiniBatch + }; + /** * The number of clusters to form as well as the number of centroids to generate (default:8). */ @@ -104,7 +120,14 @@ struct params : base_params { /** * if 0 then batch_centroids = n_clusters */ - int batch_centroids = 0; // + int batch_centroids = 0; + + /** + * Centroid update mode: + * - FullBatch: Standard Lloyd's algorithm, update centroids after full dataset pass + * - MiniBatch: Mini-batch k-means, update centroids after each batch + */ + CentroidUpdateMode update_mode = FullBatch; bool inertia_check = false; }; @@ -139,6 +162,147 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * @{ */ +/** + * @defgroup kmeans_batched Batched k-means for out-of-core / host data + * @{ + */ + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * This version supports out-of-core computation where the dataset resides + * on the host. Data is processed in batches, with partial sums accumulated + * across batches and centroids finalized at the end of each iteration. + * This is mathematically equivalent to standard kmeans. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::params params; + * int n_features = 15; + * float inertia; + * int n_iter; + * + * // Data on host + * std::vector h_X(n_samples * n_features); + * auto X = raft::make_host_matrix_view(h_X.data(), n_samples, n_features); + * + * // Centroids on device + * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); + * + * kmeans::fit_batched(handle, + * params, + * X, + * 100000, // batch_size + * std::nullopt, + * centroids.view(), + * raft::make_host_scalar_view(&inertia), + * raft::make_host_scalar_view(&n_iter)); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. The data must + * be in row-major format. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * [len = n_samples] + * @param[inout] centroids [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers. + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * @param[inout] centroids Cluster centers on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * @param[inout] centroids Cluster centers on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @brief Find clusters with k-means algorithm using batched processing. + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances on HOST memory. + * [dim = n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch. + * @param[in] sample_weight Optional weights for each observation in X (on host). + * @param[inout] centroids Cluster centers on device. + * [dim = n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid. + * @param[out] n_iter Number of iterations run. + */ +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter); + +/** + * @} + */ + /** * @brief Find clusters with k-means algorithm. * Initial centroids are chosen with k-means++ algorithm. Empty diff --git a/cpp/src/cluster/detail/kmeans_batched.cuh b/cpp/src/cluster/detail/kmeans_batched.cuh new file mode 100644 index 0000000000..4db84b12c0 --- /dev/null +++ b/cpp/src/cluster/detail/kmeans_batched.cuh @@ -0,0 +1,655 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "kmeans.cuh" +#include "kmeans_common.cuh" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace cuvs::cluster::kmeans::batched::detail { + +/** + * @brief Sample data from host to device for initialization + * + * Samples `n_samples_to_gather` rows from host data and copies to device. + * Uses uniform strided sampling for simplicity and cache-friendliness. + */ +template +void prepare_init_sample(raft::resources const& handle, + raft::host_matrix_view X, + raft::device_matrix_view X_sample, + uint64_t seed) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_samples_out = X_sample.extent(0); + + std::mt19937 gen(seed); + std::vector indices(n_samples); + std::iota(indices.begin(), indices.end(), 0); + std::shuffle(indices.begin(), indices.end(), gen); + + std::vector host_sample(n_samples_out * n_features); + +#pragma omp parallel for + for (IdxT i = 0; i < static_cast(n_samples_out); i++) { + IdxT src_idx = indices[i]; + std::memcpy(host_sample.data() + i * n_features, + X.data_handle() + src_idx * n_features, + n_features * sizeof(DataT)); + } + + raft::copy(X_sample.data_handle(), host_sample.data(), host_sample.size(), stream); +} + +/** + * @brief Initialize centroids using k-means++ on a sample of the host data + */ +template +void init_centroids_from_host_sample(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + raft::device_matrix_view centroids, + rmm::device_uvector& workspace) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + + // Sample size for initialization: at least 3 * n_clusters, but not more than n_samples + size_t init_sample_size = std::min(static_cast(n_samples), + std::max(static_cast(3 * n_clusters), + static_cast(10000))); + + RAFT_LOG_DEBUG("KMeans batched: sampling %zu points for initialization", init_sample_size); + + // Sample data from host to device + auto init_sample = raft::make_device_matrix(handle, init_sample_size, n_features); + prepare_init_sample(handle, X, init_sample.view(), params.rng_state.seed); + + // Run k-means++ on the sample + if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { + cuvs::cluster::kmeans::detail::kmeansPlusPlus( + handle, + params, + raft::make_device_matrix_view( + init_sample.data_handle(), init_sample_size, n_features), + centroids, + workspace); + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Random) { + // Just use the first n_clusters samples + raft::copy(centroids.data_handle(), + init_sample.data_handle(), + n_clusters * n_features, + stream); + } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { + // Centroids already provided, nothing to do + } else { + RAFT_FAIL("Unknown initialization method"); + } +} + +/** + * @brief Accumulate partial centroid sums and counts from a batch + * + * This function adds the partial sums from a batch to the running accumulators. + * It does NOT divide - that happens once at the end of all batches. + */ +template +void accumulate_batch_centroids( + raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view, IdxT> minClusterAndDistance, + raft::device_vector_view sample_weights, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = batch_data.extent(0); + auto n_features = batch_data.extent(1); + auto n_clusters = centroid_sums.extent(0); + + // Get workspace from handle + auto* workspace_resource = raft::resource::get_workspace_resource(handle); + auto workspace = rmm::device_uvector(n_samples, stream, workspace_resource); + + // Temporary buffers for this batch's partial results + auto batch_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(handle, n_clusters); + + // Zero the batch temporaries + raft::matrix::fill(handle, batch_sums.view(), DataT{0}); + raft::matrix::fill(handle, batch_counts.view(), DataT{0}); + + // Extract cluster labels from KeyValuePair + cuvs::cluster::kmeans::detail::KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> + labels_itr(minClusterAndDistance.data_handle(), conversion_op); + + // Compute weighted sum of samples per cluster for this batch + raft::linalg::reduce_rows_by_key(const_cast(batch_data.data_handle()), + batch_data.extent(1), + labels_itr, + sample_weights.data_handle(), + workspace.data(), + batch_data.extent(0), + batch_data.extent(1), + n_clusters, + batch_sums.data_handle(), + stream); + + // Compute sum of weights per cluster for this batch + raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), + labels_itr, + batch_counts.data_handle(), + static_cast(1), + static_cast(n_samples), + static_cast(n_clusters), + stream); + + // Add batch results to running accumulators + raft::linalg::add(centroid_sums.data_handle(), + centroid_sums.data_handle(), + batch_sums.data_handle(), + centroid_sums.size(), + stream); + + raft::linalg::add(cluster_counts.data_handle(), + cluster_counts.data_handle(), + batch_counts.data_handle(), + cluster_counts.size(), + stream); +} + +/** + * @brief Update centroids using mini-batch online learning + * + * Uses the online update formula: + * learning_rate[k] = batch_count[k] / (total_count[k] + batch_count[k]) + * centroid[k] = centroid[k] + learning_rate[k] * (batch_mean[k] - centroid[k]) + * + * This is equivalent to a weighted average where total_count tracks cumulative weight. + */ +template +void minibatch_update_centroids(raft::resources const& handle, + raft::device_matrix_view centroids, + raft::device_matrix_view batch_sums, + raft::device_vector_view batch_counts, + raft::device_vector_view total_counts) +{ + auto n_clusters = centroids.extent(0); + auto n_features = centroids.extent(1); + + // Compute batch means: batch_mean = batch_sums / batch_counts + auto batch_means = raft::make_device_matrix(handle, n_clusters, n_features); + raft::copy(batch_means.data_handle(), + batch_sums.data_handle(), + batch_sums.size(), + raft::resource::get_cuda_stream(handle)); + + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(batch_means.view()), + batch_counts, + batch_means.view(), + raft::div_checkzero_op{}); + + // Step 1: Update total_counts = total_counts + batch_counts + raft::linalg::add(handle, + raft::make_const_mdspan(total_counts), + batch_counts, + total_counts); + + // Step 2: Compute learning rates: lr = batch_count / total_count (after update) + auto learning_rates = raft::make_device_vector(handle, n_clusters); + raft::linalg::map(handle, + learning_rates.view(), + raft::div_checkzero_op{}, + batch_counts, + raft::make_const_mdspan(total_counts)); + + // Update centroids: centroid = centroid + lr * (batch_mean - centroid) + // = (1 - lr) * centroid + lr * batch_mean + // Using matrix_vector_op to scale each row by (1 - lr), then add lr * batch_mean + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(learning_rates.view()), + centroids, + [] __device__(DataT centroid_val, DataT lr) { return (DataT{1} - lr) * centroid_val; }); + + // Add lr * batch_mean to centroids + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(batch_means.view()), + raft::make_const_mdspan(learning_rates.view()), + batch_means.view(), + [] __device__(DataT mean_val, DataT lr) { return lr * mean_val; }); + + // centroids += lr * batch_means + raft::linalg::add(handle, + raft::make_const_mdspan(centroids), + raft::make_const_mdspan(batch_means.view()), + centroids); +} + +/** + * @brief Finalize centroids by dividing accumulated sums by counts + */ +template +void finalize_centroids(raft::resources const& handle, + raft::device_matrix_view centroid_sums, + raft::device_vector_view cluster_counts, + raft::device_matrix_view old_centroids, + raft::device_matrix_view new_centroids) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_clusters = new_centroids.extent(0); + auto n_features = new_centroids.extent(1); + + // Copy sums to new_centroids first + raft::copy( + new_centroids.data_handle(), centroid_sums.data_handle(), centroid_sums.size(), stream); + + // Divide by counts: new_centroids[i] = centroid_sums[i] / cluster_counts[i] + // When count is 0, set to 0 (will be fixed below) + raft::linalg::matrix_vector_op( + handle, + raft::make_const_mdspan(new_centroids), + cluster_counts, + new_centroids, + raft::div_checkzero_op{}); + + // Copy old centroids to new centroids where cluster_counts[i] == 0 + cub::ArgIndexInputIterator itr_wt(cluster_counts.data_handle()); + raft::matrix::gather_if( + old_centroids.data_handle(), + static_cast(old_centroids.extent(1)), + static_cast(old_centroids.extent(0)), + itr_wt, + itr_wt, + static_cast(cluster_counts.size()), + new_centroids.data_handle(), + [=] __device__(raft::KeyValuePair map) { + return map.value == DataT{0}; // predicate: copy when count is 0 + }, + raft::key_op{}, + stream); +} + +/** + * @brief Main fit function for batched k-means with host data + * + * @tparam DataT Data type (float, double) + * @tparam IdxT Index type (int, int64_t) + * + * @param[in] handle RAFT resources handle + * @param[in] params K-means parameters + * @param[in] X Input data on HOST [n_samples x n_features] + * @param[in] batch_size Number of samples to process per batch + * @param[in] sample_weight Optional weights per sample (on host) + * @param[inout] centroids Initial/output cluster centers [n_clusters x n_features] + * @param[out] inertia Sum of squared distances to nearest centroid + * @param[out] n_iter Number of iterations run + */ +template +void fit(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + IdxT batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + auto n_samples = X.extent(0); + auto n_features = X.extent(1); + auto n_clusters = params.n_clusters; + auto metric = params.metric; + + RAFT_EXPECTS(batch_size > 0, "batch_size must be positive"); + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, + "centroids.extent(0) must equal n_clusters"); + RAFT_EXPECTS(centroids.extent(1) == n_features, + "centroids.extent(1) must equal n_features"); + + raft::default_logger().set_level(params.verbosity); + + RAFT_LOG_DEBUG( + "KMeans batched fit: n_samples=%zu, n_features=%zu, n_clusters=%d, batch_size=%zu", + static_cast(n_samples), + static_cast(n_features), + n_clusters, + static_cast(batch_size)); + + rmm::device_uvector workspace(0, stream); + + // Initialize centroids from a sample of host data + if (params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { + init_centroids_from_host_sample(handle, params, X, centroids, workspace); + } + + // Allocate device buffers + // Batch buffer for data + auto batch_data = raft::make_device_matrix(handle, batch_size, n_features); + // Batch buffer for weights + auto batch_weights = raft::make_device_vector(handle, batch_size); + // Cluster assignment for batch + auto minClusterAndDistance = + raft::make_device_vector, IdxT>(handle, batch_size); + // L2 norms of batch data + auto L2NormBatch = raft::make_device_vector(handle, batch_size); + // Temporary buffer for distance computation + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + + // Accumulators for centroid computation + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto cluster_counts = raft::make_device_vector(handle, n_clusters); + auto new_centroids = raft::make_device_matrix(handle, n_clusters, n_features); + + // For mini-batch mode: track total counts for learning rate calculation + auto total_counts = raft::make_device_vector(handle, n_clusters); + + // Host buffer for batch data (pinned memory for faster H2D transfer) + std::vector host_batch_buffer(batch_size * n_features); + std::vector host_weight_buffer(batch_size); + + // Cluster cost for convergence check + rmm::device_scalar clusterCostD(stream); + DataT priorClusteringCost = 0; + + // Check update mode + bool use_minibatch = + (params.update_mode == cuvs::cluster::kmeans::params::CentroidUpdateMode::MiniBatch); + + RAFT_LOG_DEBUG("KMeans batched: update_mode=%s", use_minibatch ? "MiniBatch" : "FullBatch"); + + // For mini-batch mode with random sampling, create index shuffle + std::vector sample_indices(n_samples); + std::iota(sample_indices.begin(), sample_indices.end(), 0); + std::mt19937 rng(params.rng_state.seed); + + // Main iteration loop + for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { + RAFT_LOG_DEBUG("KMeans batched: Iteration %d", n_iter[0]); + + // For full-batch mode: zero accumulators at start of each iteration + // For mini-batch mode: zero total_counts at start of each iteration + if (!use_minibatch) { + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, cluster_counts.view(), DataT{0}); + } else { + // Mini-batch mode: zero total counts for learning rate calculation + raft::matrix::fill(handle, total_counts.view(), DataT{0}); + // Shuffle sample indices for random batch selection + std::shuffle(sample_indices.begin(), sample_indices.end(), rng); + } + + // Save old centroids for convergence check + raft::copy(new_centroids.data_handle(), centroids.data_handle(), centroids.size(), stream); + + DataT total_cost = 0; + + // Process all data in batches + for (IdxT batch_idx = 0; batch_idx < n_samples; batch_idx += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - batch_idx); + + // Copy batch data from host to device + if (use_minibatch) { + // Mini-batch: use shuffled indices for random sampling + for (IdxT i = 0; i < current_batch_size; ++i) { + IdxT sample_idx = sample_indices[batch_idx + i]; + std::memcpy(host_batch_buffer.data() + i * n_features, + X.data_handle() + sample_idx * n_features, + n_features * sizeof(DataT)); + } + raft::copy( + batch_data.data_handle(), host_batch_buffer.data(), current_batch_size * n_features, stream); + } else { + // Full-batch: sequential access + raft::copy(batch_data.data_handle(), + X.data_handle() + batch_idx * n_features, + current_batch_size * n_features, + stream); + } + + // Copy or set weights for this batch + if (sample_weight) { + if (use_minibatch) { + for (IdxT i = 0; i < current_batch_size; ++i) { + host_weight_buffer[i] = sample_weight->data_handle()[sample_indices[batch_idx + i]]; + } + raft::copy(batch_weights.data_handle(), host_weight_buffer.data(), current_batch_size, stream); + } else { + raft::copy(batch_weights.data_handle(), + sample_weight->data_handle() + batch_idx, + current_batch_size, + stream); + } + } else { + auto batch_weights_fill_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + raft::matrix::fill(handle, batch_weights_fill_view, DataT{1}); + } + + // Create views for current batch size + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto batch_weights_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + auto minClusterAndDistance_view = + raft::make_device_vector_view, IdxT>( + minClusterAndDistance.data_handle(), current_batch_size); + + // Compute L2 norms for batch if needed + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormBatch.data_handle(), + batch_data.data_handle(), + n_features, + current_batch_size, + stream); + } + + // Find nearest centroid for each sample in batch + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_data_view, + centroids_const, + minClusterAndDistance_view, + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + // Accumulate partial sums for this batch + auto minClusterAndDistance_const = + raft::make_device_vector_view, IdxT>( + minClusterAndDistance.data_handle(), current_batch_size); + + if (use_minibatch) { + // Mini-batch mode: zero batch accumulators before each batch + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, cluster_counts.view(), DataT{0}); + } + + accumulate_batch_centroids(handle, + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + cluster_counts.view()); + + if (use_minibatch) { + // Mini-batch mode: update centroids immediately after each batch + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = raft::make_device_vector_view( + cluster_counts.data_handle(), n_clusters); + + minibatch_update_centroids( + handle, centroids, centroid_sums_const, cluster_counts_const, total_counts.view()); + } + + // Accumulate cluster cost if checking convergence + if (params.inertia_check) { + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + DataT batch_cost = clusterCostD.value(stream); + total_cost += batch_cost; + } + } // end batch loop + + if (!use_minibatch) { + // Full-batch mode: finalize centroids after processing all batches + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto cluster_counts_const = raft::make_device_vector_view( + cluster_counts.data_handle(), n_clusters); + + finalize_centroids( + handle, centroid_sums_const, cluster_counts_const, centroids_const, centroids); + } + + // Compute squared norm of change in centroids (compare to saved old centroids) + auto sqrdNorm = raft::make_device_scalar(handle, DataT{0}); + raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), + centroids.size(), + raft::sqdiff_op{}, + stream, + new_centroids.data_handle(), // old centroids + centroids.data_handle()); // new centroids + + DataT sqrdNormError = 0; + raft::copy(&sqrdNormError, sqrdNorm.data_handle(), 1, stream); + + // Check convergence + bool done = false; + if (params.inertia_check) { + if (n_iter[0] > 1) { + DataT delta = total_cost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; + } + priorClusteringCost = total_cost; + } + + raft::resource::sync_stream(handle, stream); + if (sqrdNormError < params.tol) done = true; + + if (done) { + RAFT_LOG_DEBUG("KMeans batched: Converged after %d iterations", n_iter[0]); + break; + } + } // end iteration loop + + // Compute final inertia by processing all data once more + inertia[0] = 0; + for (IdxT offset = 0; offset < n_samples; offset += batch_size) { + IdxT current_batch_size = std::min(batch_size, n_samples - offset); + + raft::copy(batch_data.data_handle(), + X.data_handle() + offset * n_features, + current_batch_size * n_features, + stream); + + auto batch_data_view = raft::make_device_matrix_view( + batch_data.data_handle(), current_batch_size, n_features); + auto minClusterAndDistance_view = + raft::make_device_vector_view, IdxT>( + minClusterAndDistance.data_handle(), current_batch_size); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::rowNorm( + L2NormBatch.data_handle(), batch_data.data_handle(), n_features, current_batch_size, stream); + } + + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, n_features); + auto L2NormBatch_const = raft::make_device_vector_view( + L2NormBatch.data_handle(), current_batch_size); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + handle, + batch_data_view, + centroids_const, + minClusterAndDistance_view, + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + cuvs::cluster::kmeans::detail::computeClusterCost( + handle, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(clusterCostD.data()), + raft::value_op{}, + raft::add_op{}); + + inertia[0] += clusterCostD.value(stream); + } + + RAFT_LOG_DEBUG("KMeans batched: Completed with inertia=%f", inertia[0]); +} + +} // namespace cuvs::cluster::kmeans::batched::detail + diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 43f457a29a..0962c87890 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -1,8 +1,9 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -30,14 +31,55 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view inertia, \ raft::host_scalar_view n_iter); +#define INSTANTIATE_FIT_BATCHED(DataT, IndexT) \ + template void batched::detail::fit( \ + raft::resources const& handle, \ + const kmeans::params& params, \ + raft::host_matrix_view X, \ + IndexT batch_size, \ + std::optional> sample_weight, \ + raft::device_matrix_view centroids, \ + raft::host_scalar_view inertia, \ + raft::host_scalar_view n_iter); + INSTANTIATE_FIT_MAIN(double, int) INSTANTIATE_FIT_MAIN(double, int64_t) INSTANTIATE_FIT(double, int) INSTANTIATE_FIT(double, int64_t) +INSTANTIATE_FIT_BATCHED(double, int) +INSTANTIATE_FIT_BATCHED(double, int64_t) + #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT +#undef INSTANTIATE_FIT_BATCHED + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 5624151943..39cc074d9e 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -1,8 +1,9 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "detail/kmeans_batched.cuh" #include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -30,14 +31,55 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view inertia, \ raft::host_scalar_view n_iter); +#define INSTANTIATE_FIT_BATCHED(DataT, IndexT) \ + template void batched::detail::fit( \ + raft::resources const& handle, \ + const kmeans::params& params, \ + raft::host_matrix_view X, \ + IndexT batch_size, \ + std::optional> sample_weight, \ + raft::device_matrix_view centroids, \ + raft::host_scalar_view inertia, \ + raft::host_scalar_view n_iter); + INSTANTIATE_FIT_MAIN(float, int) INSTANTIATE_FIT_MAIN(float, int64_t) INSTANTIATE_FIT(float, int) INSTANTIATE_FIT(float, int64_t) +INSTANTIATE_FIT_BATCHED(float, int) +INSTANTIATE_FIT_BATCHED(float, int64_t) + #undef INSTANTIATE_FIT_MAIN #undef INSTANTIATE_FIT +#undef INSTANTIATE_FIT_BATCHED + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} + +void fit_batched(raft::resources const& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + int64_t batch_size, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + cuvs::cluster::kmeans::batched::detail::fit( + handle, params, X, batch_size, sample_weight, centroids, inertia, n_iter); +} void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd index 9f16d46c4d..d219f4e903 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pxd @@ -18,6 +18,10 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: Random Array + ctypedef enum cuvsKMeansCentroidUpdateMode: + CUVS_KMEANS_UPDATE_FULL_BATCH + CUVS_KMEANS_UPDATE_MINI_BATCH + ctypedef enum cuvsKMeansType: CUVS_KMEANS_TYPE_KMEANS CUVS_KMEANS_TYPE_KMEANS_BALANCED @@ -32,6 +36,7 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil: double oversampling_factor, int batch_samples, int batch_centroids, + cuvsKMeansCentroidUpdateMode update_mode, bool inertia_check, bool hierarchical, int hierarchical_n_iters diff --git a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx index 489d983ac7..7b0cb4b3a2 100644 --- a/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx +++ b/python/cuvs/cuvs/cluster/kmeans/kmeans.pyx @@ -44,6 +44,12 @@ INIT_METHOD_TYPES = { INIT_METHOD_NAMES = {v: k for k, v in INIT_METHOD_TYPES.items()} +UPDATE_MODE_TYPES = { + "full_batch": cuvsKMeansCentroidUpdateMode.CUVS_KMEANS_UPDATE_FULL_BATCH, + "mini_batch": cuvsKMeansCentroidUpdateMode.CUVS_KMEANS_UPDATE_MINI_BATCH} + +UPDATE_MODE_NAMES = {v: k for k, v in UPDATE_MODE_TYPES.items()} + cdef class KMeansParams: """ Hyper-parameters for the kmeans algorithm @@ -70,6 +76,17 @@ cdef class KMeansParams: Number of instance k-means algorithm will be run with different seeds oversampling_factor : double Oversampling factor for use in the k-means|| algorithm + batch_samples : int + Number of samples to process in each batch for tiled 1NN computation. + Useful to optimize/control memory footprint. Default tile is + [batch_samples x n_clusters]. + batch_centroids : int + Number of centroids to process in each batch. If 0, uses n_clusters. + update_mode : str + Centroid update strategy. One of: + "full_batch" : Standard Lloyd's algorithm - accumulate assignments over + the entire dataset, then update centroids once per iteration. + "mini_batch" : Mini-batch k-means - update centroids after each batch. hierarchical : bool Whether to use hierarchical (balanced) kmeans or not hierarchical_n_iters : int @@ -92,6 +109,9 @@ cdef class KMeansParams: tol=None, n_init=None, oversampling_factor=None, + batch_samples=None, + batch_centroids=None, + update_mode=None, hierarchical=None, hierarchical_n_iters=None): if metric is not None: @@ -109,6 +129,13 @@ cdef class KMeansParams: self.params.n_init = n_init if oversampling_factor is not None: self.params.oversampling_factor = oversampling_factor + if batch_samples is not None: + self.params.batch_samples = batch_samples + if batch_centroids is not None: + self.params.batch_centroids = batch_centroids + if update_mode is not None: + c_mode = UPDATE_MODE_TYPES[update_mode] + self.params.update_mode = c_mode if hierarchical is not None: self.params.hierarchical = hierarchical if hierarchical_n_iters is not None: @@ -145,6 +172,18 @@ cdef class KMeansParams: def oversampling_factor(self): return self.params.oversampling_factor + @property + def batch_samples(self): + return self.params.batch_samples + + @property + def batch_centroids(self): + return self.params.batch_centroids + + @property + def update_mode(self): + return UPDATE_MODE_NAMES[self.params.update_mode] + @property def hierarchical(self): return self.params.hierarchical diff --git a/python/cuvs/cuvs/tests/test_kmeans.py b/python/cuvs/cuvs/tests/test_kmeans.py index 6f18137b13..cc3b1cf4a4 100644 --- a/python/cuvs/cuvs/tests/test_kmeans.py +++ b/python/cuvs/cuvs/tests/test_kmeans.py @@ -69,3 +69,83 @@ def test_cluster_cost(n_rows, n_cols, n_clusters, dtype): # need reduced tolerance for float32 tol = 1e-3 if dtype == np.float32 else 1e-6 assert np.allclose(inertia, sum(cluster_distances), rtol=tol, atol=tol) + + +@pytest.mark.parametrize("n_rows", [1000]) +@pytest.mark.parametrize("n_cols", [10]) +@pytest.mark.parametrize("n_clusters", [8]) +@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize( + "batch_samples_list", + [ + [32, 64, 128, 256, 512], # various batch sizes + ], +) +def test_kmeans_batch_size_determinism( + n_rows, n_cols, n_clusters, dtype, batch_samples_list +): + """ + Test that different batch sizes produce identical centroids. + + When starting from the same initial centroids, the k-means algorithm + should produce identical final centroids regardless of the batch_samples + parameter. This is because the accumulated adjustments to centroids after + the entire dataset pass should be the same. + """ + # Use fixed seed for reproducibility + rng = np.random.default_rng(42) + + # Generate random data + X_host = rng.random((n_rows, n_cols)).astype(dtype) + X = device_ndarray(X_host) + + # Generate fixed initial centroids (using first n_clusters rows) + initial_centroids_host = X_host[:n_clusters].copy() + + # Store results from each batch size + results = [] + + for batch_samples in batch_samples_list: + # Create fresh copy of initial centroids for each run + centroids = device_ndarray(initial_centroids_host.copy()) + + params = KMeansParams( + n_clusters=n_clusters, + init_method="Array", # Use provided centroids + max_iter=100, + tol=1e-10, # Very small tolerance to ensure convergence + batch_samples=batch_samples, + ) + + centroids_out, inertia, n_iter = fit(params, X, centroids) + results.append( + { + "batch_samples": batch_samples, + "centroids": centroids_out.copy_to_host(), + "inertia": inertia, + "n_iter": n_iter, + } + ) + + # Compare all results against the first one + reference = results[0] + for result in results[1:]: + # Centroids should be identical (or very close due to float precision) + assert np.allclose( + reference["centroids"], + result["centroids"], + rtol=1e-5, + atol=1e-5, + ), ( + f"Centroids differ between batch_samples=" + f"{reference['batch_samples']} and {result['batch_samples']}" + ) + + # Inertia should also be identical + assert np.allclose( + reference["inertia"], result["inertia"], rtol=1e-5, atol=1e-5 + ), ( + f"Inertia differs between batch_samples=" + f"{reference['batch_samples']} and {result['batch_samples']}: " + f"{reference['inertia']} vs {result['inertia']}" + )