1
#' Simple Training/Test Set Splitting
2
#'
3
#' `initial_split` creates a single binary split of the data into a training
4
#'  set and testing set. `initial_time_split` does the same, but takes the
5
#'  _first_ `prop` samples for training, instead of a random selection.
6
#'  `training` and `testing` are used to extract the resulting data.
7
#' @details The `strata` argument causes the random sampling to be conducted
8
#'  *within the stratification variable*. This can help ensure that the number
9
#'  of data points in the training data is equivalent to the proportions in the
10
#'  original data set. (Strata below 10% of the total are pooled together.)
11
#' @inheritParams vfold_cv
12
#' @param prop The proportion of data to be retained for modeling/analysis.
13
#' @param strata A variable that is used to conduct stratified sampling to
14
#'  create the resamples. This could be a single character value or a variable
15
#'  name that corresponds to a variable that exists in the data frame.
16
#' @param breaks A single number giving the number of bins desired to stratify
17
#'  a numeric stratification variable.
18
#' @export
19
#' @return An `rsplit` object that can be used with the `training` and `testing`
20
#'  functions to extract the data in each split.
21
#' @examples
22
#' set.seed(1353)
23
#' car_split <- initial_split(mtcars)
24
#' train_data <- training(car_split)
25
#' test_data <- testing(car_split)
26
#'
27
#' data(drinks, package = "modeldata")
28
#' drinks_split <- initial_time_split(drinks)
29
#' train_data <- training(drinks_split)
30
#' test_data <- testing(drinks_split)
31
#' c(max(train_data$date), min(test_data$date))  # no lag
32
#'
33
#' # With 12 period lag
34
#' drinks_lag_split <- initial_time_split(drinks, lag = 12)
35
#' train_data <- training(drinks_lag_split)
36
#' test_data <- testing(drinks_lag_split)
37
#' c(max(train_data$date), min(test_data$date))  # 12 period lag
38
#'
39
#' @export
40
#'
41
initial_split <- function(data, prop = 3/4, strata = NULL, breaks = 4, ...) {
42

43 1
  if (!missing(strata)) {
44 0
    strata <- tidyselect::vars_select(names(data), !!enquo(strata))
45 0
    if (length(strata) == 0) {
46 0
      strata <- NULL
47
    }
48
  }
49

50 1
  res <-
51 1
    mc_cv(
52 1
      data = data,
53 1
      prop = prop,
54 1
      strata = strata,
55 1
      breaks = breaks,
56 1
      times = 1,
57
      ...
58
    )
59 1
  res$splits[[1]]
60
}
61

62
#' @rdname initial_split
63
#' @param lag A value to include a lag between the assessment
64
#'  and analysis set. This is useful if lagged predictors will be used
65
#'  during training and testing.
66
#' @export
67
initial_time_split <- function(data, prop = 3/4, lag = 0, ...) {
68

69 1
  if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
70 0
    rlang::abort("`prop` must be a number on (0, 1).")
71
  }
72

73 1
  if (!is.numeric(lag) | !(lag%%1==0)) {
74 1
    stop("`lag` must be a whole number.", call. = FALSE)
75
  }
76

77 1
  n_train <- floor(nrow(data) * prop)
78

79 1
  if (lag > n_train) {
80 0
    stop("`lag` must be less than or equal to the number of training observations.", call. = FALSE)
81
  }
82

83 1
  rsplit(data, 1:n_train, (n_train + 1 - lag):nrow(data))
84
}
85

86
#' @rdname initial_split
87
#' @export
88
#' @param x An `rsplit` object produced by `initial_split`
89 1
training <- function(x) analysis(x)
90
#' @rdname initial_split
91
#' @export
92 1
testing <- function(x) assessment(x)

Read our documentation on viewing source code .

Loading