Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions LocalPreferences.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[PosteriorStats]
## which credible interval to use (options: "eti", "hdi")
# ci_kind = "eti"

## the probability mass to be contained in the credible interval (options: any number in (0, 1))
# ci_prob = 0.89

## the default point estimate to use for summarizing distributions (options: "mean", "median", "mode")
# point_estimate = "mean"

## which method to use to compare models (options: "Stacking", "BootstrappedPseudoBMA", "PseudoBMA")
# weights_method = "Stacking"

## precision settings for custom show methods. Note: these are considered internal and can be changed at any time
## without a breaking release.
# show_html_max_rows = 25 # options: integer >= 0
# show_printf = "" # options: a valid Printf format string; if provided, supercedes all settings below
# show_sigdigits_using_se = true # options: Bool
# show_sigdigits_default = 3 # options: integer >= 0
# show_sigdigits_se = 2 # options: integer >= 0
# show_sigdigits_rhat = 2 # options: integer >= 0
# show_ess_int = true # options: Bool
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
PSIS = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -52,6 +53,7 @@ Optim = "1.7.2"
OrderedCollections = "1.3.0"
PDMats = "0.11.11"
PSIS = "0.9.1"
Preferences = "1"
PrettyTables = "2.1"
Printf = "1"
Random = "1"
Expand Down
3 changes: 2 additions & 1 deletion src/PosteriorStats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using MCMCDiagnosticTools: MCMCDiagnosticTools
using Optim: Optim
using OrderedCollections: OrderedCollections
using PrettyTables: PrettyTables
using Preferences: Preferences
using Printf: Printf
using PDMats: PDMats
using PSIS: PSIS, PSISResult, psis, psis!
Expand Down Expand Up @@ -48,10 +49,10 @@ export eti, eti!, hdi, hdi!
# Others
export loo_pit, r2_score

const DEFAULT_CI_PROB = 0.89f0
const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1)

include("utils.jl")
include("preferences.jl")
include("show_prettytable.jl")
include("density_estimation.jl")
include("kde.jl")
Expand Down
6 changes: 3 additions & 3 deletions src/compare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ The ELPD is estimated by Pareto smoothed importance sampling leave-one-out cross

# Keywords

- `weights_method::AbstractModelWeightsMethod=Stacking()`: the method to be used to weight
the models. See [`model_weights`](@ref) for details
- `weights_method::AbstractModelWeightsMethod=$(default_weights_method())()`: the method
to be used to weight the models. See [`model_weights`](@ref) for details
- `sort::Bool=true`: Whether to sort models by decreasing ELPD.

# Returns
Expand Down Expand Up @@ -67,7 +67,7 @@ ModelComparisonResult with BootstrappedPseudoBMA weights
"""
function compare(
inputs;
weights_method::AbstractModelWeightsMethod=Stacking(),
weights_method::AbstractModelWeightsMethod=default_weights_method()(),
model_names=_indices(inputs),
sort::Bool=true,
)
Expand Down
13 changes: 6 additions & 7 deletions src/eti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ See also: [`eti!`](@ref), [`hdi`](@ref), [`hdi!`](@ref).
present

# Keywords
- `prob`: the probability mass to be contained in the ETI. Default is `$(DEFAULT_CI_PROB)`.
- `prob`: the probability mass to be contained in the ETI. Default is
`$(default_ci_prob())`.
- `kwargs`: remaining keywords are passed to [`Statistics.quantile`](@extref).

# Returns
Expand All @@ -26,7 +27,7 @@ See also: [`eti!`](@ref), [`hdi`](@ref), [`hdi!`](@ref).

!!! note
Any default value of `prob` is arbitrary. The default value of
`prob=$(DEFAULT_CI_PROB)` instead of a more common default like `prob=0.95` is
`prob=$(default_ci_prob())` instead of a more common default like `prob=0.95` is
chosen to reminder the user of this arbitrariness.

# Examples
Expand All @@ -52,10 +53,8 @@ julia> eti(x)
8.389954310154792 .. 11.631846719859958
```
"""
function eti(
x::AbstractArray{<:Real}; prob::Real=DEFAULT_CI_PROB, sorted::Bool=false, kwargs...
)
return eti!(sorted ? x : _copymutable(x); prob, sorted, kwargs...)
function eti(x::AbstractArray{<:Real}; sorted::Bool=false, kwargs...)
return eti!(sorted ? x : _copymutable(x); sorted, kwargs...)
end

