mlr-org / mlr3spatiotempcv

@@ -469,6 +469,364 @@
Loading
469 469
  print(autoplot(x, ...)) # nocov
470 470
}
471 471
472 +
# SpCVDisc ---------------------------------------------------------------------
473 +
474 +
#' @title Visualization Functions for SpCV Disc Method.
475 +
#'
476 +
#' @description Generic S3 `plot()` and `autoplot()` (ggplot2) methods to
477 +
#'   visualize mlr3 spatiotemporal resampling objects.
478 +
#'
479 +
#' @importFrom stats na.omit
480 +
#'
481 +
#' @details
482 +
#' This method requires to set argument `fold_id` and no plot containing all
483 +
#' partitions can be created. This is because the method does not make use of
484 +
#' all observations but only a subset of them (many observations are left out).
485 +
#' Hence, train and test sets of one fold are not re-used in other folds as in
486 +
#' other methods and plotting these without a train/test indicator would not
487 +
#' make sense.
488 +
#'
489 +
#' @section 2D vs 3D plotting:
490 +
#' This method has both a 2D and a 3D plotting method.
491 +
#' The 2D method returns a \pkg{ggplot} with x and y axes representing the spatial
492 +
#' coordinates.
493 +
#' The 3D method uses \pkg{plotly} to create an interactive 3D plot.
494 +
#' Set `plot3D = TRUE` to use the 3D method.
495 +
#'
496 +
#' Note that spatiotemporal datasets usually suffer from overplotting in 2D
497 +
#' mode.
498 +
#'
499 +
#' @name autoplot.ResamplingSpCVDisc
500 +
#' @inheritParams autoplot.ResamplingSpCVBlock
501 +
#'
502 +
#' @param show_omitted `[logical]`\cr
503 +
#'   Whether to show points not used in train or test set for the current fold.
504 +
#' @export
505 +
#' @seealso
506 +
#'   - mlr3book chapter on on ["Spatiotemporal Visualization"](https://mlr3book.mlr-org.com/spatiotemporal.html#vis-spt-partitions).
507 +
#'   - Vignette [Spatiotemporal Visualization](https://mlr3spatiotempcv.mlr-org.com/articles/spatiotemp-viz.html).
508 +
#'   - [autoplot.ResamplingSpCVBlock()]
509 +
#'   - [autoplot.ResamplingSpCVBuffer()]
510 +
#'   - [autoplot.ResamplingSpCVCoords()]
511 +
#'   - [autoplot.ResamplingSpCVEnv()]
512 +
#'   - [autoplot.ResamplingCV()]
513 +
#'   - [autoplot.ResamplingSptCVCluto()]
514 +
#' @examples
515 +
#' \donttest{
516 +
#' if (mlr3misc::require_namespaces("sf", quietly = TRUE)) {
517 +
#'   library(mlr3)
518 +
#'   library(mlr3spatiotempcv)
519 +
#'   task = tsk("ecuador")
520 +
#'   resampling = rsmp("spcv_disc",
521 +
#'     folds = 5, radius = 200L, buffer = 200L)
522 +
#'   resampling$instantiate(task)
523 +
#'
524 +
#'   autoplot(resampling, task,
525 +
#'     fold_id = 1, crs = 4326,
526 +
#'     show_omitted = TRUE, size = 0.7) *
527 +
#'     ggplot2::scale_x_continuous(breaks = seq(-79.085, -79.055, 0.01))
528 +
#' }
529 +
#' }
530 +
autoplot.ResamplingSpCVDisc = function( # nolint
531 +
  object,
532 +
  task,
533 +
  fold_id = NULL,
534 +
  plot_as_grid = TRUE,
535 +
  train_color = "#0072B5",
536 +
  test_color = "#E18727",
537 +
  crs = NULL,
538 +
  repeats_id = NULL,
539 +
  show_omitted = FALSE,
540 +
  ...) {
541 +
542 +
  resampling = object
543 +
  coords = task$coordinates()
544 +
  coords$row_id = task$row_ids
545 +
  mlr3misc::require_namespaces(c("sf", "patchwork", "ggtext"))
546 +
547 +
  # set fallback crs if missing
548 +
  if (is.null(crs)) {
549 +
    # use 4326 (WGS84) as fallback
550 +
    crs = 4326
551 +
    messagef("CRS not set, transforming to WGS84 (EPSG: 4326).")
552 +
  }
553 +
554 +
  resampling = assert_autoplot(resampling, fold_id, task)
555 +
556 +
  if (is.null(repeats_id)) {
557 +
    repeats_id = 1
558 +
  } else {
559 +
    repeats_id = repeats_id
560 +
    # otherwise it gets passed to geom_sf() down below
561 +
    # dots$repeats_id = NULL
562 +
  }
563 +
564 +
  resampling_sub = resampling$clone()
565 +
566 +
  if (grepl("Repeated", class(resampling)[1])) {
567 +
    resampling_sub$instance = resampling_sub$instance[[repeats_id]]
568 +
  }
569 +
570 +
  if (!is.null(fold_id)) {
571 +
572 +
    if (length(fold_id) == 1) {
573 +
      ### only one fold
574 +
575 +
      data_coords = prepare_autoplot_cstf(task, resampling_sub)
576 +
577 +
      # suppress undefined global variables note
578 +
      data_coords$indicator = ""
579 +
580 +
      row_id_test = resampling_sub$instance$test[[fold_id]]
581 +
      row_id_train = resampling_sub$instance$train[[fold_id]]
582 +
583 +
      data_coords[row_id %in% row_id_test, indicator := "Test"]
584 +
      data_coords[row_id %in% row_id_train, indicator := "Train"]
585 +
586 +
      # should omitted points be shown?
587 +
      if (show_omitted && nrow(data_coords[indicator == ""]) > 0) {
588 +
        data_coords[indicator == "", indicator := "Omitted"]
589 +
590 +
        sf_df = sf::st_transform(
591 +
          sf::st_as_sf(data_coords,
592 +
            coords = task$extra_args$coordinate_names,
593 +
            crs = task$extra_args$crs),
594 +
          crs = crs)
595 +
        sf_df = reorder_levels(sf_df)
596 +
597 +
        ggplot() +
598 +
          geom_sf(data = sf_df, aes(color = indicator), ...) +
599 +
          scale_color_manual(values = c(
600 +
            "Omitted" = "grey",
601 +
            "Train" = "#0072B5",
602 +
            "Test" = "#E18727"
603 +
          )) +
604 +
          labs(color = "Set", title = sprintf(
605 +
            "Fold %s, Repetition %s", fold_id,
606 +
            repeats_id)) +
607 +
          theme(plot.title = ggtext::element_textbox(
608 +
            size = 10,
609 +
            color = "black", fill = "#ebebeb", box.color = "black",
610 +
            height = unit(0.33, "inch"), width = unit(1, "npc"),
611 +
            linetype = 1, r = unit(5, "pt"),
612 +
            valign = 0.5, halign = 0.5,
613 +
            padding = margin(2, 2, 2, 2), margin = margin(3, 3, 3, 3))
614 +
          )
615 +
616 +
      } else {
617 +
        data_coords = data_coords[indicator != ""]
618 +
619 +
        sf_df = sf::st_transform(
620 +
          sf::st_as_sf(data_coords,
621 +
            coords = task$extra_args$coordinate_names,
622 +
            crs = task$extra_args$crs),
623 +
          crs = crs)
624 +
        sf_df = reorder_levels(sf_df)
625 +
626 +
        ggplot() +
627 +
          geom_sf(data = sf_df, aes(color = indicator), ...) +
628 +
          scale_color_manual(values = c(
629 +
            "Train" = "#0072B5",
630 +
            "Test" = "#E18727"
631 +
          )) +
632 +
          labs(color = "Set", title = sprintf(
633 +
            "Fold %s, Repetition %s", fold_id,
634 +
            repeats_id)) +
635 +
          theme(plot.title = ggtext::element_textbox(
636 +
            size = 10,
637 +
            color = "black", fill = "#ebebeb", box.color = "black",
638 +
            height = unit(0.33, "inch"), width = unit(1, "npc"),
639 +
            linetype = 1, r = unit(5, "pt"),
640 +
            valign = 0.5, halign = 0.5,
641 +
            padding = margin(2, 2, 2, 2), margin = margin(3, 3, 3, 3))
642 +
          )
643 +
      }
644 +
    }
645 +
    else {
646 +
      ### Multiplot of multiple partitions with train and test set
647 +
648 +
      # FIXME: redundant code - call function from single plots?
649 +
      plot_list = mlr3misc::map(fold_id, function(.x) {
650 +
651 +
        data_coords = prepare_autoplot_cstf(task, resampling_sub)
652 +
653 +
        # suppress undefined global variables note
654 +
        data_coords$indicator = ""
655 +
656 +
        row_id_test = resampling_sub$instance$test[[.x]]
657 +
        row_id_train = resampling_sub$instance$train[[.x]]
658 +
659 +
        data_coords[row_id %in% row_id_test, indicator := "Test"]
660 +
        data_coords[row_id %in% row_id_train, indicator := "Train"]
661 +
662 +
        # should omitted points be shown?
663 +
        if (show_omitted && nrow(data_coords[indicator == ""]) > 0) {
664 +
          data_coords[indicator == "", indicator := "Omitted"]
665 +
666 +
          sf_df = sf::st_transform(
667 +
            sf::st_as_sf(data_coords,
668 +
              coords = task$extra_args$coordinate_names,
669 +
              crs = task$extra_args$crs),
670 +
            crs = crs)
671 +
          sf_df = reorder_levels(sf_df)
672 +
673 +
          ggplot() +
674 +
            geom_sf(data = sf_df, aes(color = indicator), ...) +
675 +
            scale_color_manual(values = c(
676 +
              "Omitted" = "grey",
677 +
              "Train" = "#0072B5",
678 +
              "Test" = "#E18727"
679 +
            )) +
680 +
            labs(color = "Set", title = sprintf(
681 +
              "Fold %s, Repetition %s", .x,
682 +
              repeats_id)) +
683 +
            theme(plot.title = ggtext::element_textbox(
684 +
              size = 10,
685 +
              color = "black", fill = "#ebebeb", box.color = "black",
686 +
              height = unit(0.33, "inch"), width = unit(1, "npc"),
687 +
              linetype = 1, r = unit(5, "pt"),
688 +
              valign = 0.5, halign = 0.5,
689 +
              padding = margin(2, 2, 2, 2), margin = margin(3, 3, 3, 3))
690 +
            )
691 +
692 +
        } else {
693 +
          data_coords = data_coords[indicator != ""]
694 +
695 +
          sf_df = sf::st_transform(
696 +
            sf::st_as_sf(data_coords,
697 +
              coords = task$extra_args$coordinate_names,
698 +
              crs = task$extra_args$crs),
699 +
            crs = crs)
700 +
          sf_df = reorder_levels(sf_df)
701 +
702 +
          ggplot() +
703 +
            geom_sf(data = sf_df, aes(color = indicator), ...) +
704 +
            scale_color_manual(values = c(
705 +
              "Train" = "#0072B5",
706 +
              "Test" = "#E18727"
707 +
            )) +
708 +
            labs(color = "Set", title = sprintf(
709 +
              "Fold %s, Repetition %s", .x,
710 +
              repeats_id)) +
711 +
            theme(plot.title = ggtext::element_textbox(
712 +
              size = 10,
713 +
              color = "black", fill = "#ebebeb", box.color = "black",
714 +
              height = unit(0.33, "inch"), width = unit(1, "npc"),
715 +
              linetype = 1, r = unit(5, "pt"),
716 +
              valign = 0.5, halign = 0.5,
717 +
              padding = margin(2, 2, 2, 2), margin = margin(3, 3, 3, 3))
718 +
            )
719 +
        }
720 +
      })
721 +
722 +
      # Return a plot grid via patchwork?
723 +
724 +
      if (!plot_as_grid) {
725 +
        return(invisible(plot_list))
726 +
      } else {
727 +
        # for repeated cv we also print out the rep number
728 +
        if (is.null(repeats_id)) {
729 +
          repeats_id = 1 # nocov
730 +
        }
731 +
732 +
        plot_list_pw = patchwork::wrap_plots(plot_list) +
733 +
          patchwork::plot_layout(guides = "collect")
734 +
        return(plot_list_pw)
735 +
      }
736 +
    }
737 +
  } else {
738 +
739 +
    ### Create one plot colored by all test folds
740 +
741 +
    # set fallback crs if missing
742 +
    if (is.null(crs)) {
743 +
      # use 4326 (WGS84) as fallback
744 +
      crs = 4326
745 +
      messagef("CRS not set, transforming to WGS84 (EPSG: 4326).")
746 +
    }
747 +
748 +
    data_coords = prepare_autoplot_cstf(task, resampling_sub)
749 +
750 +
    # extract test ids from lists
751 +
    row_ids_test = data.table::rbindlist(
752 +
      lapply(resampling_sub$instance$test, as.data.table),
753 +
      idcol = "fold")
754 +
    setnames(row_ids_test, c("fold", "row_id"))
755 +
756 +
    test_folds = merge(data_coords, row_ids_test, by = "row_id", all = TRUE)
757 +
758 +
    sf_df = sf::st_transform(
759 +
      sf::st_as_sf(test_folds,
760 +
        coords = task$extra_args$coordinate_names,
761 +
        crs = task$extra_args$crs),
762 +
      crs = crs)
763 +
764 +
    # only keep test ids
765 +
    sf_df = stats::na.omit(sf_df, cols = "fold")
766 +
767 +
    # order fold ids
768 +
    sf_df = sf_df[order(sf_df$fold, decreasing = FALSE), ]
769 +
    sf_df$fold = as.factor(as.character(sf_df$fold))
770 +
    sf_df$fold = factor(sf_df$fold, levels = unique(as.character(sf_df$fold)))
771 +
772 +
    # for all non-repeated rsmp cases
773 +
    if (is.null(repeats_id)) {
774 +
      repeats_id = 1 # nocov
775 +
    }
776 +
777 +
    plot = ggplot() +
778 +
      geom_sf(
779 +
        data = sf_df["fold"], show.legend = "point",
780 +
        aes(color = fold)
781 +
      ) +
782 +
      ggsci::scale_color_ucscgb() +
783 +
      labs(color = sprintf("Partition #, Rep %s", repeats_id))
784 +
    return(plot)
785 +
  }
786 +
}
787 +
788 +
#' @rdname autoplot.ResamplingSptCVCstf
789 +
#' @export
790 +
autoplot.ResamplingRepeatedSpCVDisc = function( # nolint
791 +
  object,
792 +
  task,
793 +
  fold_id = NULL,
794 +
  repeats_id = 1,
795 +
  plot_as_grid = TRUE,
796 +
  train_color = "#0072B5",
797 +
  test_color = "#E18727",
798 +
  crs = NULL,
799 +
  show_omitted = FALSE,
800 +
  ...) {
801 +
802 +
  autoplot.ResamplingSpCVDisc(
803 +
    object = object,
804 +
    task = task,
805 +
    fold_id = fold_id,
806 +
    plot_as_grid = plot_as_grid,
807 +
    train_color = train_color,
808 +
    test_color = test_color,
809 +
    crs = crs,
810 +
    show_omitted = show_omitted,
811 +
    ... = ...,
812 +
    # ellipsis
813 +
    repeats_id = repeats_id
814 +
  )
815 +
}
816 +
817 +
#' @importFrom graphics plot
818 +
#' @rdname autoplot.ResamplingSpCVDisc
819 +
#' @export
820 +
plot.ResamplingSpCVDisc = function(x, ...) {
821 +
  print(autoplot(x, ...)) # nocov
822 +
}
823 +
824 +
#' @rdname autoplot.ResamplingSpCVDisc
825 +
#' @export
826 +
plot.ResamplingRepeatedSpCVDisc = function(x, ...) {
827 +
  print(autoplot(x, ...)) # nocov
828 +
}
829 +
472 830
# CV ---------------------------------------------------------------------------
473 831
474 832
#' @title Visualization Functions for Non-Spatial CV Methods.

@@ -27,7 +27,7 @@
Loading
27 27
28 28
  public = list(
29 29
    #' @description
30 -
    #' Create an "Environmental Block" resampling instance.
30 +
    #' Create an "coordinate-based" repeated resampling instance.
31 31
    #' @param id `character(1)`\cr
32 32
    #'   Identifier for the resampling strategy.
33 33
    initialize = function(id = "spcv_coords") {

@@ -0,0 +1,160 @@
Loading
1 +
#' @title Spatial "Disc" resampling with optional buffer zone
2 +
#'
3 +
#' @template rox_spcv_disc
4 +
#'
5 +
#' @references
6 +
#' `r format_bib("brenning2012")`
7 +
#'
8 +
#' @export
9 +
#' @examples
10 +
#' library(mlr3)
11 +
#' task = tsk("ecuador")
12 +
#'
13 +
#' # Instantiate Resampling
14 +
#' rcv = rsmp("spcv_disc", folds = 3L, radius = 200L, buffer = 200L)
15 +
#' rcv$instantiate(task)
16 +
#'
17 +
#' # Individual sets:
18 +
#' rcv$train_set(1)
19 +
#' rcv$test_set(1)
20 +
#' # check that no obs are in both sets
21 +
#' intersect(rcv$train_set(1), rcv$test_set(1)) # good!
22 +
#'
23 +
#' # Internal storage:
24 +
#' rcv$instance # table
25 +
ResamplingSpCVDisc = R6Class("ResamplingSpCVDisc",
26 +
  inherit = mlr3::Resampling,
27 +
28 +
  public = list(
29 +
    #' @description
30 +
    #' Create a "Spatial 'Disc' resampling" resampling instance.
31 +
    #' @param id `character(1)`\cr
32 +
    #'   Identifier for the resampling strategy.
33 +
    initialize = function(id = "spcv_disc") {
34 +
      ps = ParamSet$new(params = list(
35 +
        ParamInt$new("folds", lower = 1L, default = 10L, tags = "required"),
36 +
        ParamInt$new("radius",
37 +
          lower = 0L, tags = "required",
38 +
          special_vals = list(0L)),
39 +
        ParamInt$new("buffer",
40 +
          lower = 0L, default = NULL,
41 +
          special_vals = list(NULL)),
42 +
        ParamUty$new("prob",
43 +
          default = NULL),
44 +
        ParamLgl$new("replace", default = FALSE)
45 +
      ))
46 +
      ps$values = list(folds = 10L)
47 +
      super$initialize(
48 +
        id = id,
49 +
        param_set = ps
50 +
      )
51 +
    },
52 +
53 +
    #' @description
54 +
    #'  Materializes fixed training and test splits for a given task.
55 +
    #' @param task [Task]\cr
56 +
    #'  A task to instantiate.
57 +
    instantiate = function(task) {
58 +
59 +
      mlr3::assert_task(task)
60 +
      checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
61 +
      groups = task$groups
62 +
63 +
      # Set values to default if missing
64 +
      mlr3misc::map(
65 +
        c("buffer", "radius", "prob", "replace"),
66 +
        function(x) private$.set_default_param_values(x)
67 +
      )
68 +
69 +
      if (!is.null(groups)) {
70 +
        stopf("Grouping is not supported for spatial resampling methods")
71 +
      }
72 +
73 +
      private$.sample(task$row_ids, task$coordinates())
74 +
75 +
      self$task_hash = task$hash
76 +
      self$task_nrow = task$nrow
77 +
      invisible(self)
78 +
    }
79 +
  ),
80 +
81 +
  active = list(
82 +
    #' @field iters `integer(1)`\cr
83 +
    #'   Returns the number of resampling iterations, depending on the
84 +
    #'   values stored in the `param_set`.
85 +
    iters = function() {
86 +
      self$param_set$values$folds
87 +
    }
88 +
  ),
89 +
90 +
  private = list(
91 +
    .sample = function(ids, coords) {
92 +
93 +
      index = sample.int(nrow(coords),
94 +
        size = self$param_set$values$folds,
95 +
        replace = self$param_set$values$replace,
96 +
        prob = self$param_set$values$prob
97 +
      )
98 +
99 +
      # we need to set a custom index starting at 1 because the index from
100 +
      # sperrorest does not start at 1 and does not increase by 1
101 +
      # the index is required for assigning the train/test sets to their
102 +
      # respective folds
103 +
      mlr3_index = 1
104 +
105 +
      for (i in index) {
106 +
107 +
        if (!is.null(self$param_set$values$buffer) |
108 +
          self$param_set$values$radius >= 0) {
109 +
          di = sqrt((coords[[1]] - as.numeric(coords[i, 1]))^2 + # nolint
110 +
            (coords[[2]] - as.numeric(coords[i, 2]))^2) # nolint
111 +
        }
112 +
        train_sel = numeric()
113 +
        if (self$param_set$values$radius >= 0) {
114 +
          # leave-disc-out with buffer:
115 +
          test_sel = which(di <= self$param_set$values$radius)
116 +
          train_sel <- which(di > (self$param_set$values$radius + self$param_set$values$buffer))
117 +
        } else {
118 +
          # leave-one-out with buffer:
119 +
          test_sel = i
120 +
          if (is.null(self$param_set$values$buffer)) {
121 +
            train_sel = seq_len(nrow(coords))[-i] # nocov
122 +
          } else {
123 +
            train_sel = which(di > self$param_set$values$buffer)
124 +
          }
125 +
        }
126 +
        if (length(train_sel) == 0) {
127 +
          warningf(
128 +
            "Empty training set in 'partition_disc': 'buffer'
129 +
            and/or 'radius' too large?",
130 +
            wrap = TRUE
131 +
          )
132 +
        }
133 +
134 +
        # similar result structure as in sptcv_cstf
135 +
        self$instance$test[[mlr3_index]] = test_sel
136 +
        self$instance$train[[mlr3_index]] = train_sel
137 +
138 +
        mlr3_index = mlr3_index + 1
139 +
      }
140 +
141 +
      invisible(self)
142 +
    },
143 +
144 +
    .set_default_param_values = function(param) {
145 +
      if (is.null(self$param_set$values[[param]])) {
146 +
        self$param_set$values[[param]] = self$param_set$default[[param]]
147 +
      }
148 +
    },
149 +
150 +
    # private get funs for train and test which are used by
151 +
    # Resampling$.get_set()
152 +
    .get_train = function(i) {
153 +
      self$instance$train[[i]]
154 +
    },
155 +
156 +
    .get_test = function(i) {
157 +
      self$instance$test[[i]]
158 +
    }
159 +
  )
160 +
)

@@ -0,0 +1,197 @@
Loading
1 +
#' @title Repeated Spatial "Disc" resampling with optional buffer zone
2 +
#'
3 +
#' @references
4 +
#' `r format_bib("brenning2012")`
5 +
#'
6 +
#' @export
7 +
#' @examples
8 +
#' library(mlr3)
9 +
#' task = tsk("ecuador")
10 +
#'
11 +
#' # Instantiate Resampling
12 +
#' rrcv = rsmp("repeated_spcv_disc",
13 +
#'   folds = 3L, repeats = 2,
14 +
#'   radius = 200L, buffer = 200L)
15 +
#' rrcv$instantiate(task)
16 +
#'
17 +
#' # Individual sets:
18 +
#' rrcv$iters
19 +
#' rrcv$folds(1:6)
20 +
#' rrcv$repeats(1:6)
21 +
#'
22 +
#' # Individual sets:
23 +
#' rrcv$train_set(1)
24 +
#' rrcv$test_set(1)
25 +
#' intersect(rrcv$train_set(1), rrcv$test_set(1))
26 +
#'
27 +
#' # Internal storage:
28 +
#' rrcv$instance # table
29 +
ResamplingRepeatedSpCVDisc = R6Class("ResamplingRepeatedSpCVDisc",
30 +
  inherit = mlr3::Resampling,
31 +
32 +
  public = list(
33 +
    #' @description
34 +
    #' Create a "Spatial 'Disc' resampling" resampling instance.
35 +
    #' @param id `character(1)`\cr
36 +
    #'   Identifier for the resampling strategy.
37 +
    initialize = function(id = "repeated_spcv_disc") {
38 +
      ps = ParamSet$new(params = list(
39 +
        ParamInt$new("folds", lower = 1L, default = 10L, tags = "required"),
40 +
        ParamInt$new("radius",
41 +
          lower = 0L, tags = "required",
42 +
          special_vals = list(0L)),
43 +
        ParamInt$new("buffer",
44 +
          lower = 0L, default = NULL,
45 +
          special_vals = list(NULL)),
46 +
        ParamUty$new("prob",
47 +
          default = NULL),
48 +
        ParamLgl$new("replace", default = FALSE),
49 +
        ParamInt$new("repeats", lower = 1, default = 1L, tags = "required")
50 +
      ))
51 +
      ps$values = list(folds = 10L, repeats = 1)
52 +
      super$initialize(
53 +
        id = id,
54 +
        param_set = ps,
55 +
        man = "mlr3spatiotempcv::mlr_resamplings_repeated_spcv_disc"
56 +
      )
57 +
58 +
    },
59 +
60 +
    #' @description Translates iteration numbers to fold number.
61 +
    #' @param iters `integer()`\cr
62 +
    #'   Iteration number.
63 +
    folds = function(iters) {
64 +
      iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
65 +
      ((iters - 1L) %% as.integer(self$param_set$values$repeats)) + 1L
66 +
    },
67 +
68 +
    #' @description Translates iteration numbers to repetition number.
69 +
    #' @param iters `integer()`\cr
70 +
    #'   Iteration number.
71 +
    repeats = function(iters) {
72 +
      iters = assert_integerish(iters, any.missing = FALSE, coerce = TRUE)
73 +
      ((iters - 1L) %/% as.integer(self$param_set$values$folds)) + 1L
74 +
    },
75 +
76 +
    #' @description
77 +
    #'  Materializes fixed training and test splits for a given task.
78 +
    #' @param task [Task]\cr
79 +
    #'  A task to instantiate.
80 +
    instantiate = function(task) {
81 +
82 +
      mlr3::assert_task(task)
83 +
      checkmate::assert_multi_class(task, c("TaskClassifST", "TaskRegrST"))
84 +
      groups = task$groups
85 +
      if (!is.null(groups)) {
86 +
        stopf("Grouping is not supported for spatial resampling methods.") # nocov
87 +
      }
88 +
89 +
      # Set values to default if missing
90 +
      mlr3misc::map(
91 +
        c("buffer", "radius", "prob", "replace"),
92 +
        function(x) private$.set_default_param_values(x)
93 +
      )
94 +
95 +
      private$.sample(task$row_ids, task$coordinates())
96 +
97 +
      self$task_hash = task$hash
98 +
      self$task_nrow = task$nrow
99 +
      invisible(self)
100 +
    }
101 +
  ),
102 +
103 +
  active = list(
104 +
105 +
    #' @field iters `integer(1)`\cr
106 +
    #'   Returns the number of resampling iterations, depending on the
107 +
    #'   values stored in the `param_set`.
108 +
    iters = function() {
109 +
      pv = self$param_set$values
110 +
      as.integer(pv$repeats) * as.integer(pv$folds)
111 +
    }
112 +
  ),
113 +
114 +
  private = list(
115 +
    .sample = function(ids, coords) {
116 +
      reps = self$param_set$values$repeats
117 +
      # declare empty list so the for-loop can write to its fields
118 +
      self$instance = vector("list", length = reps)
119 +
120 +
      # k = self$param_set$values$folds
121 +
122 +
      for (rep in seq_len(reps)) {
123 +
124 +
        index = sample.int(nrow(coords),
125 +
          size = self$param_set$values$folds,
126 +
          replace = self$param_set$values$replace,
127 +
          prob = self$param_set$values$prob
128 +
        )
129 +
130 +
        # we need to set a custom index starting at 1 because the index from
131 +
        # sperrorest does not start at 1 and does not increase by 1
132 +
        # the index is required for assigning the train/test sets to their
133 +
        # respective folds
134 +
        mlr3_index = 1
135 +
136 +
        for (i in index) {
137 +
138 +
          if (!is.null(self$param_set$values$buffer) |
139 +
            self$param_set$values$radius >= 0) {
140 +
            di = sqrt((coords[[1]] - as.numeric(coords[i, 1]))^2 + # nolint
141 +
              (coords[[2]] - as.numeric(coords[i, 2]))^2) # nolint
142 +
          }
143 +
          train_sel = numeric()
144 +
          if (self$param_set$values$radius >= 0) {
145 +
            # leave-disc-out with buffer:
146 +
            test_sel = which(di <= self$param_set$values$radius)
147 +
            train_sel <- which(di > (self$param_set$values$radius + self$param_set$values$buffer))
148 +
          } else {
149 +
            # leave-one-out with buffer:
150 +
            test_sel = i
151 +
            if (is.null(self$param_set$values$buffer)) {
152 +
              train_sel = seq_len(nrow(coords))[-i] # nocov
153 +
            } else {
154 +
              train_sel = which(di > self$param_set$values$buffer)
155 +
            }
156 +
          }
157 +
          if (length(train_sel) == 0) {
158 +
            warningf(
159 +
              "Empty training set in 'partition_disc': 'buffer'
160 +
            and/or 'radius' too large?",
161 +
              wrap = TRUE
162 +
            )
163 +
          }
164 +
165 +
          # similar result structure as in sptcv_cstf
166 +
          self$instance[[rep]]$test[[mlr3_index]] = test_sel
167 +
          self$instance[[rep]]$train[[mlr3_index]] = train_sel
168 +
169 +
          mlr3_index = mlr3_index + 1
170 +
        }
171 +
        invisible(self)
172 +
      }
173 +
    },
174 +
175 +
    .set_default_param_values = function(param) {
176 +
      if (is.null(self$param_set$values[[param]])) {
177 +
        self$param_set$values[[param]] = self$param_set$default[[param]]
178 +
      }
179 +
    },
180 +
181 +
    .get_train = function(i) {
182 +
      i = as.integer(i) - 1L
183 +
      folds = as.integer(self$param_set$values$folds)
184 +
      rep = i %/% folds + 1L
185 +
      fold = i %% folds + 1L
186 +
      self$instance[[rep]]$train[[fold]]
187 +
    },
188 +
189 +
    .get_test = function(i) {
190 +
      i = as.integer(i) - 1L
191 +
      folds = as.integer(self$param_set$values$folds)
192 +
      rep = i %/% folds + 1L
193 +
      fold = i %% folds + 1L
194 +
      self$instance[[rep]]$test[[fold]]
195 +
    }
196 +
  )
197 +
)

@@ -43,7 +43,7 @@
Loading
43 43
      super$initialize(
44 44
        id = id,
45 45
        param_set = ps,
46 -
        man = "mlr3spatiotempcv::mlr_resamplings_repeated_spcvcoords"
46 +
        man = "mlr3spatiotempcv::mlr_resamplings_repeated_spcv_coords"
47 47
      )
48 48
49 49
    },

@@ -54,12 +54,16 @@
Loading
54 54
55 55
  # 'fold' needs to be a factor, otherwise `show.legend = "points" has no
56 56
  # effect
57 -
58 57
  object$indicator = as.factor(as.character(object$indicator))
59 58
60 -
  # reorder factor levels so that "train" comes first
61 -
  object$indicator = ordered(object$indicator, levels = c("Train", "Test"))
59 +
  if ("Omitted" %in% levels(object$indicator)) {
60 +
    # reorder factor levels so that "train" comes first
61 +
    object$indicator = ordered(object$indicator,
62 +
      levels = c("Train", "Test", "Omitted"))
63 +
  } else {
62 64
65 +
    # reorder factor levels so that "train" comes first
66 +
    object$indicator = ordered(object$indicator, levels = c("Train", "Test"))
67 +
  }
63 68
  return(object)
64 -
65 69
}
Files Coverage
R 95.48%
Project Totals (24 files) 95.48%
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading