arviz-devs / ArviZ.jl
1
const turing_key_map = Dict(
2
    "acceptance_rate" => "mean_tree_accept",
3
    "hamiltonian_energy" => "energy",
4
    "hamiltonian_energy_error" => "energy_error",
5
    "is_adapt" => "tune",
6
    "max_hamiltonian_energy_error" => "max_energy_error",
7
    "n_steps" => "tree_size",
8
    "numerical_error" => "diverging",
9
    "tree_depth" => "depth",
10
)
11
const stan_key_map = Dict(
12
    "accept_stat__" => "accept_stat",
13
    "divergent__" => "diverging",
14
    "energy__" => "energy",
15
    "lp__" => "lp",
16
    "n_leapfrog__" => "n_leapfrog",
17
    "stepsize__" => "stepsize",
18
    "treedepth__" => "treedepth",
19
)
20
const stats_key_map = merge(turing_key_map, stan_key_map)
21

22
"""
23
    reshape_values(x::AbstractArray) -> AbstractArray
24

25
Convert from `MCMCChains` variable values with dimensions `(ndraw, size..., nchain)` to
26
ArviZ's expected `(nchain, ndraw, size...)`.
27
"""
28 20
reshape_values(x::AbstractArray{T,N}) where {T,N} = permutedims(x, (N, 1, 2:(N - 1)...))
29

30 20
headtail(x) = x[1], x[2:end]
31

32
function split_locname(name)
33 20
    name = replace(name, r"[\[,]" => '.')
34 20
    name = replace(name, ']' => "")
35 20
    name, loc = headtail(split(name, '.'))
36 20
    length(loc) == 0 && return name, ()
37 20
    loc = tryparse.(Int, loc)
38 20
    Nothing <: eltype(loc) && return name, ()
39 20
    return name, tuple(loc...)
40
end
41

42
function varnames_locs_dict(loc_names, loc_str_to_old)
43 20
    vars_to_locs = Dict()
44 20
    for loc_name in loc_names
45 20
        var_name, loc = split_locname(loc_name)
46 20
        if var_name  keys(vars_to_locs)
47 20
            vars_to_locs[var_name] = ([loc_str_to_old[loc_name]], [loc])
48
        else
49 20
            push!(vars_to_locs[var_name][1], loc_str_to_old[loc_name])
50 20
            push!(vars_to_locs[var_name][2], loc)
51
        end
52
    end
53 20
    return vars_to_locs
54
end
55

56
function attributes_dict(chns::Chains)
57 20
    info = delete(chns.info, :hashedsummary)
58 20
    return Dict{String,Any}((string(k), v) for (k, v) in pairs(info))
59
end
60 20
attributes_dict(::Nothing) = Dict()
61

62
function section_dict(chns::Chains, section)
63 20
    ndraws, _, nchains = size(chns)
64 20
    loc_names_old = getfield(chns.name_map, section) # old may be Symbol or String
65 20
    loc_names = string.(loc_names_old)
66 20
    loc_str_to_old = Dict(
67
        name_str => name_old for (name_str, name_old) in zip(loc_names, loc_names_old)
68
    )
69 20
    vars_to_locs = varnames_locs_dict(loc_names, loc_str_to_old)
70 20
    vars_to_arrays = Dict{String,Array}()
71 20
    for (var_name, names_locs) in vars_to_locs
72 20
        loc_names, locs = names_locs
73 20
        max_loc = maximum(hcat([[loc...] for loc in locs]...); dims=2)
74 20
        ndim = length(max_loc)
75 20
        sizes = tuple(max_loc...)
76

77 20
        oldarr = reshape_values(replacemissing(Array(chns.value[:, loc_names, :])))
78 20
        if ndim == 0
79 20
            arr = dropdims(oldarr; dims=3)
80
        else
81 20
            arr = Array{Union{typeof(NaN),eltype(oldarr)}}(undef, nchains, ndraws, sizes...)
82 20
            fill!(arr, NaN)