"""
Expand All @@ -65,7 +64,7 @@ A version of [`eti`](@ref) that partially sorts `samples` in-place while computi

See also: [`eti`](@ref), [`hdi`](@ref), [`hdi!`](@ref).
"""
function eti!(x::AbstractArray{<:Real}; prob::Real=DEFAULT_CI_PROB, kwargs...)
function eti!(x::AbstractArray{<:Real}; prob::Real=default_ci_prob(eltype(x)), kwargs...)
ndims(x) > 0 ||
throw(ArgumentError("ETI cannot be computed for a 0-dimensional array."))
0 < prob < 1 || throw(DomainError(prob, "ETI `prob` must be in the range `(0, 1)`."))
Expand Down
11 changes: 6 additions & 5 deletions src/hdi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ See also: [`hdi!`](@ref), [`eti`](@ref), [`eti!`](@ref).
present, a marginal HDI is computed for each.

# Keywords
- `prob`: the probability mass to be contained in the HDI. Default is `$(DEFAULT_CI_PROB)`.
- `prob`: the probability mass to be contained in the HDI. Default is
`$(default_ci_prob())`.
- `sorted=false`: if `true`, the input samples are assumed to be sorted.
- `method::Symbol`: the method used to estimate the HDI. Available options are:
- `:unimodal`: Assumes a unimodal distribution (default). Bounds are entries in `samples`.
Expand All @@ -115,9 +116,9 @@ See also: [`hdi!`](@ref), [`eti`](@ref), [`eti!`](@ref).
the shape `(params...,)` is returned, containing marginal HDIs for each parameter.

!!! note
Any default value of `prob` is arbitrary. The default value of `prob=$(DEFAULT_CI_PROB)`
instead of a more common default like `prob=0.95` is chosen to remind the user of this
arbitrariness.
Any default value of `prob` is arbitrary. The default value of
`prob=$(default_ci_prob())` instead of a more common default like `prob=0.95` is chosen
to remind the user of this arbitrariness.

# Examples

Expand Down Expand Up @@ -172,7 +173,7 @@ See also: [`hdi`](@ref), [`eti`](@ref), [`eti!`](@ref).
"""
Base.@constprop :aggressive function hdi!(
x::AbstractArray{<:Real};
prob::Real=DEFAULT_CI_PROB,
prob::Real=default_ci_prob(eltype(x)),
is_discrete::Union{Bool,Nothing}=nothing,
method::Union{Symbol,HDIEstimationMethod}=UnimodalHDI(),
sorted::Bool=false,
Expand Down
118 changes: 60 additions & 58 deletions src/model_weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,64 +9,6 @@ Subtypes implement [`model_weights`](@ref)`(method, elpd_results)`.
"""
abstract type AbstractModelWeightsMethod end

