1
#' V-Fold Cross-Validation
2
#'
3
#' V-fold cross-validation randomly splits the data into V groups of roughly
4
#'  equal size (called "folds"). A resample of the analysis data consisted of
5
#'  V-1 of the folds while the assessment set contains the final fold. In basic
6
#'  V-fold cross-validation (i.e. no repeats), the number of resamples is equal
7
#'  to V.
8
#' @details
9
#' The `strata` argument causes the random sampling to be conducted *within
10
#'  the stratification variable*. This can help ensure that the number of data
11
#'  points in the analysis data is equivalent to the proportions in the original
12
#'  data set. (Strata below 10% of the total are pooled together.)
13
#' When more than one repeat is requested, the basic V-fold cross-validation
14
#'  is conducted each time. For example, if three repeats are used with `v =
15
#'  10`, there are a total of 30 splits which as three groups of 10 that are
16
#'  generated separately.
17
#' @param data A data frame.
18
#' @param v The number of partitions of the data set.
19
#' @param repeats The number of times to repeat the V-fold partitioning.
20
#' @param strata A variable that is used to conduct stratified sampling to
21
#'  create the folds. This could be a single character value or a variable name
22
#'  that corresponds to a variable that exists in the data frame.
23
#' @param breaks A single number giving the number of bins desired to stratify
24
#'  a numeric stratification variable.
25
#' @param ... Not currently used.
26
#' @export
27
#' @return A tibble with classes `vfold_cv`, `rset`, `tbl_df`, `tbl`, and
28
#'  `data.frame`. The results include a column for the data split objects and
29
#'  one or more identification variables. For a single repeat, there will be
30
#'  one column called `id` that has a character string with the fold identifier.
31
#'  For repeats, `id` is the repeat number and an additional column called `id2`
32
#'  that contains the fold information (within repeat).
33

34
#' @examples
35
#' vfold_cv(mtcars, v = 10)
36
#' vfold_cv(mtcars, v = 10, repeats = 2)
37
#'
38
#' library(purrr)
39
#' data(wa_churn, package = "modeldata")
40
#'
41
#' set.seed(13)
42
#' folds1 <- vfold_cv(wa_churn, v = 5)
43
#' map_dbl(folds1$splits,
44
#'         function(x) {
45
#'           dat <- as.data.frame(x)$churn
46
#'           mean(dat == "Yes")
47
#'         })
48
#'
49
#' set.seed(13)
50
#' folds2 <- vfold_cv(wa_churn, strata = "churn", v = 5)
51
#' map_dbl(folds2$splits,
52
#'         function(x) {
53
#'           dat <- as.data.frame(x)$churn
54
#'           mean(dat == "Yes")
55
#'         })
56
#'
57
#' set.seed(13)
58
#' folds3 <- vfold_cv(wa_churn, strata = "tenure", breaks = 6, v = 5)
59
#' map_dbl(folds3$splits,
60
#'         function(x) {
61
#'           dat <- as.data.frame(x)$churn
62
#'           mean(dat == "Yes")
63
#'         })
64
#' @export
65
vfold_cv <- function(data, v = 10, repeats = 1, strata = NULL, breaks = 4, ...) {
66

67 1
  if(!missing(strata)) {
68 1
    strata <- tidyselect::vars_select(names(data), !!enquo(strata))
69 0
    if(length(strata) == 0) strata <- NULL
70
  }
71

72 1
  strata_check(strata, names(data))
73

74 1
  if (repeats == 1) {
75 1
    split_objs <- vfold_splits(data = data, v = v, strata = strata, breaks = breaks)
76
  } else {
77 1
    for (i in 1:repeats) {
78 1
      tmp <- vfold_splits(data = data, v = v, strata = strata)
79 1
      tmp$id2 <- tmp$id
80 1
      tmp$id <- names0(repeats, "Repeat")[i]
81 1
      split_objs <- if (i == 1)
82 1
        tmp
83
      else
84 1
        rbind(split_objs, tmp)
85
    }
86
  }
87

88
  ## We remove the holdout indices since it will save space and we can
89
  ## derive them later when they are needed.
90

91 1
  split_objs$splits <- map(split_objs$splits, rm_out)
92

93
  ## Save some overall information
94

95 1
  cv_att <- list(v = v, repeats = repeats, strata = !is.null(strata))
96

97 1
  new_rset(splits = split_objs$splits,
98 1
           ids = split_objs[, grepl("^id", names(split_objs))],
99 1
           attrib = cv_att,
100 1
           subclass = c("vfold_cv", "rset"))
101
}
102

103
# Get the indices of the analysis set from the assessment set
104
vfold_complement <- function(ind, n) {
105 1
  list(analysis = setdiff(1:n, ind),
106 1
       assessment = ind)
107
}
108

109
vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4) {
110 1
  if (!is.numeric(v) || length(v) != 1)
111 0
    stop("`v` must be a single integer.", call. = FALSE)
112

113 1
  n <- nrow(data)
114 1
  if (is.null(strata)) {
115 1
    folds <- sample(rep(1:v, length.out = n))
116 1
    idx <- seq_len(n)
117 1
    indices <- split_unnamed(idx, folds)
118
  } else {
119 1
    stratas <- tibble::tibble(idx = 1:n,
120 1
                              strata = make_strata(getElement(data, strata),
121 1
                                                   breaks = breaks))
122 1
    stratas <- split_unnamed(stratas, stratas$strata)
123 1
    stratas <- purrr::map(stratas, add_vfolds, v = v)
124 1
    stratas <- dplyr::bind_rows(stratas)
125 1
    indices <- split_unnamed(stratas$idx, stratas$folds)
126
  }
127

128 1
  indices <- lapply(indices, vfold_complement, n = n)
129

130 1
  split_objs <- purrr::map(indices, make_splits, data = data, class = "vfold_split")
131 1
  tibble::tibble(splits = split_objs,
132 1
                 id = names0(length(split_objs), "Fold"))
133
}
134

135
add_vfolds <- function(x, v) {
136 1
  x$folds <- sample(rep(1:v, length.out = nrow(x)))
137 1
  x
138
}
139

140
#' @export
141
print.vfold_cv <- function(x, ...) {
142 1
  cat("# ", pretty(x), "\n")
143 1
  class(x) <- class(x)[!(class(x) %in% c("vfold_cv", "rset"))]
144 1
  print(x, ...)
145
}

Read our documentation on viewing source code .

Loading