83 20
            for i in eachindex(locs)
84 20
                arr[:, :, locs[i]...] = oldarr[:, :, i]
85
            end
86
        end
87 20
        vars_to_arrays[var_name] = arr
88
    end
89 20
    return vars_to_arrays
90
end
91

92
function chains_to_dict(
93
    chns::Chains; ignore=String[], section=:parameters, rekey_fun=identity
94
)
95 20
    section in sections(chns) || return Dict()
96 20
    chns_dict = section_dict(chns, section)
97 20
    removekeys!(chns_dict, ignore)
98 20
    return rekey_fun(chns_dict)
99
end
100 20
chains_to_dict(::Nothing; kwargs...) = nothing
101

102
"""
103
    convert_to_inference_data(obj::Chains; group = :posterior, kwargs...) -> InferenceData
104

105
Convert the chains `obj` to an [`InferenceData`](@ref) with the specified `group`.
106

107
Remaining `kwargs` are forwarded to [`from_mcmcchains`](@ref).
108
"""
109
function convert_to_inference_data(chns::Chains; group=:posterior, kwargs...)
110 20
    group = Symbol(group)
111 20
    group == :posterior && return from_mcmcchains(chns; kwargs...)
112 20
    return from_mcmcchains(; group => chns)
113
end
114

115
@doc doc"""
116
    from_mcmcchains(posterior::Chains; kwargs...) -> InferenceData
117
    from_mcmcchains(; kwargs...) -> InferenceData
118
    from_mcmcchains(
119
        posterior::Chains,
120
        posterior_predictive::Any,
121
        predictions::Any,
122
        log_likelihood::Any;
123
        kwargs...
124
    ) -> InferenceData
125

126
Convert data in an `MCMCChains.Chains` format into an [`InferenceData`](@ref).
127

128
Any keyword argument below without an an explicitly annotated type above is allowed, so long
129
as it can be passed to [`convert_to_inference_data`](@ref).
130

131
# Arguments
132

133
- `posterior::Chains`: Draws from the posterior
134

135
# Keywords
136

137
- `posterior_predictive::Any=nothing`: Draws from the posterior predictive distribution or
138
    name(s) of predictive variables in `posterior`
139
- `predictions::Any=nothing`: Out-of-sample predictions for the posterior.
140
- `prior::Any=nothing`: Draws from the prior
141
- `prior_predictive::Any=nothing`: Draws from the prior predictive distribution or name(s)
142
    of predictive variables in `prior`
143
- `observed_data::Dict{String,Array}=nothing`: Observed data on which the `posterior` is
144
    conditional. It should only contain data which is modeled as a random variable. Keys are
145
    parameter names and values.
146
- `constant_data::Dict{String,Array}=nothing`: Model constants, data included in the model
147
    which is not modeled as a random variable. Keys are parameter names and values.
148
- `predictions_constant_data::Dict{String,Array}=nothing`: Constants relevant to the model
149
     predictions (i.e. new `x` values in a linear regression).
150
- `log_likelihood::Any=nothing`: Pointwise log-likelihood for the data. It is recommended
151
     to use this argument as a dictionary whose keys are observed variable names and whose
152
     values are log likelihood arrays.
153
- `log_likelihood::String=nothing`: Name of variable in `posterior` with log likelihoods
154
- `library=MCMCChains`: Name of library that generated the chains
155
- `coords::Dict{String,Vector}=nothing`: Map from named dimension to named indices
156
- `dims::Dict{String,Vector{String}}=nothing`: Map from variable name to names of its
157
    dimensions
158

159
# Returns
160

161
- `InferenceData`: The data with groups corresponding to the provided data
162
"""
163
from_mcmcchains
164

165
function from_mcmcchains(
166
    posterior,
167
    posterior_predictive,
168
    predictions,
169
    log_likelihood;
170
    library=MCMCChains,
171
    kwargs...,
172
)
173 20
    kwargs = convert(Dict, merge((; dims=nothing), kwargs))
