tidymodels / rsample
 1 ```#' V-Fold Cross-Validation ``` 2 ```#' ``` 3 ```#' V-fold cross-validation randomly splits the data into V groups of roughly ``` 4 ```#' equal size (called "folds"). A resample of the analysis data consisted of ``` 5 ```#' V-1 of the folds while the assessment set contains the final fold. In basic ``` 6 ```#' V-fold cross-validation (i.e. no repeats), the number of resamples is equal ``` 7 ```#' to V. ``` 8 ```#' @details ``` 9 ```#' The `strata` argument causes the random sampling to be conducted *within ``` 10 ```#' the stratification variable*. This can help ensure that the number of data ``` 11 ```#' points in the analysis data is equivalent to the proportions in the original ``` 12 ```#' data set. (Strata below 10% of the total are pooled together by default.) ``` 13 ```#' When more than one repeat is requested, the basic V-fold cross-validation ``` 14 ```#' is conducted each time. For example, if three repeats are used with `v = ``` 15 ```#' 10`, there are a total of 30 splits which as three groups of 10 that are ``` 16 ```#' generated separately. ``` 17 ```#' @inheritParams make_strata ``` 18 ```#' @param data A data frame. ``` 19 ```#' @param v The number of partitions of the data set. ``` 20 ```#' @param repeats The number of times to repeat the V-fold partitioning. ``` 21 ```#' @param strata A variable that is used to conduct stratified sampling to ``` 22 ```#' create the folds. This could be a single character value or a variable name ``` 23 ```#' that corresponds to a variable that exists in the data frame. ``` 24 ```#' @param ... Not currently used. ``` 25 ```#' @export ``` 26 ```#' @return A tibble with classes `vfold_cv`, `rset`, `tbl_df`, `tbl`, and ``` 27 ```#' `data.frame`. The results include a column for the data split objects and ``` 28 ```#' one or more identification variables. For a single repeat, there will be ``` 29 ```#' one column called `id` that has a character string with the fold identifier. ``` 30 ```#' For repeats, `id` is the repeat number and an additional column called `id2` ``` 31 ```#' that contains the fold information (within repeat). ``` 32 33 ```#' @examples ``` 34 ```#' vfold_cv(mtcars, v = 10) ``` 35 ```#' vfold_cv(mtcars, v = 10, repeats = 2) ``` 36 ```#' ``` 37 ```#' library(purrr) ``` 38 ```#' data(wa_churn, package = "modeldata") ``` 39 ```#' ``` 40 ```#' set.seed(13) ``` 41 ```#' folds1 <- vfold_cv(wa_churn, v = 5) ``` 42 ```#' map_dbl(folds1\$splits, ``` 43 ```#' function(x) { ``` 44 ```#' dat <- as.data.frame(x)\$churn ``` 45 ```#' mean(dat == "Yes") ``` 46 ```#' }) ``` 47 ```#' ``` 48 ```#' set.seed(13) ``` 49 ```#' folds2 <- vfold_cv(wa_churn, strata = churn, v = 5) ``` 50 ```#' map_dbl(folds2\$splits, ``` 51 ```#' function(x) { ``` 52 ```#' dat <- as.data.frame(x)\$churn ``` 53 ```#' mean(dat == "Yes") ``` 54 ```#' }) ``` 55 ```#' ``` 56 ```#' set.seed(13) ``` 57 ```#' folds3 <- vfold_cv(wa_churn, strata = tenure, breaks = 6, v = 5) ``` 58 ```#' map_dbl(folds3\$splits, ``` 59 ```#' function(x) { ``` 60 ```#' dat <- as.data.frame(x)\$churn ``` 61 ```#' mean(dat == "Yes") ``` 62 ```#' }) ``` 63 ```#' @export ``` 64 ```vfold_cv <- function(data, v = 10, repeats = 1, ``` 65 ``` strata = NULL, breaks = 4, pool = 0.1, ...) { ``` 66 67 1 ``` if(!missing(strata)) { ``` 68 1 ``` strata <- tidyselect::vars_select(names(data), !!enquo(strata)) ``` 69 0 ``` if(length(strata) == 0) strata <- NULL ``` 70 ``` } ``` 71 72 1 ``` strata_check(strata, data) ``` 73 74 1 ``` if (repeats == 1) { ``` 75 1 ``` split_objs <- vfold_splits(data = data, v = v, ``` 76 1 ``` strata = strata, breaks = breaks, pool = pool) ``` 77 ``` } else { ``` 78 1 ``` for (i in 1:repeats) { ``` 79 1 ``` tmp <- vfold_splits(data = data, v = v, strata = strata, pool = pool) ``` 80 1 ``` tmp\$id2 <- tmp\$id ``` 81 1 ``` tmp\$id <- names0(repeats, "Repeat")[i] ``` 82 1 ``` split_objs <- if (i == 1) ``` 83 1 ``` tmp ``` 84 ``` else ``` 85 1 ``` rbind(split_objs, tmp) ``` 86 ``` } ``` 87 ``` } ``` 88 89 ``` ## We remove the holdout indices since it will save space and we can ``` 90 ``` ## derive them later when they are needed. ``` 91 92 1 ``` split_objs\$splits <- map(split_objs\$splits, rm_out) ``` 93 94 ``` ## Save some overall information ``` 95 96 1 ``` cv_att <- list(v = v, repeats = repeats, strata = !is.null(strata)) ``` 97 98 1 ``` new_rset(splits = split_objs\$splits, ``` 99 1 ``` ids = split_objs[, grepl("^id", names(split_objs))], ``` 100 1 ``` attrib = cv_att, ``` 101 1 ``` subclass = c("vfold_cv", "rset")) ``` 102 ```} ``` 103 104 105 ```vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1) { ``` 106 1 ``` if (!is.numeric(v) || length(v) != 1) ``` 107 0 ``` stop("`v` must be a single integer.", call. = FALSE) ``` 108 109 1 ``` n <- nrow(data) ``` 110 1 ``` if (is.null(strata)) { ``` 111 1 ``` folds <- sample(rep(1:v, length.out = n)) ``` 112 1 ``` idx <- seq_len(n) ``` 113 1 ``` indices <- split_unnamed(idx, folds) ``` 114 ``` } else { ``` 115 1 ``` stratas <- tibble::tibble(idx = 1:n, ``` 116 1 ``` strata = make_strata(getElement(data, strata), ``` 117 1 ``` breaks = breaks, ``` 118 1 ``` pool = pool)) ``` 119 1 ``` stratas <- split_unnamed(stratas, stratas\$strata) ``` 120 1 ``` stratas <- purrr::map(stratas, add_vfolds, v = v) ``` 121 1 ``` stratas <- dplyr::bind_rows(stratas) ``` 122 1 ``` indices <- split_unnamed(stratas\$idx, stratas\$folds) ``` 123 ``` } ``` 124 125 1 ``` indices <- lapply(indices, default_complement, n = n) ``` 126 127 1 ``` split_objs <- purrr::map(indices, make_splits, data = data, class = "vfold_split") ``` 128 1 ``` tibble::tibble(splits = split_objs, ``` 129 1 ``` id = names0(length(split_objs), "Fold")) ``` 130 ```} ``` 131 132 ```add_vfolds <- function(x, v) { ``` 133 1 ``` x\$folds <- sample(rep(1:v, length.out = nrow(x))) ``` 134 1 ``` x ``` 135 ```} ``` 136 137 ```#' @export ``` 138 ```print.vfold_cv <- function(x, ...) { ``` 139 1 ``` cat("# ", pretty(x), "\n") ``` 140 1 ``` class(x) <- class(x)[!(class(x) %in% c("vfold_cv", "rset"))] ``` 141 1 ``` print(x, ...) ``` 142 ```} ```

Read our documentation on viewing source code .