1
#' Convert Resampling Objects to Other Formats
2
#'
3
#' These functions can convert resampling objects between
4
#'  \pkg{rsample} and \pkg{caret}.
5
#'
6
#' @param object An `rset` object. Currently,
7
#'  `nested_cv` is not supported.
8
#' @return `rsample2caret` returns a list that mimics the
9
#'  `index` and `indexOut` elements of a
10
#'  `trainControl` object. `caret2rsample` returns an
11
#'  `rset` object of the appropriate class.
12
#' @export
13
rsample2caret <- function(object, data = c("analysis", "assessment")) {
14 0
  if(!inherits(object, "rset"))
15 0
    stop("`object` must be an `rset`", call. = FALSE)
16 0
  data <- match.arg(data)
17 0
  in_ind <- purrr::map(object$splits, as.integer, data = "analysis")
18 0
  names(in_ind) <- labels(object)
19 0
  out_ind <- purrr::map(object$splits, as.integer, data = "assessment")
20 0
  names(out_ind) <- names(in_ind)
21 0
  list(index = in_ind, indexOut = out_ind)
22
}
23

24
#' @rdname rsample2caret
25
#' @param ctrl An object produced by `trainControl` that has
26
#'  had the `index` and `indexOut` elements populated by
27
#'  integers. One method of getting this is to extract the
28
#'  `control` objects from an object produced by `train`.
29
#' @param data The data that was originally used to produce the
30
#'  `ctrl` object.
31
#' @export
32
caret2rsample <- function(ctrl, data = NULL) {
33 1
  if (is.null(data))
34 0
    stop("Must supply original data", call. = FALSE)
35 1
  if (!any(names(ctrl) == "index"))
36 0
    stop("`ctrl` should have an element `index`", call. = FALSE)
37 1
  if (!any(names(ctrl) == "indexOut"))
38 0
    stop("`ctrl` should have an element `indexOut`", call. = FALSE)
39 1
  if (is.null(ctrl$index))
40 0
    stop("`ctrl$index` should be populated with integers", call. = FALSE)
41 1
  if (is.null(ctrl$indexOut))
42 0
    stop("`ctrl$indexOut` should be populated with integers", call. = FALSE)
43

44 1
  indices <- purrr::map2(ctrl$index, ctrl$indexOut, extract_int)
45 1
  id_data <- names(indices)
46 1
  indices <- unname(indices)
47 1
  indices <- purrr::map(indices, add_data, y = data)
48 1
  indices <-
49 1
    map(indices, add_rsplit_class, cl = map_rsplit_method(ctrl$method))
50 1
  indices <- tibble::tibble(splits = indices)
51 1
  if (ctrl$method %in% c("repeatedcv", "adaptive_cv")) {
52 1
    id_data <- strsplit(id_data, split = ".", fixed = TRUE)
53 1
    id_data <- tibble::tibble(
54 1
      id  = vapply(id_data, function(x)
55 1
        x[2], character(1)),
56 1
      id2 = vapply(id_data, function(x)
57 1
        x[1], character(1))
58
    )
59
  } else {
60 1
    id_data <- tibble::tibble(id = id_data)
61
  }
62 1
  out <- dplyr::bind_cols(indices, id_data)
63 1
  attrib <- map_attr(ctrl)
64 1
  for (i in names(attrib))
65 1
    attr(out, i) <- attrib[[i]]
66 1
  out <- add_rset_class(out, map_rset_method(ctrl$method))
67 1
  out
68
}
69

70 1
extract_int <- function(x, y)
71 1
  list(in_id = x, out_id = y)
72

73 1
add_data <- function(x, y)
74 1
  c(list(data = y), x)
75

76
add_rsplit_class <- function(x, cl) {
77 1
  class(x) <- c("rsplit", cl)
78 1
  x
79
}
80

81
add_rset_class <- function(x, cl) {
82 1
  class(x) <- c(cl, "rset", "tbl_df", "tbl", "data.frame")
83 1
  x
84
}
85

86
map_rsplit_method <- function(method) {
87 1
  out <- switch(
88 1
    method,
89 1
    cv = , repeatedcv = , adaptive_cv = "vfold_split",
90 1
    boot = , boot_all =, boot632 = , optimism_boot = , adaptive_boot = "boot_split",
91 1
    LOOCV = "loo_split",
92 1
    LGOCV = , adaptive_LGOCV = "mc_split",
93 1
    timeSlice = "rof_split",
94 1
    "error"
95
  )
96 1
  if (out == "error")
97 0
    stop("Resampling method `",
98 0
         method,
99 0
         "` cannot be converted into an `rsplit` object",
100 0
         call. = FALSE)
101 1
  out
102
}
103

104
map_rset_method <- function(method) {
105 1
  out <- switch(
106 1
    method,
107 1
    cv = , repeatedcv = , adaptive_cv = "vfold_cv",
108 1
    boot = , boot_all =, boot632 = , optimism_boot = , adaptive_boot = "bootstraps",
109 1
    LOOCV = "loo_cv",
110 1
    LGOCV = , adaptive_LGOCV = "mc_cv",
111 1
    timeSlice = "rolling_origin",
112 1
    "error"
113
  )
114 1
  if (out == "error")
115 0
    stop("Resampling method `",
116 0
         method,
117 0
         "` cannot be converted into an `rset` object",
118 0
         call. = FALSE)
119 1
  out
120
}
121

122

123
map_attr <- function(object) {
124 1
  if (grepl("cv$", object$method)) {
125 1
    out <- list(v = object$number,
126 1
                repeats = ifelse(!is.na(object$repeats),
127 1
                             object$repeats, 1),
128 1
                strata = TRUE)
129 1
  } else if (grepl("boot", object$method)) {
130 1
    out <- list(times = object$number,
131 1
                apparent = FALSE,
132 1
                strata = FALSE)
133 1
  } else if (grepl("LGOCV$", object$method)) {
134 1
    out <- list(times = object$number,
135 1
                prop = object$p,
136 1
                strata = FALSE)
137 1
  } else if (object$method == "LOOCV") {
138 1
    out <- list()
139 1
  } else if (object$method == "timeSlice") {
140 1
    out <- list(
141 1
      initial = object$initialWindow,
142 1
      assess = object$horizon,
143 1
      cumulative = !object$fixedWindow,
144 1
      skip = object$skip
145
    )
146
  } else {
147 0
    stop("Method", object$method, "cannot be converted")
148
  }
149 1
  out
150
}
151

Read our documentation on viewing source code .

Loading