Compare f41cf7e ... +1 ... 4914774

Coverage Reach
tree-helpers.R mbar-prep-data.R utils.R

No flags found

Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.

e.g., #unittest #integration

#production #enterprise

#frontend #backend

Learn more about Codecov Flags here.


@@ -0,0 +1,157 @@
Loading
1 +
#' Optimal complexity parameter
2 +
#'
3 +
#' Extract optimal complexity parameter for tree pruning.
4 +
#'
5 +
#' @param model An object of class \code{rpart}.
6 +
#'
7 +
#' @examples
8 +
#' model <- rpart::rpart(Species ~ ., data = iris)
9 +
#' optimal_cp(model)
10 +
#'
11 +
#' @export
12 +
#'
13 +
optimal_cp <- function(model) {
14 +
15 +
  check_rpart(model)
16 +
17 +
  cp_table <-
18 +
    model %>%
19 +
    magrittr::use_series(cptable) %>%
20 +
    as.data.frame()
21 +
22 +
  xerror_min_index <-
23 +
    cp_table %>%
24 +
    dplyr::pull(xerror) %>%
25 +
    which.min()
26 +
27 +
  cp_table %>%
28 +
    dplyr::pull(CP) %>%
29 +
    magrittr::extract(xerror_min_index)
30 +
}
31 +
32 +
#' Class probability prediction
33 +
#'
34 +
#' Predict class probability on a data set.
35 +
#'
36 +
#' @param model An object of class \code{rpart}.
37 +
#' @param data A \code{data.frame} or \code{tibble}.
38 +
#' @param response Response variable.
39 +
#'
40 +
#' @examples
41 +
#' model <- rpart::rpart(Attrition ~ ., data = hr_train)
42 +
#' tree_prediction(model, hr_test, Attrition)
43 +
#'
44 +
#' @importFrom rlang !!
45 +
#'
46 +
#' @export
47 +
#'
48 +
tree_prediction <- function(model, data, response) {
49 +
50 +
  resp_var <- rlang::enquo(response)
51 +
  resp <- dplyr::pull(data, !!resp_var)
52 +
53 +
  model %>%
54 +
    stats::predict(newdata = data, type = "prob") %>%
55 +
    as.data.frame() %>%
56 +
    dplyr::select(2) %>%
57 +
    ROCR::prediction(resp)
58 +
}
59 +
60 +
#' AUC
61 +
#'
62 +
#' Area under the curve.
63 +
#'
64 +
#' @param perform An object of class \code{prediction}.
65 +
#'
66 +
#' @examples
67 +
#' model   <- rpart::rpart(Attrition ~ ., data = hr_train)
68 +
#' perform <- tree_prediction(model, hr_test, Attrition)
69 +
#' tree_auc(perform)
70 +
#'
71 +
#' @export
72 +
#'
73 +
tree_auc <- function(perform) {
74 +
  perform %>%
75 +
    ROCR::performance(measure = "auc") %>%
76 +
    methods::slot("y.values") %>%
77 +
    magrittr::extract2(1)
78 +
}
79 +
80 +
#' Plot validation curves
81 +
#'
82 +
#' Plot curves to validate decision tree model.
83 +
#'
84 +
#' @param perform An object of class \code{prediction}.
85 +
#' @param line_color Color of lines in the plot.
86 +
#'
87 +
#' @examples
88 +
#' model   <- rpart::rpart(Attrition ~ ., data = hr_train)
89 +
#' perform <- tree_prediction(model, hr_test, Attrition)
90 +
#' plot_roc(perform)
91 +
#' plot_prec_rec(perform)
92 +
#' plot_sens_spec(perform)
93 +
#' plot_lift_curve(perform)
94 +
#'
95 +
#' @export
96 +
#'
97 +
plot_roc <- function(perform, line_color = "blue") {
98 +
99 +
  plot_perform(perform, "tpr", "fpr") +
100 +
    ggplot2::xlab("False Positive Rate") +
101 +
    ggplot2::ylab("True Positive Rate") +
102 +
    ggplot2::ggtitle("ROC Curve")
103 +
}
104 +
105 +
#' @rdname plot_roc
106 +
#' @export
107 +
#'
108 +
plot_prec_rec <- function(perform, line_color = "blue") {
109 +
110 +
  plot_perform(perform, "prec", "rec") +
111 +
    ggplot2::xlab("Recall") + ggplot2::ylab("Precision") +
112 +
    ggplot2::ggtitle("Precision Recall Curve")
113 +
114 +
}
115 +
116 +
#' @rdname plot_roc
117 +
#' @export
118 +
#'
119 +
plot_sens_spec <- function(perform, line_color = "blue") {
120 +
121 +
  plot_perform(perform, "sens", "spec") +
122 +
    ggplot2::xlab("Specificity") + ggplot2::ylab("Sensitivity") +
123 +
    ggplot2::ggtitle("Sensitivity Specificity Curve")
124 +
125 +
}
126 +
127 +
#' @rdname plot_roc
128 +
#' @export
129 +
#'
130 +
plot_lift_curve <- function(perform, line_color = "blue") {
131 +
132 +
  plot_perform(perform, "lift", "rpp") +
133 +
    ggplot2::xlab("Rate of Positive Predictions") +
134 +
    ggplot2::ylab("Lift Value") +
135 +
    ggplot2::ggtitle("Lift Curve")
136 +
137 +
}
138 +
139 +
plot_perform <- function(perform, y, x, line_color = "blue") {
140 +
141 +
	measures <- ROCR::performance(perform, measure = y, x.measure = x)
142 +
143 +
  yval <-
144 +
    measures %>%
145 +
    methods::slot("y.values") %>%
146 +
    magrittr::extract2(1)
147 +
148 +
  xval <-
149 +
    measures %>%
150 +
    methods::slot("x.values") %>%
151 +
    magrittr::extract2(1)
152 +
153 +
  data.frame(yval, xval) %>%
154 +
    ggplot2::ggplot() +
155 +
    ggplot2::geom_line(ggplot2::aes(x = xval, y = yval), color = line_color)
156 +
157 +
}

@@ -0,0 +1,6 @@
Loading
1 +
check_rpart <- function(model) {
2 +
  model_class <- class(model)
3 +
  if (model_class != "rpart") {
4 +
    stop("model must be an object of class rpart")
5 +
  }
6 +
}

Learn more Showing 2 files with coverage changes found.

New file R/utils.R
New
Loading file...
New file R/tree-helpers.R
New
Loading file...
Files Coverage
R 100.00%
Project Totals (3 files) 100.00%
Loading