|
| 1 | + |
| 2 | +""" |
| 3 | + FisherMinBatchMatch(n_samples, subsampling) |
| 4 | + FisherMinBatchMatch(; n_samples, subsampling) |
| 5 | +
|
| 6 | +Covariance-weighted Fisher divergence minimization via the batch-and-match algorithm, which is a proximal point-type optimization scheme. |
| 7 | +
|
| 8 | +# (Keyword) Arguments |
| 9 | +- `n_samples::Int`: Number of samples (batchsize) used to compute the moments required for the batch-and-match update. (default: `32`) |
| 10 | +- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. (default: `nothing`) |
| 11 | +
|
| 12 | +!!! warning |
| 13 | + `FisherMinBatchMatch` with subsampling enabled results in a biased algorithm and may not properly optimize the covariance-weighted Fisher divergence. |
| 14 | +
|
| 15 | +!!! note |
| 16 | + `FisherMinBatchMatch` requires a sufficiently large `n_samples` to converge quickly. |
| 17 | +
|
| 18 | +!!! note |
| 19 | + The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `FisherMinBatchMatch` does not support amortization or structured variational families. |
| 20 | +
|
| 21 | +# Output |
| 22 | +- `q`: The last iterate of the algorithm. |
| 23 | +
|
| 24 | +# Callback Signature |
| 25 | +The `callback` function supplied to `optimize` needs to have the following signature: |
| 26 | +
|
| 27 | + callback(; rng, iteration, q, info) |
| 28 | +
|
| 29 | +The keyword arguments are as follows: |
| 30 | +- `rng`: Random number generator internally used by the algorithm. |
| 31 | +- `iteration`: The index of the current iteration. |
| 32 | +- `q`: Current variational approximation. |
| 33 | +- `info`: `NamedTuple` containing the information generated during the current iteration. |
| 34 | +
|
| 35 | +# Requirements |
| 36 | +- The variational family is [`FullRankGaussian`](@ref FullRankGaussian). |
| 37 | +- The target distribution has unconstrained support. |
| 38 | +- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. |
| 39 | +""" |
| 40 | +@kwdef struct FisherMinBatchMatch{Sub<:Union{Nothing,<:AbstractSubsampling}} <: |
| 41 | + AbstractVariationalAlgorithm |
| 42 | + n_samples::Int = 32 |
| 43 | + subsampling::Sub = nothing |
| 44 | +end |
| 45 | + |
| 46 | +struct BatchMatchState{Q,P,Sigma,Sub,UBuf,GradBuf} |
| 47 | + q::Q |
| 48 | + prob::P |
| 49 | + sigma::Sigma |
| 50 | + iteration::Int |
| 51 | + sub_st::Sub |
| 52 | + u_buf::UBuf |
| 53 | + grad_buf::GradBuf |
| 54 | +end |
| 55 | + |
| 56 | +function init( |
| 57 | + rng::Random.AbstractRNG, |
| 58 | + alg::FisherMinBatchMatch, |
| 59 | + q::MvLocationScale{<:LowerTriangular,<:Normal,L}, |
| 60 | + prob, |
| 61 | +) where {L} |
| 62 | + (; n_samples, subsampling) = alg |
| 63 | + capability = LogDensityProblems.capabilities(typeof(prob)) |
| 64 | + if capability < LogDensityProblems.LogDensityOrder{1}() |
| 65 | + throw( |
| 66 | + ArgumentError( |
| 67 | + "`FisherMinBatchMatch` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", |
| 68 | + ), |
| 69 | + ) |
| 70 | + end |
| 71 | + sub_st = isnothing(subsampling) ? nothing : init(rng, subsampling) |
| 72 | + params, _ = Optimisers.destructure(q) |
| 73 | + n_dims = LogDensityProblems.dimension(prob) |
| 74 | + u_buf = Matrix{eltype(params)}(undef, n_dims, n_samples) |
| 75 | + grad_buf = Matrix{eltype(params)}(undef, n_dims, n_samples) |
| 76 | + return BatchMatchState(q, prob, cov(q), 0, sub_st, u_buf, grad_buf) |
| 77 | +end |
| 78 | + |
| 79 | +output(::FisherMinBatchMatch, state) = state.q |
| 80 | + |
| 81 | +function rand_batch_match_samples_with_objective!( |
| 82 | + rng::Random.AbstractRNG, |
| 83 | + q::MvLocationScale, |
| 84 | + n_samples::Int, |
| 85 | + prob, |
| 86 | + u_buf=Matrix{eltype(q)}(undef, LogDensityProblems.dimension(prob), n_samples), |
| 87 | + grad_buf=Matrix{eltype(q)}(undef, LogDensityProblems.dimension(prob), n_samples), |
| 88 | +) |
| 89 | + μ = q.location |
| 90 | + C = q.scale |
| 91 | + u = Random.randn!(rng, u_buf) |
| 92 | + z = C*u .+ μ |
| 93 | + logπ_sum = zero(eltype(μ)) |
| 94 | + for b in 1:n_samples |
| 95 | + logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b)) |
| 96 | + grad_buf[:, b] = gb |
| 97 | + logπ_sum += logπb |
| 98 | + end |
| 99 | + logπ_avg = logπ_sum/n_samples |
| 100 | + |
| 101 | + # Estimate objective values |
| 102 | + # |
| 103 | + # F = E[| ∇log(q/π) (z) |_{CC'}^2] (definition) |
| 104 | + # = E[| C' (∇logq(z) - ∇logπ(z)) |^2] (Σ = CC') |
| 105 | + # = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ) |
| 106 | + # = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2] |
| 107 | + # = E[| -u - C'∇logπ(z)) |^2] |
| 108 | + fisher = sum(abs2, -u_buf - (C'*grad_buf))/n_samples |
| 109 | + |
| 110 | + return u_buf, z, grad_buf, fisher, logπ_avg |
| 111 | +end |
| 112 | + |
| 113 | +function step( |
| 114 | + rng::Random.AbstractRNG, |
| 115 | + alg::FisherMinBatchMatch, |
| 116 | + state, |
| 117 | + callback, |
| 118 | + objargs...; |
| 119 | + kwargs..., |
| 120 | +) |
| 121 | + (; n_samples, subsampling) = alg |
| 122 | + (; q, prob, sigma, iteration, sub_st, u_buf, grad_buf) = state |
| 123 | + |
| 124 | + d = LogDensityProblems.dimension(prob) |
| 125 | + μ = q.location |
| 126 | + C = q.scale |
| 127 | + Σ = sigma |
| 128 | + iteration += 1 |
| 129 | + |
| 130 | + # Maybe apply subsampling |
| 131 | + prob_sub, sub_st′, sub_inf = if isnothing(subsampling) |
| 132 | + prob, sub_st, NamedTuple() |
| 133 | + else |
| 134 | + batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) |
| 135 | + prob_sub = subsample(prob, batch) |
| 136 | + prob_sub, sub_st′, sub_inf |
| 137 | + end |
| 138 | + |
| 139 | + u_buf, z, grad_buf, fisher, logπ_avg = rand_batch_match_samples_with_objective!( |
| 140 | + rng, q, n_samples, prob_sub, u_buf, grad_buf |
| 141 | + ) |
| 142 | + |
| 143 | + # BaM updates |
| 144 | + zbar, C = mean_and_cov(z, 2) |
| 145 | + gbar, Γ = mean_and_cov(grad_buf, 2) |
| 146 | + |
| 147 | + μmz = μ - zbar |
| 148 | + λ = convert(eltype(μ), d*n_samples / iteration) |
| 149 | + |
| 150 | + U = Symmetric(λ*Γ + (λ/(1 + λ)*gbar)*gbar') |
| 151 | + V = Symmetric(Σ + λ*C + (λ/(1 + λ)*μmz)*μmz') |
| 152 | + |
| 153 | + Σ′ = Hermitian(2*V/(I + real(sqrt(I + 4*U*V)))) |
| 154 | + μ′ = 1/(1 + λ)*μ + λ/(1 + λ)*(Σ′*gbar + zbar) |
| 155 | + q′ = MvLocationScale(μ′[:, 1], cholesky(Σ′).L, q.dist) |
| 156 | + |
| 157 | + elbo = logπ_avg + entropy(q) |
| 158 | + info = (iteration=iteration, covweighted_fisher=fisher, elbo=elbo) |
| 159 | + |
| 160 | + state = BatchMatchState(q′, prob, Σ′, iteration, sub_st′, u_buf, grad_buf) |
| 161 | + |
| 162 | + if !isnothing(callback) |
| 163 | + info′ = callback(; rng, iteration, q, state) |
| 164 | + info = !isnothing(info′) ? merge(info′, info) : info |
| 165 | + end |
| 166 | + state, false, info |
| 167 | +end |
| 168 | + |
| 169 | +""" |
| 170 | + estimate_objective([rng,] alg, q, prob; n_samples) |
| 171 | +
|
| 172 | +Estimate the covariance-weighted Fisher divergence of the variational approximation `q` against the target log-density `prob`. |
| 173 | +
|
| 174 | +# Arguments |
| 175 | +- `rng::Random.AbstractRNG`: Random number generator. |
| 176 | +- `alg::FisherMinBatchMatch`: Variational inference algorithm. |
| 177 | +- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. |
| 178 | +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. |
| 179 | +
|
| 180 | +# Keyword Arguments |
| 181 | +- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) |
| 182 | +
|
| 183 | +# Returns |
| 184 | +- `obj_est`: Estimate of the objective value. |
| 185 | +""" |
| 186 | +function estimate_objective( |
| 187 | + rng::Random.AbstractRNG, |
| 188 | + alg::FisherMinBatchMatch, |
| 189 | + q::MvLocationScale{S,<:Normal,L}, |
| 190 | + prob; |
| 191 | + n_samples::Int=alg.n_samples, |
| 192 | +) where {S,L} |
| 193 | + _, _, _, fisher, _ = rand_batch_match_samples_with_objective!(rng, q, n_samples, prob) |
| 194 | + return fisher |
| 195 | +end |
0 commit comments