Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: double dispatch for ggplot_add() #5537

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Imports:
MASS,
mgcv,
rlang (>= 1.1.0),
S7,
scales (>= 1.2.0),
stats,
tibble,
Expand Down Expand Up @@ -90,6 +91,7 @@ Collate:
'compat-plyr.R'
'utilities.R'
'aes.R'
'all-classes.R'
'utilities-checks.R'
'legend-draw.R'
'geom-.R'
Expand Down Expand Up @@ -196,9 +198,9 @@ Collate:
'margins.R'
'performance.R'
'plot-build.R'
'plot.R'
'plot-construction.R'
'plot-last.R'
'plot.R'
'position-.R'
'position-collide.R'
'position-dodge.R'
Expand Down
15 changes: 1 addition & 14 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@ S3method(get_alt_text,ggplot_built)
S3method(get_alt_text,gtable)
S3method(ggplot,"function")
S3method(ggplot,default)
S3method(ggplot_add,"NULL")
S3method(ggplot_add,"function")
S3method(ggplot_add,Coord)
S3method(ggplot_add,Facet)
S3method(ggplot_add,Guides)
S3method(ggplot_add,Layer)
S3method(ggplot_add,Scale)
S3method(ggplot_add,by)
S3method(ggplot_add,data.frame)
S3method(ggplot_add,default)
S3method(ggplot_add,labels)
S3method(ggplot_add,list)
S3method(ggplot_add,theme)
S3method(ggplot_add,uneval)
S3method(ggplot_build,ggplot)
S3method(ggplot_gtable,ggplot_built)
S3method(grid.draw,absoluteGrob)
Expand Down Expand Up @@ -228,6 +214,7 @@ export(PositionJitter)
export(PositionJitterdodge)
export(PositionNudge)
export(PositionStack)
export(S7_ggplot)
export(Scale)
export(ScaleBinned)
export(ScaleBinnedPosition)
Expand Down
11 changes: 11 additions & 0 deletions R/all-classes.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Class declarations for S7 dispatch.
class_theme <- S7::new_S3_class("theme")
class_scale <- S7::new_S3_class("Scale")
class_labels <- S7::new_S3_class("labels")
class_guides <- S7::new_S3_class("Guides")
class_aes <- S7::new_S3_class("uneval")
class_coord <- S7::new_S3_class("Coord")
class_facet <- S7::new_S3_class("Facet")
class_by <- S7::new_S3_class("by")
class_layer <- S7::new_S3_class("Layer")
class_scales_list <- S7::new_S3_class("ScalesList")
210 changes: 116 additions & 94 deletions R/plot-construction.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#' @include plot.R
NULL

#' Add components to a plot
#'
#' `+` is the key to constructing sophisticated ggplot2 graphics. It
Expand Down Expand Up @@ -70,7 +73,7 @@ add_ggplot <- function(p, object, objectname) {
if (is.null(object)) return(p)

p <- plot_clone(p)
p <- ggplot_add(object, p, objectname)
p <- ggplot_add(object, p, object_name = objectname)
set_last_plot(p)
p
}
Expand All @@ -81,105 +84,124 @@ add_ggplot <- function(p, object, objectname) {
#'
#' @param object An object to add to the plot
#' @param plot The ggplot object to add `object` to
#' @param object_name The name of the object to add
#' @param ... Additional arguments to pass to the methods. Typically, an
#' `object_name` argument that gives a display name for `object` to use
#' in error messages.
#'
#' @return A modified ggplot object
#'
#' @keywords internal
#' @export
ggplot_add <- function(object, plot, object_name) {
UseMethod("ggplot_add")
}
#' @export
ggplot_add.default <- function(object, plot, object_name) {
cli::cli_abort("Can't add {.var {object_name}} to a {.cls ggplot} object.")
}
#' @export
ggplot_add.NULL <- function(object, plot, object_name) {
plot
}
#' @export
ggplot_add.data.frame <- function(object, plot, object_name) {
plot$data <- object
plot
}
#' @export
ggplot_add.function <- function(object, plot, object_name) {
cli::cli_abort(c(
"Can't add {.var {object_name}} to a {.cls ggplot} object",
"i" = "Did you forget to add parentheses, as in {.fn {object_name}}?"
))
}
#' @export
ggplot_add.theme <- function(object, plot, object_name) {
plot$theme <- add_theme(plot$theme, object)
plot
}
#' @export
ggplot_add.Scale <- function(object, plot, object_name) {
plot$scales$add(object)
plot
}
#' @export
ggplot_add.labels <- function(object, plot, object_name) {
update_labels(plot, object)
}
#' @export
ggplot_add.Guides <- function(object, plot, object_name) {
update_guides(plot, object)
}
#' @export
ggplot_add.uneval <- function(object, plot, object_name) {
plot$mapping <- defaults(object, plot$mapping)
# defaults() doesn't copy class, so copy it.
class(plot$mapping) <- class(object)

labels <- make_labels(object)
names(labels) <- names(object)
update_labels(plot, labels)
}
#' @export
ggplot_add.Coord <- function(object, plot, object_name) {
if (!isTRUE(plot$coordinates$default)) {
cli::cli_inform("Coordinate system already present. Adding new coordinate system, which will replace the existing one.")
ggplot_add <- S7::new_generic("ggplot_add", c("object", "plot"))

S7::method(ggplot_add, list(S7::class_any, S7_ggplot)) <-
function(object, plot, object_name) {
cli::cli_abort("Can't add {.var {object_name}} to a {.cls ggplot} object.")
}

plot$coordinates <- object
plot
}
#' @export
ggplot_add.Facet <- function(object, plot, object_name) {
plot$facet <- object
plot
}
#' @export
ggplot_add.list <- function(object, plot, object_name) {
for (o in object) {
plot <- plot %+% o
# Cannot currently double dispatch on NULL directly
# replace `S7::new_S3_class("NULL")` with `NULL` when S7 version > 0.1.1
S7::method(ggplot_add, list(S7::new_S3_class("NULL"), S7_ggplot)) <-
function(object, plot, object_name) {
plot
}
plot
}
#' @export
ggplot_add.by <- function(object, plot, object_name) {
ggplot_add.list(object, plot, object_name)
}

#' @export
ggplot_add.Layer <- function(object, plot, object_name) {
plot$layers <- append(plot$layers, object)

# Add any new labels
mapping <- make_labels(object$mapping)
default <- lapply(make_labels(object$stat$default_aes), function(l) {
attr(l, "fallback") <- TRUE
l
})
new_labels <- defaults(mapping, default)
current_labels <- plot$labels
current_fallbacks <- vapply(current_labels, function(l) isTRUE(attr(l, "fallback")), logical(1))
plot$labels <- defaults(current_labels[!current_fallbacks], new_labels)
if (any(current_fallbacks)) {
plot$labels <- defaults(plot$labels, current_labels)
}
plot
}
S7::method(ggplot_add, list(S7::class_data.frame, S7_ggplot)) <-
function(object, plot, object_name) {
plot$data <- object
plot
}

S7::method(ggplot_add, list(S7::class_function, S7_ggplot)) <-
function(object, plot, object_name) {
cli::cli_abort(c(
"Can't add {.var {object_name}} to a {.cls ggplot} object",
"i" = "Did you forget to add parentheses, as in {.fn {object_name}}?"
))
}

S7::method(ggplot_add, list(class_theme, S7_ggplot)) <-
function(object, plot, object_name) {
plot$theme <- add_theme(plot$theme, object)
plot
}

S7::method(ggplot_add, list(class_scale, S7_ggplot)) <-
function(object, plot, object_name) {
plot$scales$add(object)
plot
}

S7::method(ggplot_add, list(class_labels, S7_ggplot)) <-
function(object, plot, object_name) {
update_labels(plot, object)
}

S7::method(ggplot_add, list(class_guides, S7_ggplot)) <-
function(object, plot, object_name) {
update_guides(plot, object)
}

S7::method(ggplot_add, list(class_aes, S7_ggplot)) <-
function(object, plot, object_name) {
mapping <- defaults(object, plot$mapping)
# defaults() doesn't copy class, so copy it.
class(mapping) <- class(object)
S7::prop(plot, "mapping") <- mapping


labels <- make_labels(object)
names(labels) <- names(object)
update_labels(plot, labels)
}

S7::method(ggplot_add, list(class_coord, S7_ggplot)) <-
function(object, plot, object_name) {
if (!isTRUE(plot$coordinates$default)) {
cli::cli_inform("Coordinate system already present. Adding new coordinate system, which will replace the existing one.")
}

plot$coordinates <- object
plot
}

S7::method(ggplot_add, list(class_facet, S7_ggplot)) <-
function(object, plot, object_name) {
plot$facet <- object
plot
}

S7::method(ggplot_add, list(S7::class_list, S7_ggplot)) <-
function(object, plot, object_name) {
for (o in object) {
plot <- plot %+% o
}
plot
}

S7::method(ggplot_add, list(class_by, S7_ggplot)) <-
function(object, plot, object_name) {
S7::method(ggplot_add, list(class_list, ggplot))(
object, plot, object_name
)
}

S7::method(ggplot_add, list(class_layer, S7_ggplot)) <-
function(object, plot, object_name) {
plot$layers <- append(plot$layers, object)

# Add any new labels
mapping <- make_labels(object$mapping)
default <- lapply(make_labels(object$stat$default_aes), function(l) {
attr(l, "fallback") <- TRUE
l
})
new_labels <- defaults(mapping, default)
current_labels <- plot$labels
current_fallbacks <- vapply(current_labels, function(l) isTRUE(attr(l, "fallback")), logical(1))
plot$labels <- defaults(current_labels[!current_fallbacks], new_labels)
if (any(current_fallbacks)) {
plot$labels <- defaults(plot$labels, current_labels)
}
plot
}
Loading