"""
model_weights(elpd_results; method=Stacking())
model_weights(method::AbstractModelWeightsMethod, elpd_results)

Compute weights for each model in `elpd_results` using `method`.

`elpd_results` is a `Tuple`, `NamedTuple`, or `AbstractVector` with
[`AbstractELPDResult`](@ref) entries. The weights are returned in the same type of
collection.

[`Stacking`](@ref) is the recommended approach, as it performs well even when the true data
generating process is not included among the candidate models. See [Yao2018](@citet) for
details.

See also: [`AbstractModelWeightsMethod`](@ref), [`compare`](@ref)

# Examples

Compute [`Stacking`](@ref) weights for two models:

```jldoctest model_weights; filter = [r"└.*", r"(\\d+\\.\\d{3})\\d*" => s"\\1"]
julia> using ArviZExampleData

julia> models = (
centered=load_example_data("centered_eight"),
non_centered=load_example_data("non_centered_eight"),
);

julia> elpd_results = map(models) do idata
log_like = PermutedDimsArray(idata.log_likelihood.obs, (2, 3, 1))
return loo(log_like)
end;
┌ Warning: 1 parameters had Pareto shape values 0.7 < k ≤ 1. Resulting importance sampling estimates are likely to be unstable.
└ @ PSIS ~/.julia/packages/PSIS/...

julia> model_weights(elpd_results; method=Stacking()) |> pairs
pairs(::NamedTuple) with 2 entries:
:centered => 3.50546e-31
:non_centered => 1.0
```

Now we compute [`BootstrappedPseudoBMA`](@ref) weights for the same models:

```jldoctest model_weights; setup = :(using Random; Random.seed!(94))
julia> model_weights(elpd_results; method=BootstrappedPseudoBMA()) |> pairs
pairs(::NamedTuple) with 2 entries:
:centered => 0.492513
:non_centered => 0.507487
```

# References

- [Yao2018](@cite) Yao et al. Bayesian Analysis 13, 3 (2018)
"""
function model_weights(elpd_results; method::AbstractModelWeightsMethod=Stacking())
return model_weights(method, elpd_results)
end

# Akaike-type weights are defined as exp(-AIC/2), normalized to 1, which on the log-score
# IC scale is equivalent to softmax
akaike_weights!(w, elpds) = LogExpFunctions.softmax!(w, elpds)
Expand Down Expand Up @@ -295,3 +237,63 @@ function _∇sphere_to_simplex!(∂x, x)
∂x .*= 2 .* x
return ∂x
end

"""
model_weights(elpd_results; method=$(default_weights_method())())
model_weights(method::AbstractModelWeightsMethod, elpd_results)

Compute weights for each model in `elpd_results` using `method`.

`elpd_results` is a `Tuple`, `NamedTuple`, or `AbstractVector` with
[`AbstractELPDResult`](@ref) entries. The weights are returned in the same type of
collection.

[`Stacking`](@ref) is the recommended approach, as it performs well even when the true data
generating process is not included among the candidate models. See [Yao2018](@citet) for
details.

See also: [`AbstractModelWeightsMethod`](@ref), [`compare`](@ref)

# Examples

Compute [`Stacking`](@ref) weights for two models:

```jldoctest model_weights; filter = [r"└.*", r"(\\d+\\.\\d{3})\\d*" => s"\\1"]
julia> using ArviZExampleData

julia> models = (
centered=load_example_data("centered_eight"),
non_centered=load_example_data("non_centered_eight"),
);

julia> elpd_results = map(models) do idata
log_like = PermutedDimsArray(idata.log_likelihood.obs, (2, 3, 1))
return loo(log_like)
end;
┌ Warning: 1 parameters had Pareto shape values 0.7 < k ≤ 1. Resulting importance sampling estimates are likely to be unstable.
└ @ PSIS ~/.julia/packages/PSIS/...

julia> model_weights(elpd_results; method=Stacking()) |> pairs
pairs(::NamedTuple) with 2 entries:
:centered => 3.50546e-31
:non_centered => 1.0
```

Now we compute [`BootstrappedPseudoBMA`](@ref) weights for the same models:

```jldoctest model_weights; setup = :(using Random; Random.seed!(94))
julia> model_weights(elpd_results; method=BootstrappedPseudoBMA()) |> pairs
pairs(::NamedTuple) with 2 entries:
:centered => 0.492513
:non_centered => 0.507487
```

# References

- [Yao2018](@cite) Yao et al. Bayesian Analysis 13, 3 (2018)
"""
function model_weights(
elpd_results; method::AbstractModelWeightsMethod=default_weights_method()()
)
return model_weights(method, elpd_results)
end
81 changes: 81 additions & 0 deletions src/preferences.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
function default_ci_fun()
ci_kind = Symbol(Preferences.load_preference(PosteriorStats, "ci_kind", "eti"))
ci_kind ∈ (:eti, :hdi) ||
throw(ArgumentError("Invalid ci_kind: $ci_kind. Must be one of (eti, hdi)."))
return ci_kind == :eti ? eti : hdi
end

