From 8701c06067a7103da4786f34fbb40c70ac2f7f34 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 11:40:46 +0200 Subject: [PATCH 01/10] Add Preferences as dependency --- Project.toml | 2 ++ src/PosteriorStats.jl | 2 ++ src/preferences.jl | 1 + 3 files changed, 5 insertions(+) create mode 100644 src/preferences.jl diff --git a/Project.toml b/Project.toml index 6eb4005..6b92085 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" PSIS = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -52,6 +53,7 @@ Optim = "1.7.2" OrderedCollections = "1.3.0" PDMats = "0.11.11" PSIS = "0.9.1" +Preferences = "1" PrettyTables = "2.1" Printf = "1" Random = "1" diff --git a/src/PosteriorStats.jl b/src/PosteriorStats.jl index 648fcfd..d972e4d 100644 --- a/src/PosteriorStats.jl +++ b/src/PosteriorStats.jl @@ -13,6 +13,7 @@ using MCMCDiagnosticTools: MCMCDiagnosticTools using Optim: Optim using OrderedCollections: OrderedCollections using PrettyTables: PrettyTables +using Preferences: Preferences using Printf: Printf using PDMats: PDMats using PSIS: PSIS, PSISResult, psis, psis! @@ -52,6 +53,7 @@ const DEFAULT_CI_PROB = 0.94 const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1) include("utils.jl") +include("preferences.jl") include("show_prettytable.jl") include("density_estimation.jl") include("kde.jl") diff --git a/src/preferences.jl b/src/preferences.jl new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/preferences.jl @@ -0,0 +1 @@ + From 9b7d179211b80def9ababc668776d67e50518e80 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 11:43:01 +0200 Subject: [PATCH 02/10] Add ci_kind preference --- LocalPreferences.toml | 4 ++++ src/preferences.jl | 6 ++++++ src/summarize.jl | 6 +++--- 3 files changed, 13 insertions(+), 3 deletions(-) create mode 100644 LocalPreferences.toml diff --git a/LocalPreferences.toml b/LocalPreferences.toml new file mode 100644 index 0000000..89bd2ff --- /dev/null +++ b/LocalPreferences.toml @@ -0,0 +1,4 @@ +[PosteriorStats] +## which credible interval to use (options: "eti", "hdi") +# ci_kind = "eti" + diff --git a/src/preferences.jl b/src/preferences.jl index 8b13789..30100d4 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -1 +1,7 @@ +function default_ci_fun() + ci_kind = Symbol(Preferences.load_preference(PosteriorStats, "ci_kind", "eti")) + ci_kind ∈ (:eti, :hdi) || + throw(ArgumentError("Invalid ci_kind: $ci_kind. Must be one of (eti, hdi).")) + return ci_kind == :eti ? eti : hdi +end diff --git a/src/summarize.jl b/src/summarize.jl index 0e33753..d2f4fe0 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -10,8 +10,8 @@ const _DEFAULT_SUMMARY_STATS_KIND_DOCSTRING = """ + `:diagnostics_median`: `ess_median`, `ess_tail`, `rhat`, `mcse_median` """ const _DEFAULT_SUMMARY_STATS_CI_DOCSTRING = """ -- `ci_fun=eti`: The function to compute the credible interval ``, if any. Supported - options are [`eti`](@ref) and [`hdi`](@ref). CI column name is +- `ci_fun=$(default_ci_fun())`: The function to compute the credible interval ``, if + any. Supported options are [`eti`](@ref) and [`hdi`](@ref). CI column name is `<100*ci_prob>`. - `ci_prob=$(DEFAULT_CI_PROB)`: The probability mass to be contained in the credible interval ``. @@ -340,7 +340,7 @@ function _default_stats(::typeof(Statistics.median); kwargs...) ) end -function _interval_stat(; ci_fun=eti, ci_prob=DEFAULT_CI_PROB, kwargs...) +function _interval_stat(; ci_fun=default_ci_fun(), ci_prob=DEFAULT_CI_PROB, kwargs...) ci_name = Symbol(_fname(ci_fun), _prob_to_string(ci_prob)) return ci_name => FixKeywords(ci_fun; prob=ci_prob) ∘ _cskipmissing end From 2540ffeef8943e60279bbac4bfe19bf8b95707c6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 11:43:56 +0200 Subject: [PATCH 03/10] Add ci_prob preference --- LocalPreferences.toml | 3 +++ src/PosteriorStats.jl | 1 - src/eti.jl | 13 ++++++------- src/hdi.jl | 11 ++++++----- src/preferences.jl | 6 ++++++ src/summarize.jl | 4 ++-- 6 files changed, 23 insertions(+), 15 deletions(-) diff --git a/LocalPreferences.toml b/LocalPreferences.toml index 89bd2ff..836a390 100644 --- a/LocalPreferences.toml +++ b/LocalPreferences.toml @@ -2,3 +2,6 @@ ## which credible interval to use (options: "eti", "hdi") # ci_kind = "eti" +## the probability mass to be contained in the credible interval (options: any number in (0, 1)) +# ci_prob = 0.94 + diff --git a/src/PosteriorStats.jl b/src/PosteriorStats.jl index d972e4d..af09055 100644 --- a/src/PosteriorStats.jl +++ b/src/PosteriorStats.jl @@ -49,7 +49,6 @@ export eti, eti!, hdi, hdi! # Others export loo_pit, r2_score -const DEFAULT_CI_PROB = 0.94 const INFORMATION_CRITERION_SCALES = (deviance=-2, log=1, negative_log=-1) include("utils.jl") diff --git a/src/eti.jl b/src/eti.jl index 086c125..2dced1e 100644 --- a/src/eti.jl +++ b/src/eti.jl @@ -16,7 +16,8 @@ See also: [`eti!`](@ref), [`hdi`](@ref), [`hdi!`](@ref). present # Keywords -- `prob`: the probability mass to be contained in the ETI. Default is `$(DEFAULT_CI_PROB)`. +- `prob`: the probability mass to be contained in the ETI. Default is + `$(default_ci_prob())`. - `kwargs`: remaining keywords are passed to [`Statistics.quantile`](@extref). # Returns @@ -26,7 +27,7 @@ See also: [`eti!`](@ref), [`hdi`](@ref), [`hdi!`](@ref). !!! note Any default value of `prob` is arbitrary. The default value of - `prob=$(DEFAULT_CI_PROB)` instead of a more common default like `prob=0.95` is + `prob=$(default_ci_prob())` instead of a more common default like `prob=0.95` is chosen to reminder the user of this arbitrariness. # Examples @@ -52,10 +53,8 @@ julia> eti(x) 8.048993174980314 .. 11.90116662171538 ``` """ -function eti( - x::AbstractArray{<:Real}; prob::Real=DEFAULT_CI_PROB, sorted::Bool=false, kwargs... -) - return eti!(sorted ? x : _copymutable(x); prob, sorted, kwargs...) +function eti(x::AbstractArray{<:Real}; sorted::Bool=false, kwargs...) + return eti!(sorted ? x : _copymutable(x); sorted, kwargs...) end """ @@ -65,7 +64,7 @@ A version of [`eti`](@ref) that partially sorts `samples` in-place while computi See also: [`eti`](@ref), [`hdi`](@ref), [`hdi!`](@ref). """ -function eti!(x::AbstractArray{<:Real}; prob::Real=DEFAULT_CI_PROB, kwargs...) +function eti!(x::AbstractArray{<:Real}; prob::Real=default_ci_prob(eltype(x)), kwargs...) ndims(x) > 0 || throw(ArgumentError("ETI cannot be computed for a 0-dimensional array.")) 0 < prob < 1 || throw(DomainError(prob, "ETI `prob` must be in the range `(0, 1)`.")) diff --git a/src/hdi.jl b/src/hdi.jl index 93fb5fd..ed51d6d 100644 --- a/src/hdi.jl +++ b/src/hdi.jl @@ -92,7 +92,8 @@ See also: [`hdi!`](@ref), [`eti`](@ref), [`eti!`](@ref). present, a marginal HDI is computed for each. # Keywords -- `prob`: the probability mass to be contained in the HDI. Default is `$(DEFAULT_CI_PROB)`. +- `prob`: the probability mass to be contained in the HDI. Default is + `$(default_ci_prob())`. - `sorted=false`: if `true`, the input samples are assumed to be sorted. - `method::Symbol`: the method used to estimate the HDI. Available options are: - `:unimodal`: Assumes a unimodal distribution (default). Bounds are entries in `samples`. @@ -115,9 +116,9 @@ See also: [`hdi!`](@ref), [`eti`](@ref), [`eti!`](@ref). the shape `(params...,)` is returned, containing marginal HDIs for each parameter. !!! note - Any default value of `prob` is arbitrary. The default value of `prob=$(DEFAULT_CI_PROB)` - instead of a more common default like `prob=0.95` is chosen to remind the user of this - arbitrariness. + Any default value of `prob` is arbitrary. The default value of + `prob=$(default_ci_prob())` instead of a more common default like `prob=0.95` is chosen + to remind the user of this arbitrariness. # Examples @@ -172,7 +173,7 @@ See also: [`hdi`](@ref), [`eti`](@ref), [`eti!`](@ref). """ Base.@constprop :aggressive function hdi!( x::AbstractArray{<:Real}; - prob::Real=DEFAULT_CI_PROB, + prob::Real=default_ci_prob(eltype(x)), is_discrete::Union{Bool,Nothing}=nothing, method::Union{Symbol,HDIEstimationMethod}=UnimodalHDI(), sorted::Bool=false, diff --git a/src/preferences.jl b/src/preferences.jl index 30100d4..0198c30 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -5,3 +5,9 @@ function default_ci_fun() return ci_kind == :eti ? eti : hdi end +function default_ci_prob((::Type{T})=Float64) where {T<:Real} + prob = T(Preferences.load_preference(PosteriorStats, "ci_prob", 0.94)) + 0 < prob < 1 || throw(DomainError(prob, "ci_prob must be in the range (0, 1).")) + return prob +end + diff --git a/src/summarize.jl b/src/summarize.jl index d2f4fe0..5775cd1 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -13,7 +13,7 @@ const _DEFAULT_SUMMARY_STATS_CI_DOCSTRING = """ - `ci_fun=$(default_ci_fun())`: The function to compute the credible interval ``, if any. Supported options are [`eti`](@ref) and [`hdi`](@ref). CI column name is `<100*ci_prob>`. -- `ci_prob=$(DEFAULT_CI_PROB)`: The probability mass to be contained in the credible +- `ci_prob=$(default_ci_prob())`: The probability mass to be contained in the credible interval ``. """ @@ -340,7 +340,7 @@ function _default_stats(::typeof(Statistics.median); kwargs...) ) end -function _interval_stat(; ci_fun=default_ci_fun(), ci_prob=DEFAULT_CI_PROB, kwargs...) +function _interval_stat(; ci_fun=default_ci_fun(), ci_prob=default_ci_prob(), kwargs...) ci_name = Symbol(_fname(ci_fun), _prob_to_string(ci_prob)) return ci_name => FixKeywords(ci_fun; prob=ci_prob) ∘ _cskipmissing end From 0a2e647af8a636b378bda27027b9c04b1d27f6e2 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 11:44:33 +0200 Subject: [PATCH 04/10] Add weights method preference --- LocalPreferences.toml | 3 ++ src/compare.jl | 10 ++-- src/model_weights.jl | 118 +++++++++++++++++++++--------------------- src/preferences.jl | 13 +++++ 4 files changed, 81 insertions(+), 63 deletions(-) diff --git a/LocalPreferences.toml b/LocalPreferences.toml index 836a390..9f84549 100644 --- a/LocalPreferences.toml +++ b/LocalPreferences.toml @@ -5,3 +5,6 @@ ## the probability mass to be contained in the credible interval (options: any number in (0, 1)) # ci_prob = 0.94 +## which method to use to compare models (options: "Stacking", "BootstrappedPseudoBMA", "PseudoBMA") +# weights_method = "Stacking" + diff --git a/src/compare.jl b/src/compare.jl index ce2aa09..ac8564f 100644 --- a/src/compare.jl +++ b/src/compare.jl @@ -15,8 +15,8 @@ see [Spiegelhalter2002](@citet). # Keywords - - `weights_method::AbstractModelWeightsMethod=Stacking()`: the method to be used to weight - the models. See [`model_weights`](@ref) for details + - `weights_method::AbstractModelWeightsMethod=$(default_weights_method())()`: the method + to be used to weight the models. See [`model_weights`](@ref) for details - `elpd_method=loo`: a method that computes an `AbstractELPDResult` from an argument in `models`. - `sort::Bool=true`: Whether to sort models by decreasing ELPD. @@ -29,8 +29,8 @@ see [Spiegelhalter2002](@citet). # Examples Compare the centered and non centered models of the eight school problem using the defaults: -[`loo`](@ref) and [`Stacking`](@ref) weights. A custom `myloo` method formates the inputs -as expected by [`loo`](@ref). +[`loo`](@ref) and [`$(default_weights_method())()`](@ref) weights. A custom `myloo` method +formates the inputs as expected by [`loo`](@ref). ```jldoctest compare; filter = [r"└.*", r"(\\d+\\.\\d{3})\\d*" => s"\\1"] julia> using ArviZExampleData @@ -77,7 +77,7 @@ ModelComparisonResult with BootstrappedPseudoBMA weights """ function compare( inputs; - weights_method::AbstractModelWeightsMethod=Stacking(), + weights_method::AbstractModelWeightsMethod=default_weights_method()(), elpd_method=loo, model_names=_indices(inputs), sort::Bool=true, diff --git a/src/model_weights.jl b/src/model_weights.jl index 84acb4c..be03008 100644 --- a/src/model_weights.jl +++ b/src/model_weights.jl @@ -9,64 +9,6 @@ Subtypes implement [`model_weights`](@ref)`(method, elpd_results)`. """ abstract type AbstractModelWeightsMethod end -""" - model_weights(elpd_results; method=Stacking()) - model_weights(method::AbstractModelWeightsMethod, elpd_results) - -Compute weights for each model in `elpd_results` using `method`. - -`elpd_results` is a `Tuple`, `NamedTuple`, or `AbstractVector` with -[`AbstractELPDResult`](@ref) entries. The weights are returned in the same type of -collection. - -[`Stacking`](@ref) is the recommended approach, as it performs well even when the true data -generating process is not included among the candidate models. See [Yao2018](@citet) for -details. - -See also: [`AbstractModelWeightsMethod`](@ref), [`compare`](@ref) - -# Examples - -Compute [`Stacking`](@ref) weights for two models: - -```jldoctest model_weights; filter = [r"└.*", r"(\\d+\\.\\d{3})\\d*" => s"\\1"] -julia> using ArviZExampleData - -julia> models = ( - centered=load_example_data("centered_eight"), - non_centered=load_example_data("non_centered_eight"), - ); - -julia> elpd_results = map(models) do idata - log_like = PermutedDimsArray(idata.log_likelihood.obs, (2, 3, 1)) - return loo(log_like) - end; -┌ Warning: 1 parameters had Pareto shape values 0.7 < k ≤ 1. Resulting importance sampling estimates are likely to be unstable. -└ @ PSIS ~/.julia/packages/PSIS/... - -julia> model_weights(elpd_results; method=Stacking()) |> pairs -pairs(::NamedTuple) with 2 entries: - :centered => 3.50546e-31 - :non_centered => 1.0 -``` - -Now we compute [`BootstrappedPseudoBMA`](@ref) weights for the same models: - -```jldoctest model_weights; setup = :(using Random; Random.seed!(94)) -julia> model_weights(elpd_results; method=BootstrappedPseudoBMA()) |> pairs -pairs(::NamedTuple) with 2 entries: - :centered => 0.492513 - :non_centered => 0.507487 -``` - -# References - -- [Yao2018](@cite) Yao et al. Bayesian Analysis 13, 3 (2018) -""" -function model_weights(elpd_results; method::AbstractModelWeightsMethod=Stacking()) - return model_weights(method, elpd_results) -end - # Akaike-type weights are defined as exp(-AIC/2), normalized to 1, which on the log-score # IC scale is equivalent to softmax akaike_weights!(w, elpds) = LogExpFunctions.softmax!(w, elpds) @@ -295,3 +237,63 @@ function _∇sphere_to_simplex!(∂x, x) ∂x .*= 2 .* x return ∂x end + +""" + model_weights(elpd_results; method=$(default_weights_method())()) + model_weights(method::AbstractModelWeightsMethod, elpd_results) + +Compute weights for each model in `elpd_results` using `method`. + +`elpd_results` is a `Tuple`, `NamedTuple`, or `AbstractVector` with +[`AbstractELPDResult`](@ref) entries. The weights are returned in the same type of +collection. + +[`Stacking`](@ref) is the recommended approach, as it performs well even when the true data +generating process is not included among the candidate models. See [Yao2018](@citet) for +details. + +See also: [`AbstractModelWeightsMethod`](@ref), [`compare`](@ref) + +# Examples + +Compute [`Stacking`](@ref) weights for two models: + +```jldoctest model_weights; filter = [r"└.*", r"(\\d+\\.\\d{3})\\d*" => s"\\1"] +julia> using ArviZExampleData + +julia> models = ( + centered=load_example_data("centered_eight"), + non_centered=load_example_data("non_centered_eight"), + ); + +julia> elpd_results = map(models) do idata + log_like = PermutedDimsArray(idata.log_likelihood.obs, (2, 3, 1)) + return loo(log_like) + end; +┌ Warning: 1 parameters had Pareto shape values 0.7 < k ≤ 1. Resulting importance sampling estimates are likely to be unstable. +└ @ PSIS ~/.julia/packages/PSIS/... + +julia> model_weights(elpd_results; method=Stacking()) |> pairs +pairs(::NamedTuple) with 2 entries: + :centered => 3.50546e-31 + :non_centered => 1.0 +``` + +Now we compute [`BootstrappedPseudoBMA`](@ref) weights for the same models: + +```jldoctest model_weights; setup = :(using Random; Random.seed!(94)) +julia> model_weights(elpd_results; method=BootstrappedPseudoBMA()) |> pairs +pairs(::NamedTuple) with 2 entries: + :centered => 0.492513 + :non_centered => 0.507487 +``` + +# References + +- [Yao2018](@cite) Yao et al. Bayesian Analysis 13, 3 (2018) +""" +function model_weights( + elpd_results; method::AbstractModelWeightsMethod=default_weights_method()() +) + return model_weights(method, elpd_results) +end diff --git a/src/preferences.jl b/src/preferences.jl index 0198c30..c11cc5a 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -11,3 +11,16 @@ function default_ci_prob((::Type{T})=Float64) where {T<:Real} return prob end +function default_weights_method() + method = Preferences.load_preference(PosteriorStats, "weights_method", "Stacking") + method == "Stacking" && return Stacking + method == "PseudoBMA" && return PseudoBMA + method == "BootstrappedPseudoBMA" && return BootstrappedPseudoBMA + throw( + ArgumentError( + "Invalid weights_method: $method. Must be one of " * + "(Stacking, PseudoBMA, BootstrappedPseudoBMA).", + ), + ) +end + From b03fd51b2f303771b3b89f82de8a75a27d236058 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 11:44:51 +0200 Subject: [PATCH 05/10] Add preferences to control pretty printing --- LocalPreferences.toml | 9 +++++++++ src/preferences.jl | 36 ++++++++++++++++++++++++++++++++++++ src/show_prettytable.jl | 36 +++++++++++++++++++++++++++--------- src/summarize.jl | 8 ++++++-- 4 files changed, 78 insertions(+), 11 deletions(-) diff --git a/LocalPreferences.toml b/LocalPreferences.toml index 9f84549..5ce5064 100644 --- a/LocalPreferences.toml +++ b/LocalPreferences.toml @@ -8,3 +8,12 @@ ## which method to use to compare models (options: "Stacking", "BootstrappedPseudoBMA", "PseudoBMA") # weights_method = "Stacking" +## precision settings for custom show methods. Note: these are considered internal and can be changed at any time +## without a breaking release. +# show_html_max_rows = 25 # options: integer >= 0 +# show_printf = "" # options: a valid Printf format string; if provided, supercedes all settings below +# show_sigdigits_using_se = true # options: Bool +# show_sigdigits_default = 3 # options: integer >= 0 +# show_sigdigits_se = 2 # options: integer >= 0 +# show_sigdigits_rhat = 2 # options: integer >= 0 +# show_ess_int = true # options: Bool diff --git a/src/preferences.jl b/src/preferences.jl index c11cc5a..06ec1dd 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -24,3 +24,39 @@ function default_weights_method() ) end +@kwdef struct PrecisionSettings + show_printf::String = "" + show_sigdigits_default::Int = 3 + show_sigdigits_se::Int = 2 + show_sigdigits_rhat::Int = 2 + show_sigdigits_using_se::Bool = true + show_ess_int::Bool = true + show_html_max_rows::Int = 25 +end + +function default_precision_settings() + default_settings = PrecisionSettings() + # load settings from preferences + settings = PrecisionSettings( + ( + _parse( + T, + Preferences.load_preference( + PosteriorStats, string(k), getproperty(default_settings, k) + ), + ) for + (k, T) in zip(fieldnames(PrecisionSettings), fieldtypes(PrecisionSettings)) + )..., + ) + # validate settings + for (k, T) in zip(fieldnames(PrecisionSettings), fieldtypes(PrecisionSettings)) + if T <: Int + v = getfield(settings, k) + v ≥ 0 || throw(DomainError(v, "Setting `$k` must be non-negative")) + end + end + return settings +end + +_parse(::Type{T}, v::T) where {T} = v +_parse(::Type{T}, v) where {T} = parse(T, v) diff --git a/src/show_prettytable.jl b/src/show_prettytable.jl index e8b358b..5a8e9a3 100644 --- a/src/show_prettytable.jl +++ b/src/show_prettytable.jl @@ -48,7 +48,9 @@ function _prettytables_integer_formatter(data) end end -function _prettytables_se_formatters(data; sigdigits_se=2) +function _prettytables_se_formatters( + data; sigdigits_se=default_precision_settings().show_sigdigits_se +) col_names = Tables.columnnames(data) pattern = r"^(?:mcse_|se_)(.*)|^(.*?)(?:_mcse|_se)$" formatters = Function[] @@ -75,18 +77,30 @@ function _prettytables_ess_formatter(data) return PrettyTables.ft_printf("%d", cols) end -function _prettytables_rhat_formatter(data) +function _prettytables_rhat_formatter( + data, sigdigits_rhat::Int=default_precision_settings().show_sigdigits_rhat +) col_names = Tables.columnnames(data) cols = findall(x -> (x === :rhat || startswith(string(x), "rhat_")), col_names) isempty(cols) && return nothing - return PrettyTables.ft_printf("%.2f", cols) + return PrettyTables.ft_printf("%.$(sigdigits_rhat)f", cols) end -function _default_prettytables_formatters(data; sigdigits_se=2, sigdigits_default=3) +function _default_prettytables_formatters( + data; + show_printf=default_precision_settings().show_printf, + sigdigits_se=default_precision_settings().show_sigdigits_se, + sigdigits_default=default_precision_settings().show_sigdigits_default, + show_sigdigits_using_se=default_precision_settings().show_sigdigits_using_se, + show_ess_int=default_precision_settings().show_ess_int, +) + isempty(show_printf) || return [PrettyTables.ft_printf(show_printf)] formatters = Union{Function,Nothing}[] push!(formatters, _prettytables_integer_formatter(data)) - append!(formatters, _prettytables_se_formatters(data; sigdigits_se)) - push!(formatters, _prettytables_ess_formatter(data)) + if show_sigdigits_using_se + append!(formatters, _prettytables_se_formatters(data; sigdigits_se)) + end + show_ess_int && push!(formatters, _prettytables_ess_formatter(data)) push!(formatters, ft_printf_sigdigits(sigdigits_default)) push!(formatters, ft_printf_sigdigits_interval(sigdigits_default)) return filter(!isnothing, formatters) @@ -119,8 +133,10 @@ end function _show_prettytable( io::IO, data; - sigdigits_se=2, - sigdigits_default=3, + sigdigits_se=default_precision_settings().show_sigdigits_se, + sigdigits_default=default_precision_settings().show_sigdigits_default, + show_sigdigits_using_se=default_precision_settings().show_sigdigits_using_se, + show_ess_int=default_precision_settings().show_ess_int, extra_formatters=(), alignment=_text_alignment(data), show_subheader=false, @@ -131,7 +147,9 @@ function _show_prettytable( ) formatters = ( extra_formatters..., - _default_prettytables_formatters(data; sigdigits_se, sigdigits_default)..., + _default_prettytables_formatters( + data; sigdigits_se, sigdigits_default, show_sigdigits_using_se, show_ess_int + )..., ) PrettyTables.pretty_table( io, diff --git a/src/summarize.jl b/src/summarize.jl index 5775cd1..f35f5f0 100644 --- a/src/summarize.jl +++ b/src/summarize.jl @@ -118,8 +118,12 @@ end function _show(io::IO, mime::MIME, stats::SummaryStats; kwargs...) nt = parent(stats) data = nt[keys(nt)[2:end]] - rhat_formatter = _prettytables_rhat_formatter(data) - extra_formatters = rhat_formatter === nothing ? () : (rhat_formatter,) + if isempty(default_precision_settings().show_printf) + rhat_formatter = _prettytables_rhat_formatter(data) + extra_formatters = rhat_formatter === nothing ? () : (rhat_formatter,) + else + extra_formatters = () + end return _show_prettytable( io, mime, From 0e8d2efc3ad0650897f43f5381f32a6d07405066 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 20:57:21 +0200 Subject: [PATCH 06/10] Default to Float32 eltype for prob --- src/preferences.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/preferences.jl b/src/preferences.jl index 06ec1dd..cc8979e 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -5,7 +5,7 @@ function default_ci_fun() return ci_kind == :eti ? eti : hdi end -function default_ci_prob((::Type{T})=Float64) where {T<:Real} +function default_ci_prob(::Type{T}=Float32) where {T<:Real} prob = T(Preferences.load_preference(PosteriorStats, "ci_prob", 0.94)) 0 < prob < 1 || throw(DomainError(prob, "ci_prob must be in the range (0, 1).")) return prob From 80eea866943793739371cb16152468a522e114d6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 20:59:32 +0200 Subject: [PATCH 07/10] Use CI defaults --- src/r2_score.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/r2_score.jl b/src/r2_score.jl index 69800fb..d3c40a8 100644 --- a/src/r2_score.jl +++ b/src/r2_score.jl @@ -28,7 +28,7 @@ credible interval (CI). + [`StatsBase.mode`](@extref) - `ci_fun=eti`: The function used to compute the credible interval if `summary` is `true`. Supported options are [`eti`](@ref) and [`hdi`](@ref). - - `ci_prob=$(DEFAULT_CI_PROB)`: The probability mass to be contained in the credible + - `ci_prob=$(default_ci_prob())`: The probability mass to be contained in the credible interval. # Examples @@ -55,8 +55,8 @@ function r2_score( y_pred; summary=true, point_estimate=Statistics.mean, - ci_fun=eti, - ci_prob=DEFAULT_CI_PROB, + ci_fun=default_ci_fun(), + ci_prob=default_ci_prob(float(Base.promote_eltype(y_true, y_pred))), ) r_squared = _r2_samples(y_true, y_pred) summary || return r_squared From d2380585cf5c4afec7c6a1ce03b28ea9eb013380 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 21:05:06 +0200 Subject: [PATCH 08/10] Add preference for point estimate --- LocalPreferences.toml | 3 +++ src/preferences.jl | 21 ++++++++++++++++++++- src/r2_score.jl | 8 ++++---- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/LocalPreferences.toml b/LocalPreferences.toml index 5ce5064..f253412 100644 --- a/LocalPreferences.toml +++ b/LocalPreferences.toml @@ -5,6 +5,9 @@ ## the probability mass to be contained in the credible interval (options: any number in (0, 1)) # ci_prob = 0.94 +## the default point estimate to use for summarizing distributions (options: "mean", "median", "mode") +# point_estimate = "mean" + ## which method to use to compare models (options: "Stacking", "BootstrappedPseudoBMA", "PseudoBMA") # weights_method = "Stacking" diff --git a/src/preferences.jl b/src/preferences.jl index cc8979e..d21f2b5 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -5,12 +5,31 @@ function default_ci_fun() return ci_kind == :eti ? eti : hdi end -function default_ci_prob(::Type{T}=Float32) where {T<:Real} +function default_ci_prob((::Type{T})=Float32) where {T<:Real} prob = T(Preferences.load_preference(PosteriorStats, "ci_prob", 0.94)) 0 < prob < 1 || throw(DomainError(prob, "ci_prob must be in the range (0, 1).")) return prob end +function default_point_estimate() + point_estimate = Symbol( + Preferences.load_preference(PosteriorStats, "point_estimate", "mean") + ) + if point_estimate === :mean + return Statistics.mean + elseif point_estimate === :median + return Statistics.median + elseif point_estimate === :mode + return StatsBase.mode + else + throw( + ArgumentError( + "Invalid point_estimate: $point_estimate. Must be one of (mean, median, mode).", + ), + ) + end +end + function default_weights_method() method = Preferences.load_preference(PosteriorStats, "weights_method", "Stacking") method == "Stacking" && return Stacking diff --git a/src/r2_score.jl b/src/r2_score.jl index d3c40a8..d30f25e 100644 --- a/src/r2_score.jl +++ b/src/r2_score.jl @@ -21,9 +21,9 @@ credible interval (CI). - `summary::Bool=true`: Whether to return a summary or an array of ``R²`` scores. The summary is a named tuple with the point estimate `:r2` and the credible interval `:`. - - `point_estimate=Statistics.mean`: The function used to compute the point estimate of the - ``R²`` scores if `summary` is `true`. Supported options are: - + [`Statistics.mean`](@extref) (default) + - `point_estimate=$(default_point_estimate())`: The function used to compute the point + estimate of the ``R²`` scores if `summary` is `true`. Supported options are: + + [`Statistics.mean`](@extref) + [`Statistics.median`](@extref) + [`StatsBase.mode`](@extref) - `ci_fun=eti`: The function used to compute the credible interval if `summary` is @@ -54,7 +54,7 @@ function r2_score( y_true, y_pred; summary=true, - point_estimate=Statistics.mean, + point_estimate=default_point_estimate(), ci_fun=default_ci_fun(), ci_prob=default_ci_prob(float(Base.promote_eltype(y_true, y_pred))), ) From 3a93216354f13c8a6304bb56259f3751ebeca9d1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 21:05:34 +0200 Subject: [PATCH 09/10] Update preference for ci_prob --- LocalPreferences.toml | 2 +- src/preferences.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/LocalPreferences.toml b/LocalPreferences.toml index f253412..07812dc 100644 --- a/LocalPreferences.toml +++ b/LocalPreferences.toml @@ -3,7 +3,7 @@ # ci_kind = "eti" ## the probability mass to be contained in the credible interval (options: any number in (0, 1)) -# ci_prob = 0.94 +# ci_prob = 0.89 ## the default point estimate to use for summarizing distributions (options: "mean", "median", "mode") # point_estimate = "mean" diff --git a/src/preferences.jl b/src/preferences.jl index d21f2b5..472e9b5 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -6,7 +6,7 @@ function default_ci_fun() end function default_ci_prob((::Type{T})=Float32) where {T<:Real} - prob = T(Preferences.load_preference(PosteriorStats, "ci_prob", 0.94)) + prob = T(Preferences.load_preference(PosteriorStats, "ci_prob", 0.89)) 0 < prob < 1 || throw(DomainError(prob, "ci_prob must be in the range (0, 1).")) return prob end From 4c9e681d7d0e5b45d4e64f54861d1600afca426e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Sep 2025 21:22:10 +0200 Subject: [PATCH 10/10] Update tests --- test/r2_score.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/r2_score.jl b/test/r2_score.jl index f9d069e..50f1426 100644 --- a/test/r2_score.jl +++ b/test/r2_score.jl @@ -17,11 +17,13 @@ using Test x_reshape = length(sz) == 1 ? x' : reshape(x, 1, 1, :) y_pred = slope .* x_reshape .+ intercept .+ randn(T, sz..., n) .* σ - r2_val = @inferred r2_score(y, y_pred; ci_prob=PosteriorStats.DEFAULT_CI_PROB) + r2_val = @inferred r2_score( + y, y_pred; ci_prob=PosteriorStats.default_ci_prob(T) + ) @test r2_val isa @NamedTuple{r2::T, eti::ClosedInterval{T}} r2_draws = @inferred PosteriorStats._r2_samples(y, y_pred) @test r2_val.r2 == mean(r2_draws) - @test r2_val.eti == eti(r2_draws; prob=PosteriorStats.DEFAULT_CI_PROB) + @test r2_val.eti == eti(r2_draws; prob=PosteriorStats.default_ci_prob(T)) @test r2_val == r2_score(y, y_pred) r2_val2 = r2_score(