Skip to content

ComponentArrays break #1230

@penelopeysm

Description

@penelopeysm

This is fine:

julia> using DynamicPPL, ComponentArrays, Distributions

julia> @model function f()
           x = ComponentArray(a = 1.0)
           x[1] ~ Normal()
       end
f (generic function with 2 methods)

julia> VarInfo(f())
VarInfo {linked=false}
 ├─ values
 │  VarNamedTuple
 │  └─ x => PartialArray size=(1,) data::ComponentVector{VectorValue{Vector{Float64}, DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}, Tuple{}}, Vector{VectorValue{Vector{Float64}, DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}, Tuple{}}}, Tuple{Axis{(a = 1,)}}}
 │          └─ (1,) => VectorValue{Vector{Float64}, DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}, Tuple{}}([1.7586049368044987], DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}((1,)), ())
 └─ accs
    AccumulatorTuple with 3 accumulators
    ├─ LogPrior => LogPriorAccumulator(-2.4652841950812503)
    ├─ LogJacobian => LogJacobianAccumulator(0.0)
    └─ LogLikelihood => LogLikelihoodAccumulator(0.0)

This runs, but works for the wrong reasons (it just ignores the ComponentArray and assumes that x is a struct/NT):

julia> @model function g()
           x = ComponentArray(a = 1.0)
           x.a ~ Normal()
       end
g (generic function with 2 methods)

julia> VarInfo(g())
VarInfo {linked=false}
 ├─ values
 │  VarNamedTuple
 │  └─ x => VarNamedTuple
 │          └─ a => VectorValue{Vector{Float64}, DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}, Tuple{}}([0.831876572070276], DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}((1,)), ())
 └─ accs
    AccumulatorTuple with 3 accumulators
    ├─ LogPrior => LogPriorAccumulator(-1.2649478487843693)
    ├─ LogJacobian => LogJacobianAccumulator(0.0)
    └─ LogLikelihood => LogLikelihoodAccumulator(0.0)

Attempting to mix and match properties and indices will fail:

julia> @model function h()
           x = ComponentArray(a = 1.0, b = 2.0)
           x[1] ~ Normal()
           x.b ~ Normal()
       end
h (generic function with 2 methods)

julia> VarInfo(h())
ERROR: MethodError: no method matching _setindex_optic!!(::DynamicPPL.VarNamedTuples.PartialArray{…}, ::VectorValue{…}, ::AbstractPPL.Property{…}, ::ComponentVector{…}; allow_new::Val{…})
The function `_setindex_optic!!` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  _setindex_optic!!(::Any, ::Any, ::AbstractPPL.Iden, ::Any; allow_new)
   @ DynamicPPL ~/ppl/dppl/src/varnamedtuple/getset.jl:97
  _setindex_optic!!(::VarNamedTuple{names}, ::Any, ::AbstractPPL.Property{S}, ::Any; allow_new) where {names, S}
   @ DynamicPPL ~/ppl/dppl/src/varnamedtuple/getset.jl:227
  _setindex_optic!!(::DynamicPPL.VarNamedTuples.PartialArray, ::Any, ::AbstractPPL.Index, ::Any; allow_new)
   @ DynamicPPL ~/ppl/dppl/src/varnamedtuple/getset.jl:139

The good thing is that I think this is actually not difficult to fix. We just need to overload a few methods for make_leaf and _setindex_optic in a ComponentArraysExt.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions