Skip to content

Commit 9e6948e

Browse files
Red-Portalgithub-actions[bot]penelopeysm
authored
Batch-and-Match algorithm for minimizing the covariance-weighted Fisher divergence (#218)
* add batch and match * update HISTORY * fun formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix add missing test file * add documentation and update docstring for batch-and-match * run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix docs Co-authored-by: Penelope Yong <penelopeysm@gmail.com> * fix docs Co-authored-by: Penelope Yong <penelopeysm@gmail.com> * fix docs Co-authored-by: Penelope Yong <penelopeysm@gmail.com> * fix remove dead code * fix compute average outside of loop for batch-and-match * fix remove reference in docstring * fix capitalization in dosctring * refactor move duplicate code in batch match to a common function --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
1 parent f407d88 commit 9e6948e

File tree

8 files changed

+405
-1
lines changed

8 files changed

+405
-1
lines changed

HISTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Specifically, the following measure-space optimization algorithms have been adde
88
- `KLMinWassFwdBwd`
99
- `KLMinNaturalGradDescent`
1010
- `KLMinSqrtNaturalGradDescent`
11+
- `FisherMinBatchMatch`
1112

1213
## Interface Change
1314

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ makedocs(;
2929
"`KLMinWassFwdBwd`" => "klminwassfwdbwd.md",
3030
"`KLMinNaturalGradDescent`" => "klminnaturalgraddescent.md",
3131
"`KLMinSqrtNaturalGradDescent`" => "klminsqrtnaturalgraddescent.md",
32+
"`FisherMinBatchMatch`" => "fisherminbatchmatch.md",
3233
],
3334
"Variational Families" => "families.md",
3435
"Optimization" => "optimization.md",

docs/src/fisherminbatchmatch.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# [`FisherMinBatchMatch`](@id fisherminbatchmatch)
2+
3+
## Description
4+
5+
This algorithm, known as batch-and-match (BaM) aims to minimize the covariance-weighted 2nd-order Fisher divergence by running a proximal point-type method[^CMPMGBS24].
6+
On certain low-dimensional problems, BaM can converge very quickly without any tuning.
7+
Since `FisherMinBatchMatch` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the measure-valued operations tractable.
8+
9+
```@docs
10+
FisherMinBatchMatch
11+
```
12+
13+
The associated objective value can be estimated through the following:
14+
15+
```@docs; canonical=false
16+
estimate_objective(
17+
::Random.AbstractRNG,
18+
::KLMinWassFwdBwd,
19+
::MvLocationScale,
20+
::Any;
21+
::Int,
22+
)
23+
```
24+
25+
[^CMPMGBS24]: Cai, D., Modi, C., Pillaud-Vivien, L., Margossian, C. C., Gower, R. M., Blei, D. M., & Saul, L. K. (2024). Batch and match: black-box variational inference with a score-based divergence. In *Proceedings of the International Conference on Machine Learning*.
26+
## [Methodology](@id fisherminbatchmatch_method)
27+
28+
This algorithm aims to solve the problem
29+
30+
```math
31+
\mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{F}_{\mathrm{cov}}(q, \pi),
32+
```
33+
34+
where $\mathcal{Q}$ is some family of distributions, often called the variational family, and $\mathrm{F}_{\mathrm{cov}}$ is a divergence defined as
35+
36+
```math
37+
\mathrm{F}_{\mathrm{cov}}(q, \pi) = \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}_{\mathrm{Cov}(q)}^2 ,
38+
```
39+
40+
where ${\lVert x \rVert}_{A}^2 = x^{\top} A x $ is a weighted norm.
41+
$\mathrm{F}_{\mathrm{cov}}$ can be viewed as a variant of the canonical 2nd-order Fisher divergence defined as
42+
43+
```math
44+
\mathrm{F}_{2}(q, \pi) = \sqrt{ \mathbb{E}_{z \sim q} {\left\lVert \nabla \log \frac{q}{\pi} (z) \right\rVert}^2 }.
45+
```
46+
47+
The use of the weighted norm ${\lVert \cdot \rVert}_{\mathrm{Cov}(q)}^2$ facilitates the use of a proximal point-type method for minimizing $\mathrm{F}_{2}(q, \pi)$.
48+
In particular, BaM iterates the update
49+
50+
```math
51+
q_{t+1} = \argmin_{q \in \mathcal{Q}} \left\{ \mathrm{F}_{\mathrm{cov}}(q, \pi) + \frac{2}{\lambda_t} \mathrm{KL}\left(q_t, q\right) \right\} .
52+
```
53+
54+
Since $\mathrm{F}(q, \pi)$ is intractable, it is replaced with a Monte Carlo approximation with a number of samples `n_samples`.
55+
Furthermore, by restricting $\mathcal{Q}$ to a Gaussian variational family, the update rule admits a closed form solution[^CMPMGBS24].
56+
Notice that the update does not involve the parameterization of $q_t$, which makes `FisherMinBatchMatch` a measure-space algorithm.
57+
58+
Historically, the idea of using a proximal point-type update for minimizing a Fisher divergence-like objective was initially coined as Gaussian score matching[^MGMYBS23].
59+
BaM can be viewed as a successor to this algorithm.
60+
61+
[^MGMYBS23]: Modi, C., Gower, R., Margossian, C., Yao, Y., Blei, D., & Saul, L. (2023). Variational inference with Gaussian score matching. In *Advances in Neural Information Processing Systems*, 36.

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ For using the algorithms implemented in `AdvancedVI`, refer to the corresponding
2020
- [KLMinNaturalGradDescent](@ref klminnaturalgraddescent)
2121
- [KLMinSqrtNaturalGradDescent](@ref klminsqrtnaturalgraddescent)
2222
- [KLMinWassFwdBwd](@ref klminwassfwdbwd)
23+
- [FisherMinBatchMatch](@ref fisherminbatchmatch)

src/AdvancedVI.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,9 @@ include("algorithms/gauss_expected_grad_hess.jl")
358358
include("algorithms/klminwassfwdbwd.jl")
359359
include("algorithms/klminsqrtnaturalgraddescent.jl")
360360
include("algorithms/klminnaturalgraddescent.jl")
361+
include("algorithms/fisherminbatchmatch.jl")
361362

362-
export KLMinWassFwdBwd, KLMinSqrtNaturalGradDescent, KLMinNaturalGradDescent
363+
export KLMinWassFwdBwd,
364+
KLMinSqrtNaturalGradDescent, KLMinNaturalGradDescent, FisherMinBatchMatch
363365

364366
end
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
2+
"""
3+
FisherMinBatchMatch(n_samples, subsampling)
4+
FisherMinBatchMatch(; n_samples, subsampling)
5+
6+
Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm, which is a proximal point-type optimization scheme.
7+
8+
# (Keyword) Arguments
9+
- `n_samples::Int`: Number of samples (batchsize) used to compute the moments required for the batch-and-match update. (default: `32`)
10+
- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. (default: `nothing`)
11+
12+
!!! warning
13+
`FisherMinBatchMatch` with subsampling enabled results in a biased algorithm and may not properly optimize the covariance-weighted Fisher divergence.
14+
15+
!!! note
16+
`FisherMinBatchMatch` requires a sufficiently large `n_samples` to converge quickly.
17+
18+
!!! note
19+
The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `FisherMinBatchMatch` does not support amortization or structured variational families.
20+
21+
# Output
22+
- `q`: The last iterate of the algorithm.
23+
24+
# Callback Signature
25+
The `callback` function supplied to `optimize` needs to have the following signature:
26+
27+
callback(; rng, iteration, q, info)
28+
29+
The keyword arguments are as follows:
30+
- `rng`: Random number generator internally used by the algorithm.
31+
- `iteration`: The index of the current iteration.
32+
- `q`: Current variational approximation.
33+
- `info`: `NamedTuple` containing the information generated during the current iteration.
34+
35+
# Requirements
36+
- The variational family is [`FullRankGaussian`](@ref FullRankGaussian).
37+
- The target distribution has unconstrained support.
38+
- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability.
39+
"""
40+
@kwdef struct FisherMinBatchMatch{Sub<:Union{Nothing,<:AbstractSubsampling}} <:
41+
AbstractVariationalAlgorithm
42+
n_samples::Int = 32
43+
subsampling::Sub = nothing
44+
end
45+
46+
struct BatchMatchState{Q,P,Sigma,Sub,UBuf,GradBuf}
47+
q::Q
48+
prob::P
49+
sigma::Sigma
50+
iteration::Int
51+
sub_st::Sub
52+
u_buf::UBuf
53+
grad_buf::GradBuf
54+
end
55+
56+
function init(
57+
rng::Random.AbstractRNG,
58+
alg::FisherMinBatchMatch,
59+
q::MvLocationScale{<:LowerTriangular,<:Normal,L},
60+
prob,
61+
) where {L}
62+
(; n_samples, subsampling) = alg
63+
capability = LogDensityProblems.capabilities(typeof(prob))
64+
if capability < LogDensityProblems.LogDensityOrder{1}()
65+
throw(
66+
ArgumentError(
67+
"`FisherMinBatchMatch` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).",
68+
),
69+
)
70+
end
71+
sub_st = isnothing(subsampling) ? nothing : init(rng, subsampling)
72+
params, _ = Optimisers.destructure(q)
73+
n_dims = LogDensityProblems.dimension(prob)
74+
u_buf = Matrix{eltype(params)}(undef, n_dims, n_samples)
75+
grad_buf = Matrix{eltype(params)}(undef, n_dims, n_samples)
76+
return BatchMatchState(q, prob, cov(q), 0, sub_st, u_buf, grad_buf)
77+
end
78+
79+
output(::FisherMinBatchMatch, state) = state.q
80+
81+
function rand_batch_match_samples_with_objective!(
82+
rng::Random.AbstractRNG,
83+
q::MvLocationScale,
84+
n_samples::Int,
85+
prob,
86+
u_buf=Matrix{eltype(q)}(undef, LogDensityProblems.dimension(prob), n_samples),
87+
grad_buf=Matrix{eltype(q)}(undef, LogDensityProblems.dimension(prob), n_samples),
88+
)
89+
μ = q.location
90+
C = q.scale
91+
u = Random.randn!(rng, u_buf)
92+
z = C*u .+ μ
93+
logπ_sum = zero(eltype(μ))
94+
for b in 1:n_samples
95+
logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b))
96+
grad_buf[:, b] = gb
97+
logπ_sum += logπb
98+
end
99+
logπ_avg = logπ_sum/n_samples
100+
101+
# Estimate objective values
102+
#
103+
# F = E[| ∇log(q/π) (z) |_{CC'}^2] (definition)
104+
# = E[| C' (∇logq(z) - ∇logπ(z)) |^2] (Σ = CC')
105+
# = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ)
106+
# = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2]
107+
# = E[| -u - C'∇logπ(z)) |^2]
108+
fisher = sum(abs2, -u_buf - (C'*grad_buf))/n_samples
109+
110+
return u_buf, z, grad_buf, fisher, logπ_avg
111+
end
112+
113+
function step(
114+
rng::Random.AbstractRNG,
115+
alg::FisherMinBatchMatch,
116+
state,
117+
callback,
118+
objargs...;
119+
kwargs...,
120+
)
121+
(; n_samples, subsampling) = alg
122+
(; q, prob, sigma, iteration, sub_st, u_buf, grad_buf) = state
123+
124+
d = LogDensityProblems.dimension(prob)
125+
μ = q.location
126+
C = q.scale
127+
Σ = sigma
128+
iteration += 1
129+
130+
# Maybe apply subsampling
131+
prob_sub, sub_st′, sub_inf = if isnothing(subsampling)
132+
prob, sub_st, NamedTuple()
133+
else
134+
batch, sub_st′, sub_inf = step(rng, subsampling, sub_st)
135+
prob_sub = subsample(prob, batch)
136+
prob_sub, sub_st′, sub_inf
137+
end
138+
139+
u_buf, z, grad_buf, fisher, logπ_avg = rand_batch_match_samples_with_objective!(
140+
rng, q, n_samples, prob_sub, u_buf, grad_buf
141+
)
142+
143+
# BaM updates
144+
zbar, C = mean_and_cov(z, 2)
145+
gbar, Γ = mean_and_cov(grad_buf, 2)
146+
147+
μmz = μ - zbar
148+
λ = convert(eltype(μ), d*n_samples / iteration)
149+
150+
U = Symmetric*Γ +/(1 + λ)*gbar)*gbar')
151+
V = Symmetric+ λ*C +/(1 + λ)*μmz)*μmz')
152+
153+
Σ′ = Hermitian(2*V/(I + real(sqrt(I + 4*U*V))))
154+
μ′ = 1/(1 + λ)*μ + λ/(1 + λ)*(Σ′*gbar + zbar)
155+
q′ = MvLocationScale(μ′[:, 1], cholesky(Σ′).L, q.dist)
156+
157+
elbo = logπ_avg + entropy(q)
158+
info = (iteration=iteration, covweighted_fisher=fisher, elbo=elbo)
159+
160+
state = BatchMatchState(q′, prob, Σ′, iteration, sub_st′, u_buf, grad_buf)
161+
162+
if !isnothing(callback)
163+
info′ = callback(; rng, iteration, q, state)
164+
info = !isnothing(info′) ? merge(info′, info) : info
165+
end
166+
state, false, info
167+
end
168+
169+
"""
170+
estimate_objective([rng,] alg, q, prob; n_samples)
171+
172+
Estimate the covariance-weighted Fisher divergence of the variational approximation `q` against the target log-density `prob`.
173+
174+
# Arguments
175+
- `rng::Random.AbstractRNG`: Random number generator.
176+
- `alg::FisherMinBatchMatch`: Variational inference algorithm.
177+
- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation.
178+
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
179+
180+
# Keyword Arguments
181+
- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.)
182+
183+
# Returns
184+
- `obj_est`: Estimate of the objective value.
185+
"""
186+
function estimate_objective(
187+
rng::Random.AbstractRNG,
188+
alg::FisherMinBatchMatch,
189+
q::MvLocationScale{S,<:Normal,L},
190+
prob;
191+
n_samples::Int=alg.n_samples,
192+
) where {S,L}
193+
_, _, _, fisher, _ = rand_batch_match_samples_with_objective!(rng, q, n_samples, prob)
194+
return fisher
195+
end

0 commit comments

Comments
 (0)