[perf] Speed up NoDuplicatesBatchSampler iteration (NO_DUPLICATES and NO_DUPLICATES_HASHED)
#3658
+257
−39
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hello!
Pull request overview
NO_DUPLICATESand around a 20% speedup forNO_DUPLICATES_HASHED.NoDuplicatesBatchSampler.Details
The old implementation kept remaining indices in a Python
dictand deleted accepted indices batch by batch.That works, but it carries a lot of Python object overhead on large datasets.
This PR switches that part to:
randperm(..., dtype=int32|int64)as a NumPy-backed index arraynext_positions) for O(1) removalsMy intent here is to improve performance without changing semantics.
drop_lastbehavior is unchanged.About dtypes:
int32row indices, so index/position NumPy arrays useint32in that case to reduce memory usage.int64only when row count exceeds theint32range.int64.Benchmark
Setup:
examples/sentence_transformer/evaluation/evaluation_no_dup_batch_sampler_speed.pysentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1triplet-hard/train--batch-size 8192 --no-progress-bar --precompute-num-proc 8d0df0cccno_dup_hashed_improve)NO_DUPLICATESis reported with 3 runs because each run is long;NO_DUPLICATES_HASHEDalso has an extra 10-run check for stability.precompute_hashes=False(NO_DUPLICATES) (3-run mean)Runs:
precompute_hashes=True(NO_DUPLICATES_HASHED) (3-run mean)Runs:
Extra check (
NO_DUPLICATES_HASHED, 10 runs)22.768s -> 17.842s(-21.64%)In short:
NO_DUPLICATES-2.25%(3-run mean)NO_DUPLICATES_HASHED-22.13%(3-run mean),-21.64%(10-run check)Memory note:
NO_DUPLICATES_HASHED, memory stays roughly flat in these runs, with a slight average decrease in bothhash_uss current_deltaandhash_uss peak_delta.Testing
uv run pytest tests/samplers/test_no_duplicates_batch_sampler.py -quv run pytest tests/test_trainer.py -k "get_batch_sampler" -qAdded tests:
drop_lastandprecompute_hashescombinations)precompute_hashes=Truestores precomputed hashes asnp.int64int64whenint32range is exceededRelated PR
This implementation was developed in collaboration with Codex, and all code has been reviewed by
@hotchpotch.If you have any questions or see anything that should be improved, I would really appreciate your feedback.