SMC Multiprocessing and Progress Bar Refactor#8047
SMC Multiprocessing and Progress Bar Refactor#8047jessegrabowski wants to merge 14 commits intopymc-devs:mainfrom
Conversation
|
@tvwenger it would be nice if you could try your SMC models that have been giving you trouble on this PR branch and report back, since you've been the one doing the heavy lifting bug-hunting on SMC lately. |
There was a problem hiding this comment.
Pull request overview
This PR refactors SMC (Sequential Monte Carlo) sampling to harmonize its API with pm.sample, bringing multiprocessing capabilities, progress bars, and performance improvements to pm.sample_smc. The refactor moves PyTensor compilation to the main process before distributing work to child processes, addresses thread safety concerns, and adds comprehensive progress reporting similar to MCMC sampling.
Changes:
- Implemented multiprocessing support for SMC using a pattern similar to MCMC parallel sampling
- Added custom SMC progress bars that track beta (inverse temperature) progression from 0 to 1
- Moved kernel compilation to main process to avoid thread-safety issues with PyTensor compilation in worker processes
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| pymc/smc/parallel.py | New file implementing parallel SMC sampling infrastructure with process management, message passing, and result collection |
| pymc/smc/sampling.py | Major refactor of main SMC sampling function to support both parallel and sequential execution with shared kernel compilation |
| pymc/smc/kernels.py | Moved kernel compilation to init and added progress bar configuration methods |
| pymc/progress_bar.py | Added SMCProgressBarManager class for beta-based progress tracking and modified table styling |
| tests/smc/test_smc.py | Added test for sequential sampling with cores=1 |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8047 +/- ##
==========================================
+ Coverage 90.89% 91.49% +0.59%
==========================================
Files 123 123
Lines 19501 19821 +320
==========================================
+ Hits 17726 18135 +409
+ Misses 1775 1686 -89
🚀 New features to boost your workflow:
|
bf91046 to
7567f0e
Compare
|
Check #8044 |
7567f0e to
8f6f2ad
Compare
I addressed this in the Extract and share Parallel setup code between MCMC and SMC commit, but now the two PRs are overlapping. |
15c4e5f to
ae2d582
Compare
|
So there's an edge case with the pickle function -> send to process approach. If the pickled functions have random number generators these need to de changed so as to have independent streams. Usually this isn't a problem in mcmc because we never wanted to use functions with randomness in it in our step samplers, but this is not the case for SMC, and specially SMC-ABC with Simulator, which definitely supposed to be random. When you call The approach may require something like the |
Could we just make the rng an explicit input to the function we pickle up and send out, to avoid the copy? |
Not without much more changes in the codebase |
Extract mp_ctx initialization function Extract blas_core setup function Don't use threadpool limits when mp_ctx is ForkContext
56cc72d to
cd2a687
Compare
|
There were some major conflicts with the recent round of progress bar PRs, so I went ahead and did a bit of refactoring:
I renamed everything in the rich/marimo progress bar backends to have more agnostic argument names (chains -> n_bars, total_draws -> total, draw -> completed, etc). The marimo backend won't work with SMC because it has a ton of hard-coded logic for MCMC and I don't care enough to handle it. If a future dev cares both about SMC and Marimo, he can handle it in the future. |
|
@jessegrabowski is this ready for final review? |
|
yep |
| "stats": {}, | ||
| } | ||
| ] | ||
| self._start_times = [perf_counter()] |
There was a problem hiding this comment.
Perhaps initialize only when the task / bar advances for the first time. So that sequential sampling is not measuring speed relative to the start of the first chain (otherwise it seems each chain is slower than the previous one...). Does not need to be done in this PR
| shared_rngs = [ | ||
| var for var in fn.get_shared() if isinstance(var.type, RandomGeneratorType) | ||
| ] | ||
| n_shared_rngs = len(shared_rngs) |
There was a problem hiding this comment.
raise NotImplementedError if there are shared_rngs and isinstance(fn.maker.linker, JAXLinker), as the rngs have different format after compile and no longer are retrieved by fn.get_shared()` either.
| ) | ||
| self.likelihood_logp_func = self.likelihood_logp_func.copy( | ||
| swap=make_rng_swaps(self.likelihood_logp_func, rng) | ||
| ) |
There was a problem hiding this comment.
For safety, even if not used, do self.rng = rng
|
|
||
| return custom_methods | ||
|
|
||
| def _sample_smc( |
There was a problem hiding this comment.
inline this in _sample_smc_many?
| pm.sample_smc(draws=6, cores=2) | ||
|
|
||
| @pytest.mark.parametrize("chains", [1, 2], ids=["1_chain", "2_chains"]) | ||
| def test_sequential(self, chains, caplog): |
There was a problem hiding this comment.
parametrize the first sample_test to conver both sequential and parallel.
| chains=chains, | ||
| cores=1, | ||
| return_inferencedata=False, | ||
| progressbar=not _IS_WINDOWS, |
There was a problem hiding this comment.
We shouldn't need this not _IS_WINDOWS logic anywhere anymore with the new progressbar (it existed in previous tests)
ricardoV94
left a comment
There was a problem hiding this comment.
My last requests (97% confidence)
Description
I always like SMC as a gradient-free option for my big silly models with few parameters, but it always gave me trouble because of the API break between it and
pm.sample. This PR aims to harmonize the two by bringing over a bunch of functionality frompm.sampletopm.sample_smc.This PR is intended to be reviewed commit by commit. I verified that the test suite runs in all intermediate forms. Here is a summary of each commit:
Use multiprocessing for SMC sampling: Themultiprocessinglibrary is now used to handle parallel SMC sampling. This commit was heavily Claude-assisted, so it should receive special scrutiny. The objective was to make SMC multiprocessing look exactly like MCMC multiprocessing. It also exposes anmp_ctxargument topm.sample_smc, which can allow compiling with e.g. JAX (usingmp_ctx ='forkserver').Sample SMC sequentially when cores=1adds separate logic for sequential sampling on one core. Again, this copies the relevant MCMC functions.Initialize SMC Kernels on main processis a major performance change, intended to address e.g. BUG:sample_smcstalls on final stage #8030. Pytensor compilation is not thread-safe, so we shouldn't be doing it on the workers. In this PR, the kernel is compiled once on the main process, then serialized and sent to the workers. This matches what we do with step functions in MCMC. Importantly, this commit eliminates the need for serialization of many auxiliary objects, including the pymc model itself, and some special logic for custom distributions. To do this, a couple of ancillary changes had to be made -- for example, transformation of the chains from numpy to NDArray objects happens on the main process now, after all sampling is done.Add blas_cores argument to sample_smcagain, this copies over multiprocessing machinery frompm.sampletopm.sample_smcby adding ablas_coresargument topm.sample_smc, for the same reasons it exists over there.Add custom progress bar for SMCadds a progress bar style tosample_smcthat matches that ofpm.sample. The bars fill from 0-1 following the value ofbeta, and we provide an estimated time to completion by measuring the speed per step. It looks like this:I observed big speed gains using
sample_smcafter this PR. I timed this simple hierarchical model:Timings went from 6.1 s to 1.44 s using the C backend, and 1.46 s to 1.09 s using Numba mode (with cache). Running
test_smc.pylocally goes from 1m4s to 6.264 seconds.I could run more formal benchmarks if someone asks, but I don't really want to.
Related Issue
sample_smcstalls on final stage #8030sample_smccan lead to compilation halting #8022Checklist
Type of change