function default_ci_prob((::Type{T})=Float32) where {T<:Real}
prob = T(Preferences.load_preference(PosteriorStats, "ci_prob", 0.89))
0 < prob < 1 || throw(DomainError(prob, "ci_prob must be in the range (0, 1)."))
return prob
end

function default_point_estimate()
point_estimate = Symbol(
Preferences.load_preference(PosteriorStats, "point_estimate", "mean")
)
if point_estimate === :mean
return Statistics.mean
elseif point_estimate === :median
return Statistics.median
elseif point_estimate === :mode
return StatsBase.mode
else
throw(
ArgumentError(
"Invalid point_estimate: $point_estimate. Must be one of (mean, median, mode).",
),
)
end
end

function default_weights_method()
method = Preferences.load_preference(PosteriorStats, "weights_method", "Stacking")
method == "Stacking" && return Stacking
method == "PseudoBMA" && return PseudoBMA
method == "BootstrappedPseudoBMA" && return BootstrappedPseudoBMA
throw(
ArgumentError(
"Invalid weights_method: $method. Must be one of " *
"(Stacking, PseudoBMA, BootstrappedPseudoBMA).",
),
)
end

@kwdef struct PrecisionSettings
show_printf::String = ""
show_sigdigits_default::Int = 3
show_sigdigits_se::Int = 2
show_sigdigits_rhat::Int = 2
show_sigdigits_using_se::Bool = true
show_ess_int::Bool = true
show_html_max_rows::Int = 25
end

function default_precision_settings()
default_settings = PrecisionSettings()
# load settings from preferences
settings = PrecisionSettings(
(
_parse(
T,
Preferences.load_preference(
PosteriorStats, string(k), getproperty(default_settings, k)
),
) for
(k, T) in zip(fieldnames(PrecisionSettings), fieldtypes(PrecisionSettings))
)...,
)
# validate settings
for (k, T) in zip(fieldnames(PrecisionSettings), fieldtypes(PrecisionSettings))
if T <: Int
v = getfield(settings, k)
v ≥ 0 || throw(DomainError(v, "Setting `$k` must be non-negative"))
end
end
return settings
end

_parse(::Type{T}, v::T) where {T} = v
_parse(::Type{T}, v) where {T} = parse(T, v)
14 changes: 7 additions & 7 deletions src/r2_score.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ credible interval (CI).
- `summary::Bool=true`: Whether to return a summary or an array of ``R²`` scores. The
summary is a named tuple with the point estimate `:r2` and the credible interval
`:<ci_fun>`.
- `point_estimate=Statistics.mean`: The function used to compute the point estimate of the
``R²`` scores if `summary` is `true`. Supported options are:
+ [`Statistics.mean`](@extref) (default)
- `point_estimate=$(default_point_estimate())`: The function used to compute the point
estimate of the ``R²`` scores if `summary` is `true`. Supported options are:
+ [`Statistics.mean`](@extref)
+ [`Statistics.median`](@extref)
+ [`StatsBase.mode`](@extref)
- `ci_fun=eti`: The function used to compute the credible interval if `summary` is
`true`. Supported options are [`eti`](@ref) and [`hdi`](@ref).
- `ci_prob=$(DEFAULT_CI_PROB)`: The probability mass to be contained in the credible
- `ci_prob=$(default_ci_prob())`: The probability mass to be contained in the credible
interval.

# Examples
Expand All @@ -54,9 +54,9 @@ function r2_score(
y_true,
y_pred;
summary=true,
point_estimate=Statistics.mean,
ci_fun=eti,
ci_prob=DEFAULT_CI_PROB,
point_estimate=default_point_estimate(),
ci_fun=default_ci_fun(),
ci_prob=default_ci_prob(float(Base.promote_eltype(y_true, y_pred))),
)
r_squared = _r2_samples(y_true, y_pred)
summary || return r_squared
Expand Down
Loading
Loading