174 20
    library = string(library)
175 20
    rekey_fun = d -> rekey(d, stats_key_map)
176

177
    # Convert chains to dicts
178 20
    post_dict = chains_to_dict(posterior)
179 20
    stats_dict = chains_to_dict(posterior; section=:internals, rekey_fun=rekey_fun)
180 20
    stats_dict = enforce_stat_types(stats_dict)
181

182 20
    all_idata = InferenceData()
183 20
    for (group, group_data) in [
184
        :posterior_predictive => posterior_predictive,
185
        :predictions => predictions,
186
        :log_likelihood => log_likelihood,
187
    ]
188 20
        group_data === nothing && continue
189 20
        if group_data isa Union{Symbol,String}
190 20
            group_data = [string(group_data)]
191
        end
192 20
        if group_data isa Union{AbstractVector{Symbol},NTuple{N,Symbol} where {N}}
193 0
            group_data = map(string, group_data)
194
        end
195 20
        if group_data isa Union{AbstractVector{String},NTuple{N,String} where {N}}
196 20
            group_data = popsubdict!(post_dict, group_data)
197
        end
198 20
        group_dataset = convert_to_dataset(group_data; library=library, kwargs...)
199 20
        setattribute!(group_dataset, "inference_library", library)
200 20
        concat!(all_idata, InferenceData(; group => group_dataset))
201
    end
202

203 20
    attrs = attributes_dict(posterior)
204 20
    attrs = merge(attrs, Dict("inference_library" => library))
205 20
    kwargs = convert(Dict, merge((; attrs=attrs, dims=nothing), kwargs))
206 20
    post_idata = _from_dict(post_dict; sample_stats=stats_dict, kwargs...)
207 20
    concat!(all_idata, post_idata)
208 20
    return all_idata
209
end
210
function from_mcmcchains(
211
    posterior=nothing;
212
    posterior_predictive=nothing,
213
    predictions=nothing,
214
    prior=nothing,
215
    prior_predictive=nothing,
216
    observed_data=nothing,
217
    constant_data=nothing,
218
    predictions_constant_data=nothing,
219
    log_likelihood=nothing,
220
    library=MCMCChains,
221
    kwargs...,
222
)
223 20
    kwargs = convert(Dict, merge((; dims=nothing, coords=nothing), kwargs))
224

225 20
    all_idata = from_mcmcchains(
226
        posterior,
227
        posterior_predictive,
228
        predictions,
229
        log_likelihood;
230
        library=library,
231
        kwargs...,
232
    )
233

234 20
    if prior !== nothing
235 20
        pre_prior_idata = convert_to_inference_data(
236
            prior; posterior_predictive=prior_predictive, library=library, kwargs...
237
        )
238 20
        prior_idata = rekey(
239
            pre_prior_idata,
240
            Dict(
241
                :posterior => :prior,
242
                :posterior_predictive => :prior_predictive,
243
                :sample_stats => :sample_stats_prior,
244
            ),
245
        )
246 20
        concat!(all_idata, prior_idata)
247
    end
248

249 20
    for (group, group_data) in [
250
        :observed_data => observed_data,
251
        :constant_data => constant_data,
252
        :predictions_constant_data => predictions_constant_data,
253
    ]
254 20
        group_data === nothing && continue
255 20
        group_dataset = convert_to_constant_dataset(group_data; library=library, kwargs...)
256 20
        concat!(all_idata, InferenceData(; group => group_dataset))
257
    end
258

259 20
    return all_idata
260
end
261

262
"""
263
    from_cmdstan(posterior::Chains; kwargs...) -> InferenceData
264

265
Call [`from_mcmcchains`](@ref) on output of `CmdStan`.
266
"""
267
function from_cmdstan(posterior::Chains; kwargs...)
268 20
    return from_mcmcchains(posterior; library="CmdStan", kwargs...)
269
end

Read our documentation on viewing source code .

Loading