1
#' Group V-Fold Cross-Validation
2
#'
3
#' Group V-fold cross-validation creates splits of the data based
4
#'  on some grouping variable (which may have more than a single row
5
#'  associated with it). The function can create as many splits as
6
#'  there are unique values of the grouping variable or it can
7
#'  create a smaller set of splits where more than one value is left
8
#'  out at a time.
9
#'
10
#' @param data A data frame.
11
#' @param group This could be a single character value or a variable
12
#'  name that corresponds to a variable that exists in the data frame.
13
#' @param v The number of partitions of the data set. If let
14
#'  `NULL`, `v` will be set to the number of unique values
15
#'  in the group.
16
#' @param ... Not currently used.
17
#' @export
18
#' @return A tibble with classes `group_vfold_cv`,
19
#'  `rset`, `tbl_df`, `tbl`, and `data.frame`.
20
#'  The results include a column for the data split objects and an
21
#'  identification variable.
22
#' @examples
23
#' set.seed(3527)
24
#' test_data <- data.frame(id = sort(sample(1:20, size = 80, replace = TRUE)))
25
#' test_data$dat <- runif(nrow(test_data))
26
#'
27
#' set.seed(5144)
28
#' split_by_id <- group_vfold_cv(test_data, group = "id")
29
#'
30
#' get_id_left_out <- function(x)
31
#'   unique(assessment(x)$id)
32
#'
33
#' library(purrr)
34
#' table(map_int(split_by_id$splits, get_id_left_out))
35
#'
36
#' set.seed(5144)
37
#' split_by_some_id <- group_vfold_cv(test_data, group = "id", v = 7)
38
#' held_out <- map(split_by_some_id$splits, get_id_left_out)
39
#' table(unlist(held_out))
40
#' # number held out per resample:
41
#' map_int(held_out, length)
42
#' @export
43
group_vfold_cv <- function(data, group = NULL, v = NULL, ...) {
44

45 1
  if(!missing(group)) {
46 1
    group <- tidyselect::vars_select(names(data), !!enquo(group))
47 1
    if(length(group) == 0) {
48 0
      group <- NULL
49
    }
50
  }
51

52 1
  if (is.null(group) || !is.character(group) || length(group) != 1)
53 1
    stop(
54 1
      "`group` should be a single character value for the column ",
55 1
      "that will be used for splitting.",
56 1
      call. = FALSE
57
    )
58 1
  if (!any(names(data) == group))
59 0
    stop("`group` should be a column in `data`.", call. = FALSE)
60

61 1
  split_objs <- group_vfold_splits(data = data, group = group, v = v)
62

63
  ## We remove the holdout indices since it will save space and we can
64
  ## derive them later when they are needed.
65

66 1
  split_objs$splits <- map(split_objs$splits, rm_out)
67

68
  # Update `v` if not supplied directly
69 1
  if (is.null(v)) {
70 1
    v <- length(split_objs$splits)
71
  }
72

73
  ## Save some overall information
74

75 1
  cv_att <- list(v = v, group = group)
76

77 1
  new_rset(splits = split_objs$splits,
78 1
           ids = split_objs[, grepl("^id", names(split_objs))],
79 1
           attrib = cv_att,
80 1
           subclass = c("group_vfold_cv", "rset"))
81
}
82

83
group_vfold_splits <- function(data, group, v = NULL) {
84 1
  uni_groups <- unique(getElement(data, group))
85 1
  max_v <- length(uni_groups)
86

87 1
  if (is.null(v)) {
88 1
    v <- max_v
89
  } else {
90 1
    if (v > max_v)
91 1
      stop("`v` should be less than ", max_v, call. = FALSE)
92
  }
93 1
  data_ind <- data.frame(..index = 1:nrow(data), ..group = getElement(data, group))
94 1
  keys <- data.frame(..group = uni_groups)
95

96 1
  n <- nrow(keys)
97 1
  keys$..folds <- sample(rep(1:v, length.out = n))
98 1
  data_ind <- data_ind %>%
99 1
    full_join(keys, by = "..group") %>%
100 1
    arrange(..index)
101 1
  indices <- split_unnamed(data_ind$..index, data_ind$..folds)
102 1
  indices <- lapply(indices, vfold_complement, n = nrow(data))
103 1
  split_objs <-
104 1
    purrr::map(indices,
105 1
               make_splits,
106 1
               data = data,
107 1
               class = "group_vfold_split")
108 1
  tibble::tibble(splits = split_objs,
109 1
                 id = names0(length(split_objs), "Resample"))
110
}
111

112
#' @export
113
print.group_vfold_cv <- function(x, ...) {
114 1
  cat("#", pretty(x), "\n")
115 1
  class(x) <- class(x)[!(class(x) %in% c("group_vfold_cv", "rset"))]
116 1
  print(x, ...)
117
}

Read our documentation on viewing source code .

Loading