Skip to content
This repository was archived by the owner on Jun 29, 2019. It is now read-only.

Commit 7c93f90

Browse files
Robert M. Horton, PhDdeguhath
Robert M. Horton, PhD
authored andcommitted
Create learning_curve_lib.R (#33)
1 parent 5ca1397 commit 7c93f90

File tree

1 file changed

+224
-0
lines changed

1 file changed

+224
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Use random number seed to select the rows to be used for training or testing.
2+
# Collect error stats for training set from the model when possible.
3+
4+
# execObjects = c("data_table", "SALT")
5+
run_training_fraction <- function(model_class, training_fraction,
6+
with_formula, test_set_kfold_id, KFOLDS=3, ...){
7+
learner <- get(model_class)
8+
9+
NUM_BUCKETS <- 1000 # for approximate AUC
10+
11+
row_tagger <- function(data_list, start_row, num_rows,
12+
chunk_num, prob, kfolds, kfold_id, salt){
13+
rowNums <- seq(from=start_row, length.out=num_rows)
14+
set.seed(chunk_num + salt)
15+
kfold <- sample(1:kfolds, size=num_rows, replace=TRUE)
16+
in_test_set <- kfold == kfold_id
17+
num_training_candidates <- sum(!in_test_set)
18+
keepers <- sample(rowNums[!in_test_set], prob * num_training_candidates)
19+
data_list$in_training_set <- rowNums %in% keepers
20+
data_list$in_test_set <- in_test_set
21+
data_list
22+
}
23+
24+
row_selection_transform <- function(data_list){
25+
row_tagger(data_list, .rxStartRow, .rxNumRows, .rxChunkNum,
26+
prob, kfolds, kfold_id, salt)
27+
}
28+
29+
# Calculate RMSE (root mean squared error) for predictions made with a given model on a dataset.
30+
# Only rows in the test set are counted.
31+
RMSE_transform <- function(data_list){
32+
if (.rxChunkNum == 1){
33+
.rxSet("SSE", 0)
34+
.rxSet("rowCount", 0)
35+
}
36+
SSE <- .rxGet("SSE")
37+
rowCount <- .rxGet("rowCount")
38+
39+
data_list <- row_tagger(data_list, .rxStartRow, .rxNumRows, .rxChunkNum,
40+
prob, kfolds, kfold_id, salt)
41+
42+
# rxPredict returns a dataframe if you give it one. # data_list$in_test_set
43+
if (class(model)[1] == "SDCAR"){
44+
test_chunk <- as.data.frame(data_list)[data_list[[SET_SELECTOR]],]
45+
outcome_var <- model$params$formulaVars[1]
46+
residual <- rxPredict(model, test_chunk)[[1]] - test_chunk[[outcome_var]]
47+
} else {
48+
residual <- rxPredict(model, as.data.frame(data_list)[data_list[[SET_SELECTOR]],],
49+
computeResiduals=TRUE, residVarNames="residual")$residual
50+
}
51+
52+
SSE <- SSE + sum(residual^2, na.rm=TRUE)
53+
rowCount <- rowCount + sum(!is.na(residual))
54+
.rxSet("SSE", SSE)
55+
.rxSet("rowCount", rowCount)
56+
return(data_list)
57+
}
58+
59+
AUC_transform <- function(data_list){
60+
# NUM_BUCKETS <- 100;
61+
if (.rxChunkNum == 1){
62+
# assume the first chunk gives a reasonably representative sample of score distribution
63+
# chunk1_scores <- rxPredict(model, as.data.frame(data_list))[[1]]
64+
# quantile_breaks <- unique(quantile(chunk1_scores, probs=0:NUM_BUCKETS/NUM_BUCKETS)))
65+
# scores must be in range of probabilities (between 0 and 1)
66+
.rxSet("BREAKS", (0:NUM_BUCKETS)/NUM_BUCKETS) #
67+
.rxSet("TP", numeric(NUM_BUCKETS))
68+
.rxSet("FP", numeric(NUM_BUCKETS))
69+
}
70+
TPR <- .rxGet("TP")
71+
FPR <- .rxGet("FP")
72+
BREAKS <- .rxGet("BREAKS")
73+
74+
data_list <- row_tagger(data_list, .rxStartRow, .rxNumRows, .rxChunkNum,
75+
prob, kfolds, kfold_id, salt)
76+
77+
data_set <- as.data.frame(data_list)[data_list[[SET_SELECTOR]],]
78+
labels <- data_set[[model$param$formulaVars$original$depVars]]
79+
scores <- rxPredict(model, data_set)[[1]] # rxPredict returns a dataframe if you give it one.
80+
bucket <- cut(scores, breaks=BREAKS, include.lowest=TRUE)
81+
82+
# data.frame(labels, scores, bucket)
83+
TP <- rev(as.vector(xtabs(labels ~ bucket))) # positive cases in each bucket, top scores first
84+
N <- rev(as.vector(xtabs( ~ bucket))) # total cases in each bucket
85+
FP <- N - TP
86+
87+
.rxSet("TP", TP)
88+
.rxSet("FP", FP)
89+
return(data_list)
90+
}
91+
92+
simple_auc <- function(TPR, FPR){
93+
dFPR <- c(0, diff(FPR))
94+
sum(TPR * dFPR) - sum(diff(TPR) * diff(FPR))/2
95+
}
96+
97+
calculate_RMSE <- function(with_model, xdfdata, set_selector){
98+
xformObjs <- rxDataStep(inData=xdfdata,
99+
transformFunc=RMSE_transform,
100+
transformVars=c(rxGetVarNames(xdfdata) ),
101+
transformObjects=list(SSE=0, rowCount=0, SET_SELECTOR=set_selector,
102+
model=with_model, row_tagger=row_tagger,
103+
prob=training_fraction, kfolds=KFOLDS,
104+
kfold_id=test_set_kfold_id,
105+
salt=SALT),
106+
returnTransformObjects=TRUE)
107+
with(xformObjs, sqrt(SSE/rowCount))
108+
}
109+
110+
calculate_AUC <- function(with_model, xdfdata, set_selector){
111+
# NUM_BUCKETS <- 100; kfolds=3
112+
xformObjs <- rxDataStep(inData=xdfdata,
113+
transformFunc=AUC_transform,
114+
transformVars=c( rxGetVarNames(xdfdata) ),
115+
transformObjects=list(TP=numeric(NUM_BUCKETS), FP=numeric(NUM_BUCKETS),
116+
SET_SELECTOR=set_selector,
117+
model=with_model, row_tagger=row_tagger,
118+
prob=training_fraction, kfolds=KFOLDS,
119+
kfold_id=test_set_kfold_id,
120+
salt=SALT),
121+
returnTransformObjects=TRUE)
122+
with(xformObjs, {
123+
TPR <- cumsum(TP)/sum(TP)
124+
FPR <- cumsum(FP)/sum(FP)
125+
simple_auc(TPR, FPR)
126+
})
127+
}
128+
129+
get_training_error <- function(fit) {
130+
switch( class(fit)[[1]],
131+
rxLinMod = with(summary(fit)[[1]], sqrt(residual.squares/nValidObs)),
132+
rxBTrees =,
133+
rxDForest = if(!is.null(fit$type) && "anova" == fit$type){
134+
calculate_RMSE(fit, data_table, "in_training_set")
135+
} else {
136+
calculate_AUC(fit, data_table, "in_training_set")
137+
},
138+
rxDTree = if ("anova" == fit$method){
139+
calculate_RMSE(fit, data_table, "in_training_set")
140+
} else { # "class"
141+
calculate_AUC(fit, data_table, "in_training_set")
142+
},
143+
rxLogit = calculate_AUC(fit, data_table, "in_training_set"),
144+
SDCA = calculate_AUC(fit, data_table, "in_training_set"),
145+
#rxFastLinear, class = SDCA (BinaryClassifierTrainer)
146+
SDCAR = calculate_RMSE(fit, data_table, "in_training_set")
147+
# rxFastLinear, class = SDCAR (RegressorTrainer)
148+
)
149+
}
150+
151+
get_test_error <- function(fit) {
152+
switch( class(fit)[[1]],
153+
rxLinMod = calculate_RMSE(fit, data_table, "in_test_set"),
154+
rxBTrees =,
155+
rxDForest = if(!is.null(fit$type) && "anova" == fit$type){
156+
calculate_RMSE(fit, data_table, "in_test_set")
157+
} else { # fit$type == "class"
158+
calculate_AUC(fit, data_table, "in_test_set")
159+
},
160+
rxDTree = if ("anova" == fit$method){
161+
calculate_RMSE(fit, data_table, "in_test_set")
162+
} else { # "class"
163+
calculate_AUC(fit, data_table, "in_test_set")
164+
},
165+
rxLogit = calculate_AUC(fit, data_table, "in_test_set"),
166+
SDCA = calculate_AUC(fit, data_table, "in_test_set"),
167+
#rxFastLinear, class = SDCA (BinaryClassifierTrainer)
168+
SDCAR = calculate_RMSE(fit, data_table, "in_test_set")
169+
# rxFastLinear, class = SDCAR (RegressorTrainer)
170+
)
171+
}
172+
173+
get_tss <- function(fit){
174+
switch( class(fit)[[1]],
175+
rxLinMod = ,
176+
rxLogit = fit$nValidObs,
177+
rxDTree = fit$valid.obs,
178+
rxBTrees =,
179+
rxDForest =,
180+
SDCA =,
181+
SDCAR = training_fraction * (1 - 1/KFOLDS) * rxGetInfo(data_table)$numRows
182+
)
183+
}
184+
185+
train_time <- system.time(
186+
fit <- learner(as.formula(with_formula), data_table,
187+
rowSelection=(in_training_set == TRUE),
188+
transformFunc=row_selection_transform,
189+
transformObjects=list(row_tagger=row_tagger, prob=training_fraction,
190+
kfold_id=test_set_kfold_id, kfolds=KFOLDS,
191+
salt=SALT),
192+
...)
193+
)[['elapsed']]
194+
195+
e1_time <- system.time(
196+
training_error <- get_training_error(fit)
197+
)[['elapsed']]
198+
199+
e2_time <- system.time(
200+
test_error <- get_test_error(fit)
201+
)[['elapsed']]
202+
203+
data.frame(tss=get_tss(fit), model_class=model_class, training=training_error, test=test_error,
204+
train_time=train_time, train_error_time=e1_time, test_error_time=e2_time,
205+
formula=with_formula, kfold=test_set_kfold_id, ...)
206+
207+
}
208+
209+
210+
create_formula <- function(outcome, varnames, interaction_pow=1){
211+
vars <- paste(setdiff(varnames, outcome), collapse=" + ")
212+
if (interaction_pow > 1) vars <- sprintf("(%s)^%d", vars, interaction_pow)
213+
sprintf("%s ~ %s", outcome, vars)
214+
}
215+
216+
#' get_training_fractions
217+
#' Create a vector of fractions of available training data to be used at the evaluation
218+
#' points of a learning curve.
219+
#' @param min_tss; target minimum training set size.
220+
#' @param max_tss: approximate maximum training set size. This is used to calculate the
221+
#' fraction used for the smallest point.
222+
#' @param num_tss: number of training set sizes.
223+
get_training_set_fractions <- function(min_tss, max_tss, num_tss)
224+
exp(seq(log(min_tss/max_tss), log(1), length=num_tss))

0 commit comments

Comments
 (0)