Skip to content

estimate_contrast(): accept any "predict" + backends divergence #489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: modelbased
Title: Estimation of Model-Based Predictions, Contrasts and Means
Version: 0.10.0.31
Version: 0.10.0.32
Authors@R:
c(person(given = "Dominique",
family = "Makowski",
Expand Down
8 changes: 1 addition & 7 deletions R/clean_names.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,7 @@

if (length(vars) == 1) {
if (type == "contrast") {
if (minfo$is_logit && predict == "response") {
names(means)[names(means) == vars] <- "Odds_ratio"
} else if (minfo$is_poisson && predict == "response") {
names(means)[names(means) == vars] <- "Ratio"
} else {
names(means)[names(means) == vars] <- "Difference"
}
names(means)[names(means) == vars] <- "Difference"
} else if (type == "mean") {
if (minfo$is_logit && predict == "response") {
names(means)[names(means) == vars] <- "Probability"
Expand Down
6 changes: 5 additions & 1 deletion R/estimate_contrast_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ estimate_contrasts.estimate_predicted <- function(model,
minfo <- insight::model_info(model, response = 1)

# model df
dof <- insight::get_df(model, type = "wald", verbose = FALSE)
if (minfo$is_bayesian) {
dof <- Inf
} else {
dof <- insight::get_df(model, type = "wald", verbose = FALSE)
}
crit_factor <- (1 + ci) / 2

## TODO: For Bayesian models, we always use the returned standard errors
Expand Down
36 changes: 21 additions & 15 deletions R/get_emcontrasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,22 @@
# extract first focal term
first_focal <- my_args$contrast[1]

# setup arguments
fun_args <- list(model)

# handle distributional parameters
if (predict %in% .brms_aux_elements(model) && inherits(model, "brmsfit")) {
dpars <- TRUE
fun_args$dpar <- predict
} else {
dpars <- FALSE
fun_args$type <- predict
}

# add dots
dots <- list(...)
fun_args <- insight::compact_list(c(fun_args, dots))

# if first focal term is numeric, we contrast slopes
if (is.numeric(model_data[[first_focal]]) &&
!first_focal %in% on_the_fly_factors &&
Expand All @@ -51,23 +67,13 @@
insight::format_error("Please specify the `by` argument to calculate contrasts of slopes.") # nolint
}
# Run emmeans
estimated <- suppressMessages(emmeans::emtrends(
model,
specs = my_args$by,
var = my_args$contrast,
type = predict,
...
))
fun_args <- c(fun_args, list(specs = my_args$by, var = my_args$contrast))
estimated <- suppressMessages(do.call(emmeans::emtrends, fun_args))
emm_by <- NULL
} else {
# Run emmeans
estimated <- suppressMessages(emmeans::emmeans(
model,
specs = my_args$emmeans_specs,
at = my_args$emmeans_at,
type = predict,
...
))
fun_args <- c(fun_args, list(specs = my_args$emmeans_specs, at = my_args$emmeans_at))
estimated <- suppressMessages(do.call(emmeans::emmeans, fun_args))
# Find by variables
emm_by <- my_args$emmeans_specs[!my_args$emmeans_specs %in% my_args$contrast]
if (length(emm_by) == 0) {
Expand All @@ -76,7 +82,7 @@
}

# If means are on the response scale (e.g., probabilities), need to regrid
if (predict == "response") {
if (predict == "response" || dpars) {
estimated <- emmeans::regrid(estimated)
}

Expand Down Expand Up @@ -142,7 +148,7 @@


.format_emmeans_contrasts <- function(model, estimated, ci, p_adjust, ...) {
predict <- attributes(estimated)$predict

Check warning on line 151 in R/get_emcontrasts.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_emcontrasts.R,line=151,col=3,[object_overwrite_linter] 'predict' is an exported object from package 'stats'. Avoid re-using such symbols.
m_info <- insight::model_info(model, response = 1)

# Summarize and clean
Expand Down
7 changes: 7 additions & 0 deletions R/get_emmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@

# handle distributional parameters
if (predict %in% .brms_aux_elements(model) && inherits(model, "brmsfit")) {
dpars <- TRUE
fun_args$dpar <- predict
} else {
dpars <- FALSE
fun_args$type <- predict
}

Expand All @@ -70,6 +72,11 @@
# Run emmeans
estimated <- suppressMessages(suppressWarnings(do.call(emmeans::emmeans, fun_args)))

# backtransform to response scale for dpars
if (dpars) {
estimated <- emmeans::regrid(estimated)
}

# Special behaviour for transformations #138 (see below)
if ("retransform" %in% names(my_args) && length(my_args$retransform) > 0) {
for (var in names(my_args$retransform)) {
Expand Down Expand Up @@ -146,7 +153,7 @@
# Table formatting emmeans ----------------------------------------------------

.format_emmeans_means <- function(x, model, ci = 0.95, verbose = TRUE, ...) {
predict <- attributes(x)$predict

Check warning on line 156 in R/get_emmeans.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_emmeans.R,line=156,col=3,[object_overwrite_linter] 'predict' is an exported object from package 'stats'. Avoid re-using such symbols.
m_info <- insight::model_info(model, response = 1)

# Summarize and clean
Expand Down
8 changes: 4 additions & 4 deletions R/get_marginalmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
dots[c("by", "conf_level", "type", "digits", "bias_correction", "sigma", "offset")] <- NULL

# model df - can be passed via `...`
if (is.null(dots$df)) {
if (is.null(dots$df) && !model_info$is_bayesian) {
dots$df <- insight::get_df(model, type = "wald", verbose = FALSE)
}

Expand Down Expand Up @@ -225,7 +225,7 @@
# just need to add "hypothesis" argument
means <- .call_marginaleffects(fun_args)

# Fifth step: post-processin marginal means----------------------------------
# Fifth step: post-processing marginal means----------------------------------
# ---------------------------------------------------------------------------

# filter "by" rows when we have "average" marginalization, because we don't
Expand Down Expand Up @@ -299,13 +299,13 @@

.marginaleffects_errors <- function(out, fun_args) {
# what was requested?
if (!is.null(fun_args$hypothesis)) {

Check warning on line 302 in R/get_marginalmeans.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_marginalmeans.R,line=302,col=7,[if_not_else_linter] Prefer `if (A) x else y` to the less-readable `if (!A) y else x` in a simple if/else statement.
fun <- "marginal contrasts"
} else {
fun <- "marginal means"
}
# clean original error message
out$message <- gsub("\\s+", " ", gsub("\n", "", out$message))

Check warning on line 308 in R/get_marginalmeans.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_marginalmeans.R,line=308,col=41,[fixed_regex_linter] Use "\n" with fixed = TRUE here. This regular expression is static, i.e., its matches can be expressed as a fixed substring expression, which is faster to compute.
# setup clear error message
msg <- c(
paste0("Sorry, calculating ", fun, " failed with following error:"),
Expand All @@ -317,7 +317,7 @@
msg <- c(msg, "\nIt seems that not all required levels of the focal terms are available in the provided data. If you want predictions extrapolated to a hypothetical target population, try setting `estimate=\"population\".") # nolint
}
# we get this error for models with complex random effects structures in glmmTMB
if (grepl("map factor length must equal", out$message, fixed = TRUE)) {
if (grepl("map factor length must equal", out$message, fixed = TRUE) || grepl("cannot allocate", out$message, fixed = TRUE)) { # nolint
msg <- c(
msg,
paste0(
Expand All @@ -335,7 +335,7 @@
}


# filter datagrid foe `estimate = "average"`---------------------------------
# filter datagrid for `estimate = "average"`---------------------------------

.filter_datagrid_average <- function(means, estimate, datagrid, datagrid_info) {
# filter "by" rows when we have "average" marginalization, because we don't
Expand Down Expand Up @@ -472,22 +472,22 @@
# if no offset argument was specified, tell user what this means
msg <- switch(estimate,
specific = ,
typical = "Model contains an offset-term, which is set to its mean value. If you want to average predictions over the distribution of the offset (if appropriate), use `estimate = \"average\"`. If you want to fix the offset to a specific value, for instance `1`, use `offset = 1`.",

Check warning on line 475 in R/get_marginalmeans.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_marginalmeans.R,line=475,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 289 characters.
"Model contains an offset-term and you average predictions over the distribution of that offset. If you want to fix the offset to a specific value, for instance `1`, use `offset = 1` and set `estimate = \"typical\"`."

Check warning on line 476 in R/get_marginalmeans.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_marginalmeans.R,line=476,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 225 characters.
)
# if offset term is log-transformed, tell user. offset should be fixed then
log_offset <- insight::find_transformation(insight::find_offset(model, as_term = TRUE))
if (!is.null(log_offset) && startsWith(log_offset, "log")) {
msg <- c(
msg,
"We also found that the model has a log-transformed offset term. If you use the `offset` argument, the log-transformation will automatically be applied to the provided offset-value. I.e., consider using, for instance, `offset = 10` and not `offset = log(10)`."

Check warning on line 483 in R/get_marginalmeans.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_marginalmeans.R,line=483,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 270 characters.
)
}
} else {
# if offset was specified, and estimate averages over predictions, tell this
msg <- switch(estimate,
average = ,
population = paste0("For `estimate = \"", estimate, "\"`, predictions are averaged over the distribution of the offset and the `offset` argument is ignored. If you want to fix the offset to a specific value, for instance `1`, use `offset = 1` and set `estimate = \"typical\"`.")

Check warning on line 490 in R/get_marginalmeans.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_marginalmeans.R,line=490,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 286 characters.
)
}
if (!is.null(msg)) {
Expand Down
2 changes: 1 addition & 1 deletion R/get_marginaltrends.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
}

# model df - can be passed via `...`
if (is.null(dots$df)) {
if (is.null(dots$df) && !model_info$is_bayesian) {
dots$df <- insight::get_df(model, type = "wald", verbose = FALSE)
}

Expand Down Expand Up @@ -204,10 +204,10 @@
# check if user provided values in `trend`, e.g. `trend=1:10`. We then pass
# this argument to also create a data grid, but we also need to "clean" trend
if (grepl("=", trend, fixed = TRUE)) {
range <- trend

Check warning on line 207 in R/get_marginaltrends.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_marginaltrends.R,line=207,col=5,[object_overwrite_linter] 'range' is an exported object from package 'base'. Avoid re-using such symbols.
trend <- gsub("=.*", "\\1", trend)
} else {
range <- NULL

Check warning on line 210 in R/get_marginaltrends.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/get_marginaltrends.R,line=210,col=5,[object_overwrite_linter] 'range' is an exported object from package 'base'. Avoid re-using such symbols.
}

# make sure range in `trend` is not also in `by`
Expand Down
Loading