Help documentation updates
Update slide.R help doc
1 |
#' Group V-Fold Cross-Validation
|
|
2 |
#'
|
|
3 |
#' Group V-fold cross-validation creates splits of the data based
|
|
4 |
#' on some grouping variable (which may have more than a single row
|
|
5 |
#' associated with it). The function can create as many splits as
|
|
6 |
#' there are unique values of the grouping variable or it can
|
|
7 |
#' create a smaller set of splits where more than one value is left
|
|
8 |
#' out at a time.
|
|
9 |
#'
|
|
10 |
#' @param data A data frame.
|
|
11 |
#' @param group This could be a single character value or a variable
|
|
12 |
#' name that corresponds to a variable that exists in the data frame.
|
|
13 |
#' @param v The number of partitions of the data set. If let
|
|
14 |
#' `NULL`, `v` will be set to the number of unique values
|
|
15 |
#' in the group.
|
|
16 |
#' @param ... Not currently used.
|
|
17 |
#' @export
|
|
18 |
#' @return A tibble with classes `group_vfold_cv`,
|
|
19 |
#' `rset`, `tbl_df`, `tbl`, and `data.frame`.
|
|
20 |
#' The results include a column for the data split objects and an
|
|
21 |
#' identification variable.
|
|
22 |
#' @examples
|
|
23 |
#' set.seed(3527)
|
|
24 |
#' test_data <- data.frame(id = sort(sample(1:20, size = 80, replace = TRUE)))
|
|
25 |
#' test_data$dat <- runif(nrow(test_data))
|
|
26 |
#'
|
|
27 |
#' set.seed(5144)
|
|
28 |
#' split_by_id <- group_vfold_cv(test_data, group = "id")
|
|
29 |
#'
|
|
30 |
#' get_id_left_out <- function(x)
|
|
31 |
#' unique(assessment(x)$id)
|
|
32 |
#'
|
|
33 |
#' library(purrr)
|
|
34 |
#' table(map_int(split_by_id$splits, get_id_left_out))
|
|
35 |
#'
|
|
36 |
#' set.seed(5144)
|
|
37 |
#' split_by_some_id <- group_vfold_cv(test_data, group = "id", v = 7)
|
|
38 |
#' held_out <- map(split_by_some_id$splits, get_id_left_out)
|
|
39 |
#' table(unlist(held_out))
|
|
40 |
#' # number held out per resample:
|
|
41 |
#' map_int(held_out, length)
|
|
42 |
#' @export
|
|
43 |
group_vfold_cv <- function(data, group = NULL, v = NULL, ...) { |
|
44 |
|
|
45 | 1 |
if(!missing(group)) { |
46 | 1 |
group <- tidyselect::vars_select(names(data), !!enquo(group)) |
47 | 1 |
if(length(group) == 0) { |
48 |
group <- NULL |
|
49 |
}
|
|
50 |
}
|
|
51 |
|
|
52 | 1 |
if (is.null(group) || !is.character(group) || length(group) != 1) |
53 | 1 |
stop( |
54 | 1 |
"`group` should be a single character value for the column ", |
55 | 1 |
"that will be used for splitting.", |
56 | 1 |
call. = FALSE |
57 |
)
|
|
58 | 1 |
if (!any(names(data) == group)) |
59 |
stop("`group` should be a column in `data`.", call. = FALSE) |
|
60 |
|
|
61 | 1 |
split_objs <- group_vfold_splits(data = data, group = group, v = v) |
62 |
|
|
63 |
## We remove the holdout indices since it will save space and we can
|
|
64 |
## derive them later when they are needed.
|
|
65 |
|
|
66 | 1 |
split_objs$splits <- map(split_objs$splits, rm_out) |
67 |
|
|
68 |
# Update `v` if not supplied directly
|
|
69 | 1 |
if (is.null(v)) { |
70 | 1 |
v <- length(split_objs$splits) |
71 |
}
|
|
72 |
|
|
73 |
## Save some overall information
|
|
74 |
|
|
75 | 1 |
cv_att <- list(v = v, group = group) |
76 |
|
|
77 | 1 |
new_rset(splits = split_objs$splits, |
78 | 1 |
ids = split_objs[, grepl("^id", names(split_objs))], |
79 | 1 |
attrib = cv_att, |
80 | 1 |
subclass = c("group_vfold_cv", "rset")) |
81 |
}
|
|
82 |
|
|
83 |
group_vfold_splits <- function(data, group, v = NULL) { |
|
84 | 1 |
uni_groups <- unique(getElement(data, group)) |
85 | 1 |
max_v <- length(uni_groups) |
86 |
|
|
87 | 1 |
if (is.null(v)) { |
88 | 1 |
v <- max_v
|
89 |
} else { |
|
90 | 1 |
if (v > max_v) |
91 | 1 |
stop("`v` should be less than ", max_v, call. = FALSE) |
92 |
}
|
|
93 | 1 |
data_ind <- data.frame(..index = 1:nrow(data), ..group = getElement(data, group)) |
94 | 1 |
keys <- data.frame(..group = uni_groups) |
95 |
|
|
96 | 1 |
n <- nrow(keys) |
97 | 1 |
keys$..folds <- sample(rep(1:v, length.out = n)) |
98 | 1 |
data_ind <- data_ind %>% |
99 | 1 |
full_join(keys, by = "..group") %>% |
100 | 1 |
arrange(..index) |
101 | 1 |
indices <- split_unnamed(data_ind$..index, data_ind$..folds) |
102 | 1 |
indices <- lapply(indices, vfold_complement, n = nrow(data)) |
103 | 1 |
split_objs <-
|
104 | 1 |
purrr::map(indices, |
105 | 1 |
make_splits,
|
106 | 1 |
data = data, |
107 | 1 |
class = "group_vfold_split") |
108 | 1 |
tibble::tibble(splits = split_objs, |
109 | 1 |
id = names0(length(split_objs), "Resample")) |
110 |
}
|
|
111 |
|
|
112 |
#' @export
|
|
113 |
print.group_vfold_cv <- function(x, ...) { |
|
114 | 1 |
cat("#", pretty(x), "\n") |
115 | 1 |
class(x) <- class(x)[!(class(x) %in% c("group_vfold_cv", "rset"))] |
116 | 1 |
print(x, ...) |
117 |
}
|
Read our documentation on viewing source code .