1 ```#' V-Fold Cross-Validation ``` 2 ```#' ``` 3 ```#' V-fold cross-validation randomly splits the data into V groups of roughly ``` 4 ```#' equal size (called "folds"). A resample of the analysis data consisted of ``` 5 ```#' V-1 of the folds while the assessment set contains the final fold. In basic ``` 6 ```#' V-fold cross-validation (i.e. no repeats), the number of resamples is equal ``` 7 ```#' to V. ``` 8 ```#' @details ``` 9 ```#' The `strata` argument causes the random sampling to be conducted *within ``` 10 ```#' the stratification variable*. This can help ensure that the number of data ``` 11 ```#' points in the analysis data is equivalent to the proportions in the original ``` 12 ```#' data set. (Strata below 10% of the total are pooled together.) ``` 13 ```#' When more than one repeat is requested, the basic V-fold cross-validation ``` 14 ```#' is conducted each time. For example, if three repeats are used with `v = ``` 15 ```#' 10`, there are a total of 30 splits which as three groups of 10 that are ``` 16 ```#' generated separately. ``` 17 ```#' @param data A data frame. ``` 18 ```#' @param v The number of partitions of the data set. ``` 19 ```#' @param repeats The number of times to repeat the V-fold partitioning. ``` 20 ```#' @param strata A variable that is used to conduct stratified sampling to ``` 21 ```#' create the folds. This could be a single character value or a variable name ``` 22 ```#' that corresponds to a variable that exists in the data frame. ``` 23 ```#' @param breaks A single number giving the number of bins desired to stratify ``` 24 ```#' a numeric stratification variable. ``` 25 ```#' @param ... Not currently used. ``` 26 ```#' @export ``` 27 ```#' @return A tibble with classes `vfold_cv`, `rset`, `tbl_df`, `tbl`, and ``` 28 ```#' `data.frame`. The results include a column for the data split objects and ``` 29 ```#' one or more identification variables. For a single repeat, there will be ``` 30 ```#' one column called `id` that has a character string with the fold identifier. ``` 31 ```#' For repeats, `id` is the repeat number and an additional column called `id2` ``` 32 ```#' that contains the fold information (within repeat). ``` 33 34 ```#' @examples ``` 35 ```#' vfold_cv(mtcars, v = 10) ``` 36 ```#' vfold_cv(mtcars, v = 10, repeats = 2) ``` 37 ```#' ``` 38 ```#' library(purrr) ``` 39 ```#' data(wa_churn, package = "modeldata") ``` 40 ```#' ``` 41 ```#' set.seed(13) ``` 42 ```#' folds1 <- vfold_cv(wa_churn, v = 5) ``` 43 ```#' map_dbl(folds1\$splits, ``` 44 ```#' function(x) { ``` 45 ```#' dat <- as.data.frame(x)\$churn ``` 46 ```#' mean(dat == "Yes") ``` 47 ```#' }) ``` 48 ```#' ``` 49 ```#' set.seed(13) ``` 50 ```#' folds2 <- vfold_cv(wa_churn, strata = "churn", v = 5) ``` 51 ```#' map_dbl(folds2\$splits, ``` 52 ```#' function(x) { ``` 53 ```#' dat <- as.data.frame(x)\$churn ``` 54 ```#' mean(dat == "Yes") ``` 55 ```#' }) ``` 56 ```#' ``` 57 ```#' set.seed(13) ``` 58 ```#' folds3 <- vfold_cv(wa_churn, strata = "tenure", breaks = 6, v = 5) ``` 59 ```#' map_dbl(folds3\$splits, ``` 60 ```#' function(x) { ``` 61 ```#' dat <- as.data.frame(x)\$churn ``` 62 ```#' mean(dat == "Yes") ``` 63 ```#' }) ``` 64 ```#' @export ``` 65 ```vfold_cv <- function(data, v = 10, repeats = 1, strata = NULL, breaks = 4, ...) { ``` 66 67 1 ``` if(!missing(strata)) { ``` 68 1 ``` strata <- tidyselect::vars_select(names(data), !!enquo(strata)) ``` 69 0 ``` if(length(strata) == 0) strata <- NULL ``` 70 ``` } ``` 71 72 1 ``` strata_check(strata, names(data)) ``` 73 74 1 ``` if (repeats == 1) { ``` 75 1 ``` split_objs <- vfold_splits(data = data, v = v, strata = strata, breaks = breaks) ``` 76 ``` } else { ``` 77 1 ``` for (i in 1:repeats) { ``` 78 1 ``` tmp <- vfold_splits(data = data, v = v, strata = strata) ``` 79 1 ``` tmp\$id2 <- tmp\$id ``` 80 1 ``` tmp\$id <- names0(repeats, "Repeat")[i] ``` 81 1 ``` split_objs <- if (i == 1) ``` 82 1 ``` tmp ``` 83 ``` else ``` 84 1 ``` rbind(split_objs, tmp) ``` 85 ``` } ``` 86 ``` } ``` 87 88 ``` ## We remove the holdout indices since it will save space and we can ``` 89 ``` ## derive them later when they are needed. ``` 90 91 1 ``` split_objs\$splits <- map(split_objs\$splits, rm_out) ``` 92 93 ``` ## Save some overall information ``` 94 95 1 ``` cv_att <- list(v = v, repeats = repeats, strata = !is.null(strata)) ``` 96 97 1 ``` new_rset(splits = split_objs\$splits, ``` 98 1 ``` ids = split_objs[, grepl("^id", names(split_objs))], ``` 99 1 ``` attrib = cv_att, ``` 100 1 ``` subclass = c("vfold_cv", "rset")) ``` 101 ```} ``` 102 103 ```# Get the indices of the analysis set from the assessment set ``` 104 ```vfold_complement <- function(ind, n) { ``` 105 1 ``` list(analysis = setdiff(1:n, ind), ``` 106 1 ``` assessment = ind) ``` 107 ```} ``` 108 109 ```vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4) { ``` 110 1 ``` if (!is.numeric(v) || length(v) != 1) ``` 111 0 ``` stop("`v` must be a single integer.", call. = FALSE) ``` 112 113 1 ``` n <- nrow(data) ``` 114 1 ``` if (is.null(strata)) { ``` 115 1 ``` folds <- sample(rep(1:v, length.out = n)) ``` 116 1 ``` idx <- seq_len(n) ``` 117 1 ``` indices <- split_unnamed(idx, folds) ``` 118 ``` } else { ``` 119 1 ``` stratas <- tibble::tibble(idx = 1:n, ``` 120 1 ``` strata = make_strata(getElement(data, strata), ``` 121 1 ``` breaks = breaks)) ``` 122 1 ``` stratas <- split_unnamed(stratas, stratas\$strata) ``` 123 1 ``` stratas <- purrr::map(stratas, add_vfolds, v = v) ``` 124 1 ``` stratas <- dplyr::bind_rows(stratas) ``` 125 1 ``` indices <- split_unnamed(stratas\$idx, stratas\$folds) ``` 126 ``` } ``` 127 128 1 ``` indices <- lapply(indices, vfold_complement, n = n) ``` 129 130 1 ``` split_objs <- purrr::map(indices, make_splits, data = data, class = "vfold_split") ``` 131 1 ``` tibble::tibble(splits = split_objs, ``` 132 1 ``` id = names0(length(split_objs), "Fold")) ``` 133 ```} ``` 134 135 ```add_vfolds <- function(x, v) { ``` 136 1 ``` x\$folds <- sample(rep(1:v, length.out = nrow(x))) ``` 137 1 ``` x ``` 138 ```} ``` 139 140 ```#' @export ``` 141 ```print.vfold_cv <- function(x, ...) { ``` 142 1 ``` cat("# ", pretty(x), "\n") ``` 143 1 ``` class(x) <- class(x)[!(class(x) %in% c("vfold_cv", "rset"))] ``` 144 1 ``` print(x, ...) ``` 145 ```} ```

Read our documentation on viewing source code .