From ffa8a79a830fe74676eb3a4a6a95f19741ccef2c Mon Sep 17 00:00:00 2001 From: Yuichi Tateno Date: Thu, 8 Jan 2026 18:46:03 +0900 Subject: [PATCH 01/12] Add fast no-duplicates batch sampler --- sentence_transformers/sampler.py | 254 ++++++++++++++++++++++++++++--- 1 file changed, 234 insertions(+), 20 deletions(-) diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index 7c0b1baf7..1c310418c 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -1,15 +1,22 @@ from __future__ import annotations import logging +import os from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Iterator from itertools import accumulate, cycle from typing import Any +import numpy as np import torch from torch.utils.data import BatchSampler, ConcatDataset, SubsetRandomSampler +try: + import xxhash +except ImportError: # pragma: no cover - optional dependency + xxhash = None + from sentence_transformers.util import is_datasets_available if is_datasets_available(): @@ -17,6 +24,87 @@ logger = logging.getLogger(__name__) +_XXHASH_INT64_MAX = 1 << 63 +_XXHASH_UINT64_MAX = 1 << 64 + + +def _xxhash_int64(value: str) -> int: + # Convert uint64 -> int64 to keep values compatible with Arrow int64 storage. + hashed = xxhash.xxh64_intdigest(value) + if hashed >= _XXHASH_INT64_MAX: + hashed -= _XXHASH_UINT64_MAX + return hashed + + +def _hash_batch( + batch: dict[str, list[Any]], + columns: list[str], + exclude_columns: set[str], +) -> dict[str, list[list[int]]]: + # Must be defined at module scope because datasets.map with num_proc pickles this function. + # Build per-row hash lists so we can later do fast overlap checks without re-reading the dataset. + active_columns = [column for column in columns if column not in exclude_columns] + batch_size = len(batch[active_columns[0]]) if active_columns else len(next(iter(batch.values()), [])) + if not active_columns: + return {"__hashes": [[] for _ in range(batch_size)]} + hashes: list[list[int]] = [] + for row_idx in range(batch_size): + row_hashes: list[int] = [] + for column in active_columns: + value = batch[column][row_idx] + if isinstance(value, list): + for item in value: + row_hashes.append(_xxhash_int64(str(item))) + else: + row_hashes.append(_xxhash_int64(str(value))) + hashes.append(row_hashes) + return {"__hashes": hashes} + + +def _remove_label_columns(dataset: Dataset, valid_label_columns: list[str] | None) -> Dataset: + # Drop label columns so they don't participate in duplicate checks. + if label_columns := set(dataset.column_names) & set(valid_label_columns or []): + return dataset.remove_columns(list(label_columns)) + return dataset + + +def _has_overlap(sample_values, batch_values: set[Any]) -> bool: + # Avoid materializing a set if we already have one. + if isinstance(sample_values, set): + return bool(sample_values & batch_values) + return any(value in batch_values for value in sample_values) + + +def _iter_no_duplicate_batches( + remaining_indices: dict[int, None], + get_sample_values, + batch_size: int, + drop_last: bool, +) -> Iterator[list[int]]: + # Shared batch construction loop for both samplers; keeps behavior consistent. + while remaining_indices: + batch_values: set[Any] = set() + batch_indices: list[int] = [] + for index in remaining_indices: + sample_values = get_sample_values(index) + if _has_overlap(sample_values, batch_values): + continue + + batch_indices.append(index) + if len(batch_indices) == batch_size: + yield batch_indices + break + + batch_values.update(sample_values) + + else: + # NOTE: some indices might still have been ignored here + if not drop_last: + yield batch_indices + + for index in batch_indices: + del remaining_indices[index] + class SetEpochMixin: """ @@ -197,9 +285,7 @@ def __init__( generator=generator, seed=seed, ) - if label_columns := set(dataset.column_names) & set(self.valid_label_columns or []): - dataset = dataset.remove_columns(list(label_columns)) - self.dataset = dataset + self.dataset = _remove_label_columns(dataset, self.valid_label_columns) def __iter__(self) -> Iterator[list[int]]: """ @@ -214,28 +300,156 @@ def __iter__(self) -> Iterator[list[int]]: # 1. Allows for cheap removal of elements # 2. Preserves the order of elements, i.e. remains random remaining_indices = dict.fromkeys(torch.randperm(len(self.dataset), generator=self.generator).tolist()) - while remaining_indices: - batch_values = set() - batch_indices = [] - for index in remaining_indices: - sample_values = {str(value) for key, value in self.dataset[index].items() if key != "dataset_name"} - if sample_values & batch_values: - continue - batch_indices.append(index) - if len(batch_indices) == self.batch_size: - yield batch_indices - break + def get_sample_values(index: int) -> set[str]: + return {str(value) for key, value in self.dataset[index].items() if key != "dataset_name"} + + yield from _iter_no_duplicate_batches( + remaining_indices, + get_sample_values, + self.batch_size, + self.drop_last, + ) + + def __len__(self) -> int: + if self.drop_last: + return len(self.dataset) // self.batch_size + else: + return (len(self.dataset) + self.batch_size - 1) // self.batch_size - batch_values.update(sample_values) +class NoDuplicatesFastBatchSampler(DefaultBatchSampler): + def __init__( + self, + dataset: Dataset, + batch_size: int, + drop_last: bool, + valid_label_columns: list[str] | None = None, + generator: torch.Generator | None = None, + seed: int = 0, + hash_num_proc: int | None = None, + ds_map_batch_size: int = 1000, + ) -> None: + """ + This sampler creates batches such that each batch contains samples where the values are unique, + even across columns. It uses the same batch construction approach as NoDuplicatesBatchSampler, + but speeds up duplicate checks by caching per-row xxhash 64-bit values computed with datasets.map. + + Recommended for: + - :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss` + - :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss` + - :class:`~sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss` + - :class:`~sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss` + - :class:`~sentence_transformers.losses.MegaBatchMarginLoss` + - :class:`~sentence_transformers.losses.GISTEmbedLoss` + - :class:`~sentence_transformers.losses.CachedGISTEmbedLoss` + + Args: + dataset (Dataset): The dataset to sample from. + batch_size (int): Number of samples per batch. + drop_last (bool): If True, drop the last incomplete batch if the dataset size + is not divisible by the batch size. + valid_label_columns (List[str], optional): List of column names to check for labels. + The first column name from ``valid_label_columns`` found in the dataset will + be used as the label column. + generator (torch.Generator, optional): Optional random number generator for shuffling + the indices. + seed (int): Seed for the random number generator to ensure reproducibility. Defaults to 0. + hash_num_proc (int, optional): Number of processes for hashing with datasets.map. Defaults to min(8, cpu-1). + ds_map_batch_size (int, optional): Batch size for datasets.map hashing. Defaults to 1000. + """ + super().__init__( + dataset, + batch_size=batch_size, + drop_last=drop_last, + valid_label_columns=valid_label_columns, + generator=generator, + seed=seed, + ) + if xxhash is None: + raise ImportError( + "NoDuplicatesFastBatchSampler requires `xxhash`. Install it via `pip install xxhash` " + "or use the `train` extra." + ) + self.dataset = _remove_label_columns(dataset, self.valid_label_columns) + cpu_count = os.cpu_count() or 1 + # Leave one core free to avoid saturating the system when hashing. + default_workers = max(1, min(8, cpu_count - 1)) + self.hash_num_proc = hash_num_proc or default_workers + self.ds_map_batch_size = ds_map_batch_size + self._row_hashes: np.ndarray | list[list[int]] | None = None + + def _build_hashes(self) -> None: + if self._row_hashes is not None: + return + exclude_columns = set(self.valid_label_columns or []) | {"dataset_name"} + columns = list(self.dataset.column_names) + # Precompute hash values once to avoid repeated string processing per batch. + # Use num_proc to parallelize hashing across CPU cores. + hash_ds = self.dataset.map( + _hash_batch, + batched=True, + batch_size=self.ds_map_batch_size, + num_proc=self.hash_num_proc, + remove_columns=columns, + fn_kwargs={"columns": columns, "exclude_columns": exclude_columns}, + desc="Hashing dataset values", + ) + try: + import pyarrow as pa + + column = hash_ds.data.column("__hashes") + if isinstance(column, pa.ChunkedArray): + column = column.combine_chunks() + if not isinstance(column, (pa.ListArray, pa.LargeListArray)): + raise ValueError("Expected a list column for hashed values.") + + row_count = len(column) + if row_count == 0: + row_hashes = np.zeros((0, 0), dtype=np.int64) else: - # NOTE: some indices might still have been ignored here - if not self.drop_last: - yield batch_indices + offsets = column.offsets.to_numpy(zero_copy_only=False) + row_size = int(offsets[1] - offsets[0]) + if row_size < 0 or not np.all(np.diff(offsets) == row_size): + raise ValueError("Hashed rows have varying lengths.") + # If every row has the same length, store as a dense ndarray to reduce overhead. + values = column.values.to_numpy(zero_copy_only=False).astype(np.int64, copy=False) + if values.size != row_count * row_size: + raise ValueError("Unexpected hashed value buffer size.") + row_hashes = values.reshape((row_count, row_size)) + except Exception as exc: + # Surface failures explicitly; the fast sampler expects fixed-length rows. + del hash_ds + raise ValueError( + "NoDuplicatesFastBatchSampler requires fixed-length hash rows. " + "Ensure each sample has the same number of values across columns." + ) from exc + + self._row_hashes = row_hashes + # Drop the temporary dataset to release Arrow buffers promptly. + del hash_ds + + def __iter__(self) -> Iterator[list[int]]: + if self.generator and self.seed is not None: + self.generator.manual_seed(self.seed + self.epoch) + + self._build_hashes() + row_hashes = self._row_hashes if self._row_hashes is not None else [] - for index in batch_indices: - del remaining_indices[index] + # We create a dictionary to None because we need a data structure that: + # 1. Allows for cheap removal of elements + # 2. Preserves the order of elements, i.e. remains random + remaining_indices = dict.fromkeys(torch.randperm(len(self.dataset), generator=self.generator).tolist()) + + def get_sample_values(index: int): + return row_hashes[index] + + yield from _iter_no_duplicate_batches( + remaining_indices, + get_sample_values, + self.batch_size, + self.drop_last, + ) def __len__(self) -> int: if self.drop_last: From 1b55966b1d3157c472d06fb4196b9d789add0612 Mon Sep 17 00:00:00 2001 From: Yuichi Tateno Date: Thu, 8 Jan 2026 18:46:10 +0900 Subject: [PATCH 02/12] Wire NO_DUPLICATES_FAST option --- .../sentence_transformer/sampler.md | 5 +++++ sentence_transformers/trainer.py | 4 ++++ sentence_transformers/training_args.py | 7 ++++++- .../test_no_duplicates_batch_sampler.py | 18 ++++++++++++------ tests/test_training_args.py | 14 ++++++++++++++ 5 files changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/package_reference/sentence_transformer/sampler.md b/docs/package_reference/sentence_transformer/sampler.md index 76f9f338b..c4b4fe856 100644 --- a/docs/package_reference/sentence_transformer/sampler.md +++ b/docs/package_reference/sentence_transformer/sampler.md @@ -17,6 +17,11 @@ :members: ``` +```{eval-rst} +.. autoclass:: sentence_transformers.sampler.NoDuplicatesFastBatchSampler + :members: +``` + ```{eval-rst} .. autoclass:: sentence_transformers.sampler.GroupByLabelBatchSampler :members: diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index c90d0fb49..21bfa18b3 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -31,6 +31,7 @@ GroupByLabelBatchSampler, MultiDatasetDefaultBatchSampler, NoDuplicatesBatchSampler, + NoDuplicatesFastBatchSampler, ProportionalBatchSampler, RoundRobinBatchSampler, ) @@ -679,6 +680,9 @@ def get_batch_sampler( if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES: return NoDuplicatesBatchSampler(dataset, **batch_sampler_kwargs) + if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES_FAST: + return NoDuplicatesFastBatchSampler(dataset, **batch_sampler_kwargs) + if self.args.batch_sampler == BatchSamplers.GROUP_BY_LABEL: return GroupByLabelBatchSampler(dataset, **batch_sampler_kwargs) diff --git a/sentence_transformers/training_args.py b/sentence_transformers/training_args.py index 09f93f4fd..c176d5561 100644 --- a/sentence_transformers/training_args.py +++ b/sentence_transformers/training_args.py @@ -24,7 +24,11 @@ class BatchSamplers(ExplicitEnum): - ``BatchSamplers.BATCH_SAMPLER``: **[default]** Uses :class:`~sentence_transformers.sampler.DefaultBatchSampler`, the default PyTorch batch sampler. - ``BatchSamplers.NO_DUPLICATES``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler`, - ensuring no duplicate samples in a batch. Recommended for losses that use in-batch negatives, such as: + ensuring no duplicate samples in a batch. + - ``BatchSamplers.NO_DUPLICATES_FAST``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesFastBatchSampler`, + a faster sampler that also ensures no duplicate samples in a batch. + + Both are recommended for losses that use in-batch negatives, such as: - :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss` - :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss` @@ -79,6 +83,7 @@ class BatchSamplers(ExplicitEnum): BATCH_SAMPLER = "batch_sampler" NO_DUPLICATES = "no_duplicates" + NO_DUPLICATES_FAST = "no_duplicates_fast" GROUP_BY_LABEL = "group_by_label" diff --git a/tests/samplers/test_no_duplicates_batch_sampler.py b/tests/samplers/test_no_duplicates_batch_sampler.py index 8a82d6503..87ef336a5 100644 --- a/tests/samplers/test_no_duplicates_batch_sampler.py +++ b/tests/samplers/test_no_duplicates_batch_sampler.py @@ -6,7 +6,11 @@ import torch from torch.utils.data import ConcatDataset -from sentence_transformers.sampler import NoDuplicatesBatchSampler, ProportionalBatchSampler +from sentence_transformers.sampler import ( + NoDuplicatesBatchSampler, + NoDuplicatesFastBatchSampler, + ProportionalBatchSampler, +) from sentence_transformers.util import is_datasets_available if is_datasets_available(): @@ -50,10 +54,11 @@ def dummy_duplicates_dataset() -> Dataset: return Dataset.from_list(values) -def test_group_by_label_batch_sampler_label_a(dummy_dataset: Dataset) -> None: +@pytest.mark.parametrize("sampler_cls", [NoDuplicatesBatchSampler, NoDuplicatesFastBatchSampler]) +def test_group_by_label_batch_sampler_label_a(dummy_dataset: Dataset, sampler_cls) -> None: batch_size = 10 - sampler = NoDuplicatesBatchSampler( + sampler = sampler_cls( dataset=dummy_dataset, batch_size=batch_size, drop_last=True, valid_label_columns=["label"] ) @@ -68,13 +73,14 @@ def test_group_by_label_batch_sampler_label_a(dummy_dataset: Dataset) -> None: assert len(batch_values) == len(set(batch_values)), f"Batch {batch} contains duplicate values: {batch_values}" +@pytest.mark.parametrize("sampler_cls", [NoDuplicatesBatchSampler, NoDuplicatesFastBatchSampler]) @pytest.mark.parametrize("drop_last", [True, False]) -def test_proportional_no_duplicates(dummy_duplicates_dataset: Dataset, drop_last: bool) -> None: +def test_proportional_no_duplicates(dummy_duplicates_dataset: Dataset, drop_last: bool, sampler_cls) -> None: batch_size = 2 - sampler_1 = NoDuplicatesBatchSampler( + sampler_1 = sampler_cls( dataset=dummy_duplicates_dataset, batch_size=batch_size, drop_last=drop_last, valid_label_columns=["anchor"] ) - sampler_2 = NoDuplicatesBatchSampler( + sampler_2 = sampler_cls( dataset=dummy_duplicates_dataset, batch_size=batch_size, drop_last=drop_last, valid_label_columns=["positive"] ) diff --git a/tests/test_training_args.py b/tests/test_training_args.py index fe424d06d..3668b6823 100644 --- a/tests/test_training_args.py +++ b/tests/test_training_args.py @@ -46,6 +46,20 @@ def test_hf_argument_parser(): assert args.learning_rate == 0.0005 +def test_hf_argument_parser_no_duplicates_fast(): + parser = HfArgumentParser((SentenceTransformerTrainingArguments,)) + dataclasses = parser.parse_args_into_dataclasses( + args=[ + "--output_dir", + "test_output_dir", + "--batch_sampler", + "no_duplicates_fast", + ] + ) + args = dataclasses[0] + assert args.batch_sampler == BatchSamplers.NO_DUPLICATES_FAST + + @pytest.mark.parametrize("argument_name", ["router_mapping", "learning_rate_mapping"]) def test_hf_argument_parser_incorrect_string_arguments(argument_name): parser = HfArgumentParser((SentenceTransformerTrainingArguments,)) From 829e45e7c0929224e6635f82d3a58c1e266ded42 Mon Sep 17 00:00:00 2001 From: Yuichi Tateno Date: Fri, 9 Jan 2026 15:44:11 +0900 Subject: [PATCH 03/12] Guard hash dataset cleanup --- sentence_transformers/sampler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index 1c310418c..af8869ec9 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -386,6 +386,7 @@ def _build_hashes(self) -> None: columns = list(self.dataset.column_names) # Precompute hash values once to avoid repeated string processing per batch. # Use num_proc to parallelize hashing across CPU cores. + hash_ds: Dataset | None = None hash_ds = self.dataset.map( _hash_batch, batched=True, @@ -419,7 +420,8 @@ def _build_hashes(self) -> None: row_hashes = values.reshape((row_count, row_size)) except Exception as exc: # Surface failures explicitly; the fast sampler expects fixed-length rows. - del hash_ds + if hash_ds is not None: + del hash_ds raise ValueError( "NoDuplicatesFastBatchSampler requires fixed-length hash rows. " "Ensure each sample has the same number of values across columns." @@ -427,7 +429,8 @@ def _build_hashes(self) -> None: self._row_hashes = row_hashes # Drop the temporary dataset to release Arrow buffers promptly. - del hash_ds + if hash_ds is not None: + del hash_ds def __iter__(self) -> Iterator[list[int]]: if self.generator and self.seed is not None: From c103b390e31ea567b693de0c9d0d160f62f5f469 Mon Sep 17 00:00:00 2001 From: Yuichi Tateno Date: Fri, 9 Jan 2026 16:17:42 +0900 Subject: [PATCH 04/12] Add no-duplicate batch sampler benchmark script --- .../evaluation_no_dup_batch_sampler_speed.py | 476 ++++++++++++++++++ 1 file changed, 476 insertions(+) create mode 100644 examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py diff --git a/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py new file mode 100644 index 000000000..cc210453c --- /dev/null +++ b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py @@ -0,0 +1,476 @@ +from __future__ import annotations + +"""Benchmark NoDuplicates batch samplers on Hugging Face datasets. + +Quick run: + python examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py --target fast + +Run examples: + python examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py \ + --dataset-name sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1 \ + --dataset-subset triplet-50 --dataset-split train --batch-size 128 --measure-hash-uss --no-progress-bar + python examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py \ + --dataset-name sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1 \ + --dataset-subset triplet-50 --dataset-split train --batch-size 8192 --measure-hash-uss --no-progress-bar + python examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py \ + --dataset-name sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1 \ + --dataset-subset triplet-hard --dataset-split train --batch-size 128 --measure-hash-uss --no-progress-bar + python examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py \ + --dataset-name sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1 \ + --dataset-subset triplet-hard --dataset-split train --batch-size 8192 --measure-hash-uss --no-progress-bar +""" + +import argparse +import asyncio +import gc +import os +import threading +import time +import tracemalloc + +import datasets +import torch +from datasets import Dataset, load_dataset + +from sentence_transformers.sampler import NoDuplicatesBatchSampler, NoDuplicatesFastBatchSampler + +try: + from tqdm import tqdm +except ImportError: # pragma: no cover - optional dependency + tqdm = None + +try: + import psutil +except ImportError: # pragma: no cover - optional dependency + psutil = None + +datasets.disable_caching() + +DEFAULT_DATASET_NAME = "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1" +DEFAULT_DATASET_SUBSET = "triplet-hard" +DEFAULT_DATASET_SPLIT = "train" + +BATCH_SIZE = 8192 +DROP_LAST = True +SEED = 42 + + +def run_sampler( + name: str, + sampler_cls, + dataset: Dataset, + batch_size: int, + drop_last: bool, + seed: int, + warmup: bool, + show_progress: bool, + measure_hash_mem: bool, + measure_hash_rss: bool, + measure_hash_uss: bool, + sampler_kwargs: dict[str, object] | None = None, +) -> tuple[float, int]: + """Run one sampler and print timing + batch count.""" + generator = torch.Generator() + generator.manual_seed(seed) + sampler = sampler_cls( + dataset=dataset, + batch_size=batch_size, + drop_last=drop_last, + generator=generator, + seed=seed, + **(sampler_kwargs or {}), + ) + + uss_sampler = None + # Optionally precompute hashes and measure their memory cost. + if measure_hash_mem or measure_hash_rss or measure_hash_uss: + if hasattr(sampler, "_build_hashes"): + gc.collect() + if measure_hash_mem: + tracemalloc.start() + start_current, start_peak = tracemalloc.get_traced_memory() + rss_sampler = None + if measure_hash_rss and psutil is not None: + rss_sampler = _RssSampler() + rss_sampler.start() + elif measure_hash_rss and psutil is None: + print(f"{name} hash_rss: n/a (psutil not available)") + if measure_hash_uss and psutil is not None: + uss_sampler = _UssSampler() + uss_sampler.start() + if measure_hash_uss and psutil is None: + print(f"{name} hash_uss: n/a (psutil not available)") + + start = time.perf_counter() + sampler._build_hashes() + build_time = time.perf_counter() - start + gc.collect() + + if rss_sampler is not None: + rss_sampler.stop() + rss_report = rss_sampler.report() + print( + f"{name} hash_rss: current_delta={_format_bytes(rss_report.current_delta)}, " + f"peak_delta={_format_bytes(rss_report.peak_delta)}, build_time={build_time:.3f}s" + ) + if uss_sampler is not None: + uss_sampler.stop() + uss_report = uss_sampler.report() + print( + f"{name} hash_uss: current_delta={_format_bytes(uss_report.current_delta)}, " + f"peak_delta={_format_bytes(uss_report.peak_delta)}, build_time={build_time:.3f}s" + ) + + if measure_hash_mem: + current, peak = tracemalloc.get_traced_memory() + tracemalloc.stop() + current_delta = current - start_current + peak_delta = peak - start_peak + print( + f"{name} hash_mem: current_delta={_format_bytes(current_delta)}, " + f"peak_delta={_format_bytes(peak_delta)}, build_time={build_time:.3f}s" + ) + else: + if measure_hash_rss: + print(f"{name} hash_rss: n/a (no precomputed hashes)") + if measure_hash_uss: + print(f"{name} hash_uss: n/a (no precomputed hashes)") + if measure_hash_mem: + print(f"{name} hash_mem: n/a (no precomputed hashes)") + + # Warm up to reduce first-iteration overhead if requested. + if warmup: + warmup_iter = sampler + if show_progress and tqdm is not None: + warmup_iter = tqdm(warmup_iter, desc=f"{name} warmup", unit="batch") + for _ in warmup_iter: + pass + + # Timed pass. + start = time.perf_counter() + batch_count = 0 + timed_iter = sampler + if show_progress and tqdm is not None: + timed_iter = tqdm(timed_iter, desc=f"{name} timed", unit="batch") + for _ in timed_iter: + batch_count += 1 + elapsed = time.perf_counter() - start + total_rows = len(dataset) + ideal_batches = total_rows // batch_size if drop_last else (total_rows + batch_size - 1) // batch_size + batch_delta = ideal_batches - batch_count + print( + f"{name}: {elapsed:.3f}s ({batch_count} batches; " + f"ideal={ideal_batches}; delta={batch_delta}; batch_size={batch_size})" + ) + return elapsed, batch_count + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark NoDuplicates batch samplers on Hugging Face datasets.") + parser.add_argument("--dataset-name", type=str, default=DEFAULT_DATASET_NAME, help="Hugging Face dataset ID.") + parser.add_argument( + "--dataset-subset", + type=str, + default=DEFAULT_DATASET_SUBSET, + help="Dataset subset/config name (if applicable).", + ) + parser.add_argument( + "--dataset-split", + type=str, + default=DEFAULT_DATASET_SPLIT, + help="Dataset split to load (e.g. train/validation/test).", + ) + parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, help="Mini-batch size.") + parser.add_argument("--seed", type=int, default=SEED, help="Random seed for sampling order.") + parser.add_argument("--warmup", action="store_true", help="Run a warmup pass before timing.") + parser.add_argument("--no-progress-bar", action="store_true", help="Disable tqdm progress bars.") + parser.add_argument("--show-uniqueness", action="store_true", help="Compute and display uniqueness stats.") + parser.add_argument("--uniqueness-workers", type=int, default=8, help="Max worker threads for uniqueness stats.") + parser.add_argument("--measure-hash-mem", action="store_true", help="Measure hash memory via tracemalloc.") + parser.add_argument("--measure-hash-rss", action="store_true", help="Measure hash RSS via psutil.") + parser.add_argument("--measure-hash-uss", action="store_true", help="Measure hash USS via psutil.") + parser.add_argument( + "--hash-num-proc", + type=int, + help="Processes used for hashing (NoDuplicatesFastBatchSampler only).", + ) + parser.add_argument( + "--ds-map-batch-size", + type=int, + help="datasets.map batch size for hashing (NoDuplicatesFastBatchSampler only).", + ) + parser.add_argument( + "--target", + action="append", + choices=["default", "fast"], + help="Which sampler to run (can be passed multiple times).", + ) + return parser.parse_args() + + +def _iter_texts(value: object) -> list[str]: + """Normalize a value into a list of strings for counting.""" + if isinstance(value, list): + return [str(item) for item in value] + return [str(value)] + + +def _format_bytes(value: int) -> str: + """Human-readable byte formatting.""" + units = ["B", "KiB", "MiB", "GiB", "TiB"] + size = float(value) + for unit in units: + if abs(size) < 1024.0 or unit == units[-1]: + return f"{size:.2f}{unit}" + size /= 1024.0 + return f"{size:.2f}TiB" + + +class _RssReport: + def __init__(self, start_rss: int, end_rss: int, peak_rss: int) -> None: + self.start_rss = start_rss + self.end_rss = end_rss + self.peak_rss = peak_rss + self.current_delta = end_rss - start_rss + self.peak_delta = peak_rss - start_rss + + +class _RssSampler: + """Sample RSS (including child processes) during hashing.""" + def __init__(self, interval: float = 0.1) -> None: + self.interval = interval + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + self._start_rss = 0 + self._end_rss = 0 + self._peak_rss = 0 + + def _total_rss(self) -> int: + if psutil is None: + return 0 + proc = psutil.Process(os.getpid()) + total = 0 + try: + total += proc.memory_info().rss + except psutil.NoSuchProcess: + return 0 + for child in proc.children(recursive=True): + try: + total += child.memory_info().rss + except psutil.NoSuchProcess: + continue + return total + + def _run(self) -> None: + while not self._stop_event.is_set(): + rss = self._total_rss() + if rss > self._peak_rss: + self._peak_rss = rss + time.sleep(self.interval) + + def start(self) -> None: + self._start_rss = self._total_rss() + self._peak_rss = self._start_rss + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def stop(self) -> None: + self._stop_event.set() + if self._thread is not None: + self._thread.join() + self._end_rss = self._total_rss() + if self._end_rss > self._peak_rss: + self._peak_rss = self._end_rss + + def report(self) -> _RssReport: + return _RssReport(self._start_rss, self._end_rss, self._peak_rss) + + +class _UssReport: + def __init__(self, start_uss: int, end_uss: int, peak_uss: int) -> None: + self.start_uss = start_uss + self.end_uss = end_uss + self.peak_uss = peak_uss + self.current_delta = end_uss - start_uss + self.peak_delta = peak_uss - start_uss + + +class _UssSampler: + """Sample USS (including child processes) during hashing.""" + def __init__(self, interval: float = 0.1) -> None: + self.interval = interval + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + self._start_uss = 0 + self._end_uss = 0 + self._peak_uss = 0 + + def _total_uss(self) -> int: + if psutil is None: + return 0 + proc = psutil.Process(os.getpid()) + total = 0 + try: + total += proc.memory_full_info().uss + except (psutil.NoSuchProcess, AttributeError): + return 0 + for child in proc.children(recursive=True): + try: + total += child.memory_full_info().uss + except (psutil.NoSuchProcess, AttributeError): + continue + return total + + def _run(self) -> None: + while not self._stop_event.is_set(): + uss = self._total_uss() + if uss > self._peak_uss: + self._peak_uss = uss + time.sleep(self.interval) + + def start(self) -> None: + self._start_uss = self._total_uss() + self._peak_uss = self._start_uss + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def stop(self) -> None: + self._stop_event.set() + if self._thread is not None: + self._thread.join() + self._end_uss = self._total_uss() + if self._end_uss > self._peak_uss: + self._peak_uss = self._end_uss + + def report(self) -> _UssReport: + return _UssReport(self._start_uss, self._end_uss, self._peak_uss) + +def _dup_stats(dataset: Dataset, show_progress: bool, desc: str) -> tuple[int, int, int]: + """Compute total/unique/dup counts across query/doc columns.""" + column_names = list(dataset.column_names) + + query_column = None + if "query" in column_names: + query_column = "query" + elif "anchor" in column_names: + query_column = "anchor" + + doc_columns = [] + doc_candidates = ["text", "positive", "pos", "doc", "document", "negative"] + for name in doc_candidates: + if name in column_names: + doc_columns.append(name) + for name in column_names: + if name.startswith("neg_") or name.startswith("negative_"): + doc_columns.append(name) + + if query_column is None and not doc_columns and len(column_names) >= 2: + query_column = column_names[0] + doc_columns = [column_names[1]] + + counts: dict[str, int] = {} + row_iter = dataset + if show_progress and tqdm is not None: + row_iter = tqdm(row_iter, desc=f"uniqueness:{desc}", unit="row") + for row in row_iter: + if query_column is not None: + for text in _iter_texts(row.get(query_column)): + counts[text] = counts.get(text, 0) + 1 + for doc_column in doc_columns: + for text in _iter_texts(row.get(doc_column)): + counts[text] = counts.get(text, 0) + 1 + + total = sum(counts.values()) + unique = len(counts) + dup = total - unique + return total, unique, dup + + +async def compute_uniqueness( + datasets_map: dict[str, Dataset], + workers: int, + show_progress: bool, +) -> dict[str, tuple[int, int, int] | None]: + """Run uniqueness checks concurrently with a bounded thread pool.""" + semaphore = asyncio.Semaphore(workers) + results: dict[str, tuple[int, int, int] | None] = {} + + async def run_one(name: str, dataset: Dataset) -> tuple[str, tuple[int, int, int] | None]: + async with semaphore: + try: + stats = await asyncio.to_thread(_dup_stats, dataset, show_progress, name) + except Exception: + return name, None + return name, stats + + tasks = [asyncio.create_task(run_one(name, dataset)) for name, dataset in datasets_map.items()] + for name, stats in await asyncio.gather(*tasks): + results[name] = stats + + return results + + +def _load_hf_dataset(name: str, subset: str | None, split: str) -> Dataset: + """Load a HF dataset split with an optional subset/config.""" + if subset: + return load_dataset(name, subset, split=split) + return load_dataset(name, split=split) + + +def main() -> None: + args = parse_args() + dataset_subset = args.dataset_subset or None + dataset = _load_hf_dataset(args.dataset_name, dataset_subset, args.dataset_split) + dataset_key = f"hf_{args.dataset_name}_{dataset_subset or 'default'}_{args.dataset_split}" + + print("Benchmark settings:") + print(f" batch_size={args.batch_size}, drop_last={DROP_LAST}, seed={args.seed}") + print(f" hf_dataset={args.dataset_name} subset={dataset_subset or 'default'} split={args.dataset_split}") + print(f" rows={len(dataset)}") + + if args.show_uniqueness: + results = asyncio.run( + compute_uniqueness( + {dataset_key: dataset}, + workers=args.uniqueness_workers, + show_progress=not args.no_progress_bar, + ) + ) + stats = results.get(dataset_key) + if stats is None: + print(" uniqueness: failed") + else: + total, unique, dup = stats + dup_rate = dup / total if total else 0.0 + print(f" uniqueness: total={total} unique={unique} dup={dup} dup_rate={dup_rate:.6f}") + + targets = args.target or ["default", "fast"] + fast_kwargs = {} + if args.hash_num_proc is not None: + fast_kwargs["hash_num_proc"] = args.hash_num_proc + if args.ds_map_batch_size is not None: + fast_kwargs["ds_map_batch_size"] = args.ds_map_batch_size + target_map = { + "default": ("NoDuplicatesBatchSampler", NoDuplicatesBatchSampler, {}), + "fast": ("NoDuplicatesFastBatchSampler", NoDuplicatesFastBatchSampler, fast_kwargs), + } + for target in targets: + name, sampler_cls, sampler_kwargs = target_map[target] + run_sampler( + name, + sampler_cls, + dataset, + args.batch_size, + DROP_LAST, + args.seed, + warmup=args.warmup, + show_progress=not args.no_progress_bar, + measure_hash_mem=args.measure_hash_mem, + measure_hash_rss=args.measure_hash_rss, + measure_hash_uss=args.measure_hash_uss, + sampler_kwargs=sampler_kwargs, + ) + + +if __name__ == "__main__": + main() From 3de2e0c04a6c11a8f709ba3d9d68ca3115829f0a Mon Sep 17 00:00:00 2001 From: Yuichi Tateno Date: Fri, 9 Jan 2026 16:38:50 +0900 Subject: [PATCH 05/12] Rename hash num_proc parameter --- .../evaluation/evaluation_no_dup_batch_sampler_speed.py | 6 +++--- sentence_transformers/sampler.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py index cc210453c..ee53a0fc1 100644 --- a/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py +++ b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py @@ -190,7 +190,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--measure-hash-rss", action="store_true", help="Measure hash RSS via psutil.") parser.add_argument("--measure-hash-uss", action="store_true", help="Measure hash USS via psutil.") parser.add_argument( - "--hash-num-proc", + "--num-proc", type=int, help="Processes used for hashing (NoDuplicatesFastBatchSampler only).", ) @@ -446,8 +446,8 @@ def main() -> None: targets = args.target or ["default", "fast"] fast_kwargs = {} - if args.hash_num_proc is not None: - fast_kwargs["hash_num_proc"] = args.hash_num_proc + if args.num_proc is not None: + fast_kwargs["num_proc"] = args.num_proc if args.ds_map_batch_size is not None: fast_kwargs["ds_map_batch_size"] = args.ds_map_batch_size target_map = { diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index af8869ec9..cd3a33f87 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -327,7 +327,7 @@ def __init__( valid_label_columns: list[str] | None = None, generator: torch.Generator | None = None, seed: int = 0, - hash_num_proc: int | None = None, + num_proc: int | None = None, ds_map_batch_size: int = 1000, ) -> None: """ @@ -355,7 +355,7 @@ def __init__( generator (torch.Generator, optional): Optional random number generator for shuffling the indices. seed (int): Seed for the random number generator to ensure reproducibility. Defaults to 0. - hash_num_proc (int, optional): Number of processes for hashing with datasets.map. Defaults to min(8, cpu-1). + num_proc (int, optional): Number of processes for hashing with datasets.map. Defaults to min(8, cpu-1). ds_map_batch_size (int, optional): Batch size for datasets.map hashing. Defaults to 1000. """ super().__init__( @@ -375,7 +375,7 @@ def __init__( cpu_count = os.cpu_count() or 1 # Leave one core free to avoid saturating the system when hashing. default_workers = max(1, min(8, cpu_count - 1)) - self.hash_num_proc = hash_num_proc or default_workers + self.num_proc = num_proc or default_workers self.ds_map_batch_size = ds_map_batch_size self._row_hashes: np.ndarray | list[list[int]] | None = None @@ -391,7 +391,7 @@ def _build_hashes(self) -> None: _hash_batch, batched=True, batch_size=self.ds_map_batch_size, - num_proc=self.hash_num_proc, + num_proc=self.num_proc, remove_columns=columns, fn_kwargs={"columns": columns, "exclude_columns": exclude_columns}, desc="Hashing dataset values", From c257dd3c6bd653f439c0bdf022dd23ac664ea85d Mon Sep 17 00:00:00 2001 From: Yuichi Tateno Date: Fri, 9 Jan 2026 16:40:39 +0900 Subject: [PATCH 06/12] Simplify xxhash requirement message --- sentence_transformers/sampler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index cd3a33f87..78d3900b6 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -368,8 +368,7 @@ def __init__( ) if xxhash is None: raise ImportError( - "NoDuplicatesFastBatchSampler requires `xxhash`. Install it via `pip install xxhash` " - "or use the `train` extra." + "NoDuplicatesFastBatchSampler requires `xxhash`. Install `xxhash` to use this sampler." ) self.dataset = _remove_label_columns(dataset, self.valid_label_columns) cpu_count = os.cpu_count() or 1 From 2f65cbfbe46dd1f9bc2e7e019fde22e43c277310 Mon Sep 17 00:00:00 2001 From: Yuichi Tateno Date: Sun, 11 Jan 2026 13:23:25 +0900 Subject: [PATCH 07/12] Run ruff format --- .../evaluation/evaluation_no_dup_batch_sampler_speed.py | 3 +++ sentence_transformers/sampler.py | 4 +--- tests/samplers/test_no_duplicates_batch_sampler.py | 4 +--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py index ee53a0fc1..30fd1494f 100644 --- a/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py +++ b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py @@ -237,6 +237,7 @@ def __init__(self, start_rss: int, end_rss: int, peak_rss: int) -> None: class _RssSampler: """Sample RSS (including child processes) during hashing.""" + def __init__(self, interval: float = 0.1) -> None: self.interval = interval self._stop_event = threading.Event() @@ -297,6 +298,7 @@ def __init__(self, start_uss: int, end_uss: int, peak_uss: int) -> None: class _UssSampler: """Sample USS (including child processes) during hashing.""" + def __init__(self, interval: float = 0.1) -> None: self.interval = interval self._stop_event = threading.Event() @@ -345,6 +347,7 @@ def stop(self) -> None: def report(self) -> _UssReport: return _UssReport(self._start_uss, self._end_uss, self._peak_uss) + def _dup_stats(dataset: Dataset, show_progress: bool, desc: str) -> tuple[int, int, int]: """Compute total/unique/dup counts across query/doc columns.""" column_names = list(dataset.column_names) diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index 78d3900b6..23a82dd04 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -367,9 +367,7 @@ def __init__( seed=seed, ) if xxhash is None: - raise ImportError( - "NoDuplicatesFastBatchSampler requires `xxhash`. Install `xxhash` to use this sampler." - ) + raise ImportError("NoDuplicatesFastBatchSampler requires `xxhash`. Install `xxhash` to use this sampler.") self.dataset = _remove_label_columns(dataset, self.valid_label_columns) cpu_count = os.cpu_count() or 1 # Leave one core free to avoid saturating the system when hashing. diff --git a/tests/samplers/test_no_duplicates_batch_sampler.py b/tests/samplers/test_no_duplicates_batch_sampler.py index 87ef336a5..f7d1c21c3 100644 --- a/tests/samplers/test_no_duplicates_batch_sampler.py +++ b/tests/samplers/test_no_duplicates_batch_sampler.py @@ -58,9 +58,7 @@ def dummy_duplicates_dataset() -> Dataset: def test_group_by_label_batch_sampler_label_a(dummy_dataset: Dataset, sampler_cls) -> None: batch_size = 10 - sampler = sampler_cls( - dataset=dummy_dataset, batch_size=batch_size, drop_last=True, valid_label_columns=["label"] - ) + sampler = sampler_cls(dataset=dummy_dataset, batch_size=batch_size, drop_last=True, valid_label_columns=["label"]) batches = list(iter(sampler)) From 8e747f6068946bf39eb7b018f3d66e142bcc2221 Mon Sep 17 00:00:00 2001 From: Yuichi Tateno Date: Wed, 28 Jan 2026 13:49:10 +0900 Subject: [PATCH 08/12] Merge fast no-duplicates into NoDuplicatesBatchSampler --- .../evaluation_no_dup_batch_sampler_speed.py | 32 ++-- sentence_transformers/sampler.py | 143 ++++++------------ sentence_transformers/trainer.py | 3 +- sentence_transformers/training_args.py | 4 +- .../test_no_duplicates_batch_sampler.py | 60 ++++++-- 5 files changed, 109 insertions(+), 133 deletions(-) diff --git a/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py index 30fd1494f..b5baa5d05 100644 --- a/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py +++ b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py @@ -32,7 +32,7 @@ import torch from datasets import Dataset, load_dataset -from sentence_transformers.sampler import NoDuplicatesBatchSampler, NoDuplicatesFastBatchSampler +from sentence_transformers.sampler import NoDuplicatesBatchSampler try: from tqdm import tqdm @@ -84,7 +84,7 @@ def run_sampler( uss_sampler = None # Optionally precompute hashes and measure their memory cost. if measure_hash_mem or measure_hash_rss or measure_hash_uss: - if hasattr(sampler, "_build_hashes"): + if getattr(sampler, "precompute_hashes", False): gc.collect() if measure_hash_mem: tracemalloc.start() @@ -132,11 +132,11 @@ def run_sampler( ) else: if measure_hash_rss: - print(f"{name} hash_rss: n/a (no precomputed hashes)") + print(f"{name} hash_rss: n/a (precompute_hashes disabled)") if measure_hash_uss: - print(f"{name} hash_uss: n/a (no precomputed hashes)") + print(f"{name} hash_uss: n/a (precompute_hashes disabled)") if measure_hash_mem: - print(f"{name} hash_mem: n/a (no precomputed hashes)") + print(f"{name} hash_mem: n/a (precompute_hashes disabled)") # Warm up to reduce first-iteration overhead if requested. if warmup: @@ -189,15 +189,11 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--measure-hash-mem", action="store_true", help="Measure hash memory via tracemalloc.") parser.add_argument("--measure-hash-rss", action="store_true", help="Measure hash RSS via psutil.") parser.add_argument("--measure-hash-uss", action="store_true", help="Measure hash USS via psutil.") + parser.add_argument("--precompute-num-proc", type=int, help="Processes used for hashing (fast target only).") parser.add_argument( - "--num-proc", + "--precompute-batch-size", type=int, - help="Processes used for hashing (NoDuplicatesFastBatchSampler only).", - ) - parser.add_argument( - "--ds-map-batch-size", - type=int, - help="datasets.map batch size for hashing (NoDuplicatesFastBatchSampler only).", + help="datasets.map batch size for hashing (fast target only).", ) parser.add_argument( "--target", @@ -448,14 +444,14 @@ def main() -> None: print(f" uniqueness: total={total} unique={unique} dup={dup} dup_rate={dup_rate:.6f}") targets = args.target or ["default", "fast"] - fast_kwargs = {} - if args.num_proc is not None: - fast_kwargs["num_proc"] = args.num_proc - if args.ds_map_batch_size is not None: - fast_kwargs["ds_map_batch_size"] = args.ds_map_batch_size + fast_kwargs = {"precompute_hashes": True} + if args.precompute_num_proc is not None: + fast_kwargs["precompute_num_proc"] = args.precompute_num_proc + if args.precompute_batch_size is not None: + fast_kwargs["precompute_batch_size"] = args.precompute_batch_size target_map = { "default": ("NoDuplicatesBatchSampler", NoDuplicatesBatchSampler, {}), - "fast": ("NoDuplicatesFastBatchSampler", NoDuplicatesFastBatchSampler, fast_kwargs), + "fast": ("NoDuplicatesBatchSampler (precompute)", NoDuplicatesBatchSampler, fast_kwargs), } for target in targets: name, sampler_cls, sampler_kwargs = target_map[target] diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index 23a82dd04..2197d0bc6 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -250,6 +250,9 @@ def __init__( valid_label_columns: list[str] | None = None, generator: torch.Generator | None = None, seed: int = 0, + precompute_hashes: bool = False, + precompute_num_proc: int | None = None, + precompute_batch_size: int = 1000, ) -> None: """ This sampler creates batches such that each batch contains samples where the values are unique, @@ -276,6 +279,13 @@ def __init__( generator (torch.Generator, optional): Optional random number generator for shuffling the indices. seed (int): Seed for the random number generator to ensure reproducibility. Defaults to 0. + precompute_hashes (bool, optional): If True, precompute xxhash 64-bit values for dataset + fields using ``datasets.map`` to speed up duplicate checks. Requires ``xxhash`` and + uses additional memory. Defaults to False. + precompute_num_proc (int, optional): Number of processes for hashing with ``datasets.map``. + Defaults to ``min(8, cpu_count - 1)`` when precomputing. + precompute_batch_size (int, optional): Batch size for ``datasets.map`` hashing. + Defaults to 1000. """ super().__init__( dataset, @@ -286,98 +296,24 @@ def __init__( seed=seed, ) self.dataset = _remove_label_columns(dataset, self.valid_label_columns) - - def __iter__(self) -> Iterator[list[int]]: - """ - Iterate over the remaining non-yielded indices. For each index, check if the sample values are already in the - batch. If not, add the sample values to the batch keep going until the batch is full. If the batch is full, yield - the batch indices and continue with the next batch. - """ - if self.generator and self.seed is not None: - self.generator.manual_seed(self.seed + self.epoch) - - # We create a dictionary to None because we need a data structure that: - # 1. Allows for cheap removal of elements - # 2. Preserves the order of elements, i.e. remains random - remaining_indices = dict.fromkeys(torch.randperm(len(self.dataset), generator=self.generator).tolist()) - - def get_sample_values(index: int) -> set[str]: - return {str(value) for key, value in self.dataset[index].items() if key != "dataset_name"} - - yield from _iter_no_duplicate_batches( - remaining_indices, - get_sample_values, - self.batch_size, - self.drop_last, - ) - - def __len__(self) -> int: - if self.drop_last: - return len(self.dataset) // self.batch_size - else: - return (len(self.dataset) + self.batch_size - 1) // self.batch_size - - -class NoDuplicatesFastBatchSampler(DefaultBatchSampler): - def __init__( - self, - dataset: Dataset, - batch_size: int, - drop_last: bool, - valid_label_columns: list[str] | None = None, - generator: torch.Generator | None = None, - seed: int = 0, - num_proc: int | None = None, - ds_map_batch_size: int = 1000, - ) -> None: - """ - This sampler creates batches such that each batch contains samples where the values are unique, - even across columns. It uses the same batch construction approach as NoDuplicatesBatchSampler, - but speeds up duplicate checks by caching per-row xxhash 64-bit values computed with datasets.map. - - Recommended for: - - :class:`~sentence_transformers.losses.MultipleNegativesRankingLoss` - - :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss` - - :class:`~sentence_transformers.losses.MultipleNegativesSymmetricRankingLoss` - - :class:`~sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss` - - :class:`~sentence_transformers.losses.MegaBatchMarginLoss` - - :class:`~sentence_transformers.losses.GISTEmbedLoss` - - :class:`~sentence_transformers.losses.CachedGISTEmbedLoss` - - Args: - dataset (Dataset): The dataset to sample from. - batch_size (int): Number of samples per batch. - drop_last (bool): If True, drop the last incomplete batch if the dataset size - is not divisible by the batch size. - valid_label_columns (List[str], optional): List of column names to check for labels. - The first column name from ``valid_label_columns`` found in the dataset will - be used as the label column. - generator (torch.Generator, optional): Optional random number generator for shuffling - the indices. - seed (int): Seed for the random number generator to ensure reproducibility. Defaults to 0. - num_proc (int, optional): Number of processes for hashing with datasets.map. Defaults to min(8, cpu-1). - ds_map_batch_size (int, optional): Batch size for datasets.map hashing. Defaults to 1000. - """ - super().__init__( - dataset, - batch_size=batch_size, - drop_last=drop_last, - valid_label_columns=valid_label_columns, - generator=generator, - seed=seed, - ) - if xxhash is None: - raise ImportError("NoDuplicatesFastBatchSampler requires `xxhash`. Install `xxhash` to use this sampler.") - self.dataset = _remove_label_columns(dataset, self.valid_label_columns) - cpu_count = os.cpu_count() or 1 - # Leave one core free to avoid saturating the system when hashing. - default_workers = max(1, min(8, cpu_count - 1)) - self.num_proc = num_proc or default_workers - self.ds_map_batch_size = ds_map_batch_size + self.precompute_hashes = precompute_hashes + self.precompute_num_proc = precompute_num_proc + self.precompute_batch_size = precompute_batch_size self._row_hashes: np.ndarray | list[list[int]] | None = None + if self.precompute_hashes: + if xxhash is None: + raise ImportError( + "NoDuplicatesBatchSampler with precompute_hashes=True requires `xxhash`. " + "Install `xxhash` to use this option." + ) + cpu_count = os.cpu_count() or 1 + # Leave one core free to avoid saturating the system when hashing. + default_workers = max(1, min(8, cpu_count - 1)) + if self.precompute_num_proc is None: + self.precompute_num_proc = default_workers def _build_hashes(self) -> None: - if self._row_hashes is not None: + if not self.precompute_hashes or self._row_hashes is not None: return exclude_columns = set(self.valid_label_columns or []) | {"dataset_name"} columns = list(self.dataset.column_names) @@ -387,8 +323,8 @@ def _build_hashes(self) -> None: hash_ds = self.dataset.map( _hash_batch, batched=True, - batch_size=self.ds_map_batch_size, - num_proc=self.num_proc, + batch_size=self.precompute_batch_size, + num_proc=self.precompute_num_proc, remove_columns=columns, fn_kwargs={"columns": columns, "exclude_columns": exclude_columns}, desc="Hashing dataset values", @@ -416,11 +352,11 @@ def _build_hashes(self) -> None: raise ValueError("Unexpected hashed value buffer size.") row_hashes = values.reshape((row_count, row_size)) except Exception as exc: - # Surface failures explicitly; the fast sampler expects fixed-length rows. + # Surface failures explicitly; the precompute option expects fixed-length rows. if hash_ds is not None: del hash_ds raise ValueError( - "NoDuplicatesFastBatchSampler requires fixed-length hash rows. " + "NoDuplicatesBatchSampler with precompute_hashes=True requires fixed-length hash rows. " "Ensure each sample has the same number of values across columns." ) from exc @@ -430,19 +366,32 @@ def _build_hashes(self) -> None: del hash_ds def __iter__(self) -> Iterator[list[int]]: + """ + Iterate over the remaining non-yielded indices. For each index, check if the sample values are already in the + batch. If not, add the sample values to the batch keep going until the batch is full. If the batch is full, yield + the batch indices and continue with the next batch. + """ if self.generator and self.seed is not None: self.generator.manual_seed(self.seed + self.epoch) - self._build_hashes() - row_hashes = self._row_hashes if self._row_hashes is not None else [] + if self.precompute_hashes: + self._build_hashes() + row_hashes = self._row_hashes if self._row_hashes is not None else [] # We create a dictionary to None because we need a data structure that: # 1. Allows for cheap removal of elements # 2. Preserves the order of elements, i.e. remains random remaining_indices = dict.fromkeys(torch.randperm(len(self.dataset), generator=self.generator).tolist()) - def get_sample_values(index: int): - return row_hashes[index] + if self.precompute_hashes: + + def get_sample_values(index: int): + return row_hashes[index] + + else: + + def get_sample_values(index: int) -> set[str]: + return {str(value) for key, value in self.dataset[index].items() if key != "dataset_name"} yield from _iter_no_duplicate_batches( remaining_indices, diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 21bfa18b3..4814d907f 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -31,7 +31,6 @@ GroupByLabelBatchSampler, MultiDatasetDefaultBatchSampler, NoDuplicatesBatchSampler, - NoDuplicatesFastBatchSampler, ProportionalBatchSampler, RoundRobinBatchSampler, ) @@ -681,7 +680,7 @@ def get_batch_sampler( return NoDuplicatesBatchSampler(dataset, **batch_sampler_kwargs) if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES_FAST: - return NoDuplicatesFastBatchSampler(dataset, **batch_sampler_kwargs) + return NoDuplicatesBatchSampler(dataset, precompute_hashes=True, **batch_sampler_kwargs) if self.args.batch_sampler == BatchSamplers.GROUP_BY_LABEL: return GroupByLabelBatchSampler(dataset, **batch_sampler_kwargs) diff --git a/sentence_transformers/training_args.py b/sentence_transformers/training_args.py index c176d5561..c48e458b1 100644 --- a/sentence_transformers/training_args.py +++ b/sentence_transformers/training_args.py @@ -25,8 +25,8 @@ class BatchSamplers(ExplicitEnum): PyTorch batch sampler. - ``BatchSamplers.NO_DUPLICATES``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler`, ensuring no duplicate samples in a batch. - - ``BatchSamplers.NO_DUPLICATES_FAST``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesFastBatchSampler`, - a faster sampler that also ensures no duplicate samples in a batch. + - ``BatchSamplers.NO_DUPLICATES_FAST``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler` + with ``precompute_hashes=True``, a faster option that also ensures no duplicate samples in a batch. Both are recommended for losses that use in-batch negatives, such as: diff --git a/tests/samplers/test_no_duplicates_batch_sampler.py b/tests/samplers/test_no_duplicates_batch_sampler.py index f7d1c21c3..253bd3603 100644 --- a/tests/samplers/test_no_duplicates_batch_sampler.py +++ b/tests/samplers/test_no_duplicates_batch_sampler.py @@ -6,11 +6,8 @@ import torch from torch.utils.data import ConcatDataset -from sentence_transformers.sampler import ( - NoDuplicatesBatchSampler, - NoDuplicatesFastBatchSampler, - ProportionalBatchSampler, -) +from sentence_transformers import sampler as sampler_module +from sentence_transformers.sampler import NoDuplicatesBatchSampler, ProportionalBatchSampler from sentence_transformers.util import is_datasets_available if is_datasets_available(): @@ -54,11 +51,27 @@ def dummy_duplicates_dataset() -> Dataset: return Dataset.from_list(values) -@pytest.mark.parametrize("sampler_cls", [NoDuplicatesBatchSampler, NoDuplicatesFastBatchSampler]) -def test_group_by_label_batch_sampler_label_a(dummy_dataset: Dataset, sampler_cls) -> None: +@pytest.mark.parametrize("precompute_hashes", [False, True]) +def test_group_by_label_batch_sampler_label_a(dummy_dataset: Dataset, precompute_hashes: bool) -> None: batch_size = 10 - sampler = sampler_cls(dataset=dummy_dataset, batch_size=batch_size, drop_last=True, valid_label_columns=["label"]) + sampler_kwargs = {} + if precompute_hashes: + if sampler_module.xxhash is None: + pytest.skip("xxhash not installed") + sampler_kwargs = { + "precompute_hashes": True, + "precompute_num_proc": 1, + "precompute_batch_size": 10, + } + + sampler = NoDuplicatesBatchSampler( + dataset=dummy_dataset, + batch_size=batch_size, + drop_last=True, + valid_label_columns=["label"], + **sampler_kwargs, + ) batches = list(iter(sampler)) @@ -71,15 +84,34 @@ def test_group_by_label_batch_sampler_label_a(dummy_dataset: Dataset, sampler_cl assert len(batch_values) == len(set(batch_values)), f"Batch {batch} contains duplicate values: {batch_values}" -@pytest.mark.parametrize("sampler_cls", [NoDuplicatesBatchSampler, NoDuplicatesFastBatchSampler]) @pytest.mark.parametrize("drop_last", [True, False]) -def test_proportional_no_duplicates(dummy_duplicates_dataset: Dataset, drop_last: bool, sampler_cls) -> None: +@pytest.mark.parametrize("precompute_hashes", [False, True]) +def test_proportional_no_duplicates( + dummy_duplicates_dataset: Dataset, drop_last: bool, precompute_hashes: bool +) -> None: batch_size = 2 - sampler_1 = sampler_cls( - dataset=dummy_duplicates_dataset, batch_size=batch_size, drop_last=drop_last, valid_label_columns=["anchor"] + sampler_kwargs = {} + if precompute_hashes: + if sampler_module.xxhash is None: + pytest.skip("xxhash not installed") + sampler_kwargs = { + "precompute_hashes": True, + "precompute_num_proc": 1, + "precompute_batch_size": 10, + } + sampler_1 = NoDuplicatesBatchSampler( + dataset=dummy_duplicates_dataset, + batch_size=batch_size, + drop_last=drop_last, + valid_label_columns=["anchor"], + **sampler_kwargs, ) - sampler_2 = sampler_cls( - dataset=dummy_duplicates_dataset, batch_size=batch_size, drop_last=drop_last, valid_label_columns=["positive"] + sampler_2 = NoDuplicatesBatchSampler( + dataset=dummy_duplicates_dataset, + batch_size=batch_size, + drop_last=drop_last, + valid_label_columns=["positive"], + **sampler_kwargs, ) concat_dataset = ConcatDataset([dummy_duplicates_dataset, dummy_duplicates_dataset]) From 2b534a2461db9bba8cdabf08f1b43408ac7ab062 Mon Sep 17 00:00:00 2001 From: Yuichi Tateno Date: Wed, 28 Jan 2026 13:49:55 +0900 Subject: [PATCH 09/12] Update sampler docs and batch sampler args test --- docs/package_reference/sentence_transformer/sampler.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/docs/package_reference/sentence_transformer/sampler.md b/docs/package_reference/sentence_transformer/sampler.md index c4b4fe856..76f9f338b 100644 --- a/docs/package_reference/sentence_transformer/sampler.md +++ b/docs/package_reference/sentence_transformer/sampler.md @@ -17,11 +17,6 @@ :members: ``` -```{eval-rst} -.. autoclass:: sentence_transformers.sampler.NoDuplicatesFastBatchSampler - :members: -``` - ```{eval-rst} .. autoclass:: sentence_transformers.sampler.GroupByLabelBatchSampler :members: From 132b16b29cdb36285d68f58d4d0a669fcbcd8010 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 28 Jan 2026 12:56:04 +0100 Subject: [PATCH 10/12] Rename to NO_DUPLICATES_HASHED; update some docs Tiny code improvements --- .../evaluation_no_dup_batch_sampler_speed.py | 18 +++++++++--------- sentence_transformers/sampler.py | 18 ++++++++++-------- sentence_transformers/trainer.py | 2 +- sentence_transformers/training_args.py | 7 ++++--- tests/test_training_args.py | 6 +++--- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py index b5baa5d05..9c0b1d0d2 100644 --- a/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py +++ b/examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py @@ -3,7 +3,7 @@ """Benchmark NoDuplicates batch samplers on Hugging Face datasets. Quick run: - python examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py --target fast + python examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py --target hashed Run examples: python examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.py \ @@ -189,16 +189,16 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--measure-hash-mem", action="store_true", help="Measure hash memory via tracemalloc.") parser.add_argument("--measure-hash-rss", action="store_true", help="Measure hash RSS via psutil.") parser.add_argument("--measure-hash-uss", action="store_true", help="Measure hash USS via psutil.") - parser.add_argument("--precompute-num-proc", type=int, help="Processes used for hashing (fast target only).") + parser.add_argument("--precompute-num-proc", type=int, help="Processes used for hashing (hashed target only).") parser.add_argument( "--precompute-batch-size", type=int, - help="datasets.map batch size for hashing (fast target only).", + help="datasets.map batch size for hashing (hashed target only).", ) parser.add_argument( "--target", action="append", - choices=["default", "fast"], + choices=["default", "hashed"], help="Which sampler to run (can be passed multiple times).", ) return parser.parse_args() @@ -443,15 +443,15 @@ def main() -> None: dup_rate = dup / total if total else 0.0 print(f" uniqueness: total={total} unique={unique} dup={dup} dup_rate={dup_rate:.6f}") - targets = args.target or ["default", "fast"] - fast_kwargs = {"precompute_hashes": True} + targets = args.target or ["default", "hashed"] + hashed_kwargs = {"precompute_hashes": True} if args.precompute_num_proc is not None: - fast_kwargs["precompute_num_proc"] = args.precompute_num_proc + hashed_kwargs["precompute_num_proc"] = args.precompute_num_proc if args.precompute_batch_size is not None: - fast_kwargs["precompute_batch_size"] = args.precompute_batch_size + hashed_kwargs["precompute_batch_size"] = args.precompute_batch_size target_map = { "default": ("NoDuplicatesBatchSampler", NoDuplicatesBatchSampler, {}), - "fast": ("NoDuplicatesBatchSampler (precompute)", NoDuplicatesBatchSampler, fast_kwargs), + "hashed": ("NoDuplicatesBatchSampler (hashed)", NoDuplicatesBatchSampler, hashed_kwargs), } for target in targets: name, sampler_cls, sampler_kwargs = target_map[target] diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index 2197d0bc6..432a312a2 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -53,8 +53,7 @@ def _hash_batch( for column in active_columns: value = batch[column][row_idx] if isinstance(value, list): - for item in value: - row_hashes.append(_xxhash_int64(str(item))) + row_hashes.extend(_xxhash_int64(str(item)) for item in value) else: row_hashes.append(_xxhash_int64(str(value))) hashes.append(row_hashes) @@ -71,7 +70,7 @@ def _remove_label_columns(dataset: Dataset, valid_label_columns: list[str] | Non def _has_overlap(sample_values, batch_values: set[Any]) -> bool: # Avoid materializing a set if we already have one. if isinstance(sample_values, set): - return bool(sample_values & batch_values) + return not sample_values.isdisjoint(batch_values) return any(value in batch_values for value in sample_values) @@ -280,10 +279,13 @@ def __init__( the indices. seed (int): Seed for the random number generator to ensure reproducibility. Defaults to 0. precompute_hashes (bool, optional): If True, precompute xxhash 64-bit values for dataset - fields using ``datasets.map`` to speed up duplicate checks. Requires ``xxhash`` and - uses additional memory. Defaults to False. + fields using ``datasets.map`` to speed up duplicate checks. Requires ``xxhash`` to + be installed and uses additional memory: in theory roughly + ``len(dataset) * num_columns * 8`` bytes for the dense int64 hash matrix, + although actual memory usage may therefore differ in practice. Defaults to False. precompute_num_proc (int, optional): Number of processes for hashing with ``datasets.map``. - Defaults to ``min(8, cpu_count - 1)`` when precomputing. + If set to ``None``, defaults to ``min(8, cpu_count - 1)`` when ``precompute_hashes`` + is True. precompute_batch_size (int, optional): Batch size for ``datasets.map`` hashing. Defaults to 1000. """ @@ -299,7 +301,7 @@ def __init__( self.precompute_hashes = precompute_hashes self.precompute_num_proc = precompute_num_proc self.precompute_batch_size = precompute_batch_size - self._row_hashes: np.ndarray | list[list[int]] | None = None + self._row_hashes: np.ndarray | None = None if self.precompute_hashes: if xxhash is None: raise ImportError( @@ -376,7 +378,7 @@ def __iter__(self) -> Iterator[list[int]]: if self.precompute_hashes: self._build_hashes() - row_hashes = self._row_hashes if self._row_hashes is not None else [] + row_hashes: np.ndarray = self._row_hashes # We create a dictionary to None because we need a data structure that: # 1. Allows for cheap removal of elements diff --git a/sentence_transformers/trainer.py b/sentence_transformers/trainer.py index 4814d907f..52404e0df 100644 --- a/sentence_transformers/trainer.py +++ b/sentence_transformers/trainer.py @@ -679,7 +679,7 @@ def get_batch_sampler( if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES: return NoDuplicatesBatchSampler(dataset, **batch_sampler_kwargs) - if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES_FAST: + if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES_HASHED: return NoDuplicatesBatchSampler(dataset, precompute_hashes=True, **batch_sampler_kwargs) if self.args.batch_sampler == BatchSamplers.GROUP_BY_LABEL: diff --git a/sentence_transformers/training_args.py b/sentence_transformers/training_args.py index c48e458b1..dd296448a 100644 --- a/sentence_transformers/training_args.py +++ b/sentence_transformers/training_args.py @@ -25,8 +25,9 @@ class BatchSamplers(ExplicitEnum): PyTorch batch sampler. - ``BatchSamplers.NO_DUPLICATES``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler`, ensuring no duplicate samples in a batch. - - ``BatchSamplers.NO_DUPLICATES_FAST``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler` - with ``precompute_hashes=True``, a faster option that also ensures no duplicate samples in a batch. + - ``BatchSamplers.NO_DUPLICATES_HASHED``: Uses :class:`~sentence_transformers.sampler.NoDuplicatesBatchSampler` + with ``precompute_hashes=True``, a variant that precomputes hashes for faster duplicate checks at a small memory cost. + Requires the ``xxhash`` library to be installed. Both are recommended for losses that use in-batch negatives, such as: @@ -83,7 +84,7 @@ class BatchSamplers(ExplicitEnum): BATCH_SAMPLER = "batch_sampler" NO_DUPLICATES = "no_duplicates" - NO_DUPLICATES_FAST = "no_duplicates_fast" + NO_DUPLICATES_HASHED = "no_duplicates_hashed" GROUP_BY_LABEL = "group_by_label" diff --git a/tests/test_training_args.py b/tests/test_training_args.py index 3668b6823..47177a7f6 100644 --- a/tests/test_training_args.py +++ b/tests/test_training_args.py @@ -46,18 +46,18 @@ def test_hf_argument_parser(): assert args.learning_rate == 0.0005 -def test_hf_argument_parser_no_duplicates_fast(): +def test_hf_argument_parser_no_duplicates_hashed(): parser = HfArgumentParser((SentenceTransformerTrainingArguments,)) dataclasses = parser.parse_args_into_dataclasses( args=[ "--output_dir", "test_output_dir", "--batch_sampler", - "no_duplicates_fast", + "no_duplicates_hashed", ] ) args = dataclasses[0] - assert args.batch_sampler == BatchSamplers.NO_DUPLICATES_FAST + assert args.batch_sampler == BatchSamplers.NO_DUPLICATES_HASHED @pytest.mark.parametrize("argument_name", ["router_mapping", "learning_rate_mapping"]) From 3bbb0b57d6199ee65ece7eebabdaebe92b2e7f0f Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 28 Jan 2026 13:14:26 +0100 Subject: [PATCH 11/12] Move methods around a bit The _iter_no_duplicate_batches and _remove_label_columns can be placed back in the class itself, as there's now just 1 again. --- sentence_transformers/sampler.py | 160 +++++++++++++------------------ 1 file changed, 68 insertions(+), 92 deletions(-) diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index 432a312a2..ca34a5485 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -28,83 +28,6 @@ _XXHASH_UINT64_MAX = 1 << 64 -def _xxhash_int64(value: str) -> int: - # Convert uint64 -> int64 to keep values compatible with Arrow int64 storage. - hashed = xxhash.xxh64_intdigest(value) - if hashed >= _XXHASH_INT64_MAX: - hashed -= _XXHASH_UINT64_MAX - return hashed - - -def _hash_batch( - batch: dict[str, list[Any]], - columns: list[str], - exclude_columns: set[str], -) -> dict[str, list[list[int]]]: - # Must be defined at module scope because datasets.map with num_proc pickles this function. - # Build per-row hash lists so we can later do fast overlap checks without re-reading the dataset. - active_columns = [column for column in columns if column not in exclude_columns] - batch_size = len(batch[active_columns[0]]) if active_columns else len(next(iter(batch.values()), [])) - if not active_columns: - return {"__hashes": [[] for _ in range(batch_size)]} - hashes: list[list[int]] = [] - for row_idx in range(batch_size): - row_hashes: list[int] = [] - for column in active_columns: - value = batch[column][row_idx] - if isinstance(value, list): - row_hashes.extend(_xxhash_int64(str(item)) for item in value) - else: - row_hashes.append(_xxhash_int64(str(value))) - hashes.append(row_hashes) - return {"__hashes": hashes} - - -def _remove_label_columns(dataset: Dataset, valid_label_columns: list[str] | None) -> Dataset: - # Drop label columns so they don't participate in duplicate checks. - if label_columns := set(dataset.column_names) & set(valid_label_columns or []): - return dataset.remove_columns(list(label_columns)) - return dataset - - -def _has_overlap(sample_values, batch_values: set[Any]) -> bool: - # Avoid materializing a set if we already have one. - if isinstance(sample_values, set): - return not sample_values.isdisjoint(batch_values) - return any(value in batch_values for value in sample_values) - - -def _iter_no_duplicate_batches( - remaining_indices: dict[int, None], - get_sample_values, - batch_size: int, - drop_last: bool, -) -> Iterator[list[int]]: - # Shared batch construction loop for both samplers; keeps behavior consistent. - while remaining_indices: - batch_values: set[Any] = set() - batch_indices: list[int] = [] - for index in remaining_indices: - sample_values = get_sample_values(index) - if _has_overlap(sample_values, batch_values): - continue - - batch_indices.append(index) - if len(batch_indices) == batch_size: - yield batch_indices - break - - batch_values.update(sample_values) - - else: - # NOTE: some indices might still have been ignored here - if not drop_last: - yield batch_indices - - for index in batch_indices: - del remaining_indices[index] - - class SetEpochMixin: """ Required for a BatchSampler as the Trainer will call set_epoch on the BatchSampler at the beginning of each epoch. @@ -240,6 +163,38 @@ def __iter__(self) -> Iterator[list[int]]: yield partial_batch +def _xxhash_int64(value: str) -> int: + # Convert uint64 -> int64 to keep values compatible with Arrow int64 storage. + hashed = xxhash.xxh64_intdigest(value) + if hashed >= _XXHASH_INT64_MAX: + hashed -= _XXHASH_UINT64_MAX + return hashed + + +def _hash_batch( + batch: dict[str, list[Any]], + columns: list[str], + exclude_columns: set[str], +) -> dict[str, list[list[int]]]: + # Must be defined at module scope because datasets.map with num_proc pickles this function. + # Build per-row hash lists so we can later do fast overlap checks without re-reading the dataset. + active_columns = [column for column in columns if column not in exclude_columns] + batch_size = len(batch[active_columns[0]]) if active_columns else len(next(iter(batch.values()), [])) + if not active_columns: + return {"__hashes": [[] for _ in range(batch_size)]} + hashes: list[list[int]] = [] + for row_idx in range(batch_size): + row_hashes: list[int] = [] + for column in active_columns: + value = batch[column][row_idx] + if isinstance(value, list): + row_hashes.extend(_xxhash_int64(str(item)) for item in value) + else: + row_hashes.append(_xxhash_int64(str(value))) + hashes.append(row_hashes) + return {"__hashes": hashes} + + class NoDuplicatesBatchSampler(DefaultBatchSampler): def __init__( self, @@ -297,7 +252,9 @@ def __init__( generator=generator, seed=seed, ) - self.dataset = _remove_label_columns(dataset, self.valid_label_columns) + if label_columns := set(dataset.column_names) & set(self.valid_label_columns or []): + dataset = dataset.remove_columns(list(label_columns)) + self.dataset = dataset self.precompute_hashes = precompute_hashes self.precompute_num_proc = precompute_num_proc self.precompute_batch_size = precompute_batch_size @@ -317,7 +274,7 @@ def __init__( def _build_hashes(self) -> None: if not self.precompute_hashes or self._row_hashes is not None: return - exclude_columns = set(self.valid_label_columns or []) | {"dataset_name"} + exclude_columns = {"dataset_name"} columns = list(self.dataset.column_names) # Precompute hash values once to avoid repeated string processing per batch. # Use num_proc to parallelize hashing across CPU cores. @@ -380,13 +337,6 @@ def __iter__(self) -> Iterator[list[int]]: self._build_hashes() row_hashes: np.ndarray = self._row_hashes - # We create a dictionary to None because we need a data structure that: - # 1. Allows for cheap removal of elements - # 2. Preserves the order of elements, i.e. remains random - remaining_indices = dict.fromkeys(torch.randperm(len(self.dataset), generator=self.generator).tolist()) - - if self.precompute_hashes: - def get_sample_values(index: int): return row_hashes[index] @@ -395,12 +345,38 @@ def get_sample_values(index: int): def get_sample_values(index: int) -> set[str]: return {str(value) for key, value in self.dataset[index].items() if key != "dataset_name"} - yield from _iter_no_duplicate_batches( - remaining_indices, - get_sample_values, - self.batch_size, - self.drop_last, - ) + def _has_overlap(sample_values, batch_values: set[Any]) -> bool: + # Avoid materializing a set if we already have one. + if isinstance(sample_values, set): + return not sample_values.isdisjoint(batch_values) + return any(value in batch_values for value in sample_values) + + # We create a dictionary mapping indices to None because we need a data structure that: + # 1. Allows for cheap removal of elements + # 2. Preserves the order of elements, i.e. remains random + remaining_indices = dict.fromkeys(torch.randperm(len(self.dataset), generator=self.generator).tolist()) + while remaining_indices: + batch_values: set[Any] = set() + batch_indices: list[int] = [] + for index in remaining_indices: + sample_values = get_sample_values(index) + if _has_overlap(sample_values, batch_values): + continue + + batch_indices.append(index) + if len(batch_indices) == self.batch_size: + yield batch_indices + break + + batch_values.update(sample_values) + + else: + # NOTE: some indices might still have been ignored here + if not self.drop_last: + yield batch_indices + + for index in batch_indices: + del remaining_indices[index] def __len__(self) -> int: if self.drop_last: From 13b949b9482dd35e260e66f119f519e48cd7206e Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 28 Jan 2026 13:26:34 +0100 Subject: [PATCH 12/12] Only compute cpu_count/default_workers if precompute_num_proc is None --- sentence_transformers/sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sentence_transformers/sampler.py b/sentence_transformers/sampler.py index ca34a5485..81c58691c 100644 --- a/sentence_transformers/sampler.py +++ b/sentence_transformers/sampler.py @@ -265,10 +265,10 @@ def __init__( "NoDuplicatesBatchSampler with precompute_hashes=True requires `xxhash`. " "Install `xxhash` to use this option." ) - cpu_count = os.cpu_count() or 1 - # Leave one core free to avoid saturating the system when hashing. - default_workers = max(1, min(8, cpu_count - 1)) if self.precompute_num_proc is None: + cpu_count = os.cpu_count() or 1 + # Leave one core free to avoid saturating the system when hashing. + default_workers = max(1, min(8, cpu_count - 1)) self.precompute_num_proc = default_workers def _build_hashes(self) -> None: