ModelOriented / DALEX
Showing 2 of 6 files from the diff.

@@ -47,3 +47,12 @@
Loading
47 47
48 48
  TRUE
49 49
}
50 +
51 +
52 +
cut_data_to_n <- function(data, N) {
53 +
  if (!is.null(N) && N < nrow(data)) {
54 +
    ndata <- data[sample(1:nrow(data), N),]
55 +
  } else {
56 +
    ndata <- data
57 +
  }
58 +
}

@@ -14,7 +14,7 @@
Loading
14 14
#' @param ... other parameters that will be passed to \code{iBreakDown::break_down}
15 15
#' @param variable_splits named list of splits for variables. It is used by oscillations based measures. Will be passed to \code{\link[ingredients]{ceteris_paribus}}.
16 16
#' @param variables names of variables for which splits shall be calculated. Will be passed to \code{\link[ingredients]{ceteris_paribus}}.
17 -
#' @param N number of observations used for calculation of oscillations. By default 500.
17 +
#' @param N number of observations used for calculations. By default all observations are taken with an exception of \code{oscillations_emp} type when 500 is taken as default.
18 18
#' @param variable_splits_type how variable grids shall be calculated? Will be passed to \code{\link[ingredients]{ceteris_paribus}}.
19 19
#' @param type the type of variable attributions. Either \code{shap}, \code{oscillations}, \code{oscillations_uni},
20 20
#' \code{oscillations_emp}, \code{break_down} or \code{break_down_interactions}.
@@ -69,24 +69,30 @@
Loading
69 69
#'
70 70
#' @name predict_parts
71 71
#' @export
72 -
predict_parts <- function(explainer, new_observation, ..., type = "break_down") {
72 +
predict_parts <- function(explainer, new_observation, N = NULL, ..., type = "break_down") {
73 +
74 +
  # Sample the data according to N
75 +
73 76
  switch (type,
74 -
          "break_down"              = predict_parts_break_down(explainer, new_observation, ...),
75 -
          "break_down_interactions" = predict_parts_break_down_interactions(explainer, new_observation, ...),
76 -
          "shap"                    = predict_parts_shap(explainer, new_observation, ...),
77 -
          "oscillations"            = predict_parts_oscillations(explainer, new_observation, ...),
78 -
          "oscillations_uni"        = predict_parts_oscillations_uni(explainer, new_observation, ...),
79 -
          "oscillations_emp"        = predict_parts_oscillations_emp(explainer, new_observation, ...),
77 +
          "break_down"              = predict_parts_break_down(explainer, new_observation, N = N, ...),
78 +
          "break_down_interactions" = predict_parts_break_down_interactions(explainer, new_observation, N = N, ...),
79 +
          "shap"                    = predict_parts_shap(explainer, new_observation, N = N, ...),
80 +
          "oscillations"            = predict_parts_oscillations(explainer, new_observation, N = N, ...),
81 +
          "oscillations_uni"        = predict_parts_oscillations_uni(explainer, new_observation, N = N, ...),
82 +
          "oscillations_emp"        = predict_parts_oscillations_emp(explainer, new_observation, N = N, ...),
80 83
          stop("The type argument shall be either 'shap' or 'break_down' or 'break_down_interactions' or 'oscillations' or 'oscillations_uni' or 'oscillations_emp'")
81 84
  )
82 85
}
83 86
84 87
#' @name predict_parts
85 88
#' @export
86 -
predict_parts_oscillations <- function(explainer, new_observation, ...) {
89 +
predict_parts_oscillations <- function(explainer, new_observation, N = NULL, ...) {
87 90
  # run checks against the explainer objects
88 91
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_oscillations")
89 92
93 +
  # Cut data according to N
94 +
  explainer$data <- cut_data_to_n(explainer$data, N)
95 +
90 96
  # call the ceteris_paribus
91 97
  cp <- ingredients::ceteris_paribus(explainer,
92 98
                                     new_observation = new_observation,
@@ -98,10 +104,13 @@
Loading
98 104
99 105
#' @name predict_parts
100 106
#' @export
101 -
predict_parts_oscillations_uni <- function(explainer, new_observation, variable_splits_type = "uniform", ...) {
107 +
predict_parts_oscillations_uni <- function(explainer, new_observation, variable_splits_type = "uniform", N = NULL, ...) {
102 108
  # run checks against the explainer objects
103 109
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_oscillations_uni")
104 110
111 +
  # Cut data according to N
112 +
  explainer$data <- cut_data_to_n(explainer$data, N)
113 +
105 114
  # call the ceteris_paribus
106 115
  cp <- ingredients::ceteris_paribus(explainer,
107 116
                                     new_observation = new_observation,
@@ -118,11 +127,18 @@
Loading
118 127
  # run checks against the explainer objects
119 128
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_oscillations_emp")
120 129
  variables <- intersect(variables, colnames(new_observation))
121 -
  N <- min(N, nrow(explainer$data))
122 -
  data_sample <- explainer$data[sample(1:nrow(explainer$data), N),]
130 +
131 +
132 +
  if (is.null(N)) {
133 +
    # Default value, set 500
134 +
    N <- 500
135 +
  }
136 +
137 +
  # Cut data according to N
138 +
  explainer$data <- cut_data_to_n(explainer$data, N)
123 139
124 140
  variable_splits <- lapply(variables, function(var) {
125 -
    data_sample[,var]
141 +
    explainer$data[,var]
126 142
  })
127 143
  names(variable_splits) <- variables
128 144
@@ -139,10 +155,13 @@
Loading
139 155
140 156
#' @name predict_parts
141 157
#' @export
142 -
predict_parts_break_down <- function(explainer, new_observation, ...) {
158 +
predict_parts_break_down <- function(explainer, new_observation, N = NULL, ...) {
143 159
  # run checks against the explainer objects
144 160
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_break_down")
145 161
162 +
  # Cut data according to N
163 +
  explainer$data <- cut_data_to_n(explainer$data, N)
164 +
146 165
  # call the break_down
147 166
  res <- iBreakDown::break_down(explainer,
148 167
                                new_observation = new_observation,
@@ -153,10 +172,13 @@
Loading
153 172
154 173
#' @name predict_parts
155 174
#' @export
156 -
predict_parts_break_down_interactions <- function(explainer, new_observation, ...) {
175 +
predict_parts_break_down_interactions <- function(explainer, new_observation, N = NULL, ...) {
157 176
  # run checks against the explainer objects
158 177
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_break_down_interactions")
159 178
179 +
  # Cut data according to N
180 +
  explainer$data <- cut_data_to_n(explainer$data, N)
181 +
160 182
  # call the break_down
161 183
  res <- iBreakDown::break_down(explainer,
162 184
                                new_observation = new_observation,
@@ -168,10 +190,13 @@
Loading
168 190
169 191
#' @name predict_parts
170 192
#' @export
171 -
predict_parts_shap <- function(explainer, new_observation, ...) {
193 +
predict_parts_shap <- function(explainer, new_observation, N = NULL, ...) {
172 194
  # run checks against the explainer objects
173 195
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_shap")
174 196
197 +
  # Cut data according to N
198 +
  explainer$data <- cut_data_to_n(explainer$data, N)
199 +
175 200
  # call the shap from iBreakDown
176 201
  res <- iBreakDown::shap(explainer,
177 202
                          new_observation = new_observation,
Files Coverage
R 86.89%
Project Totals (31 files) 86.89%
Notifications are pending CI completion. Waiting for GitHub's status webhook to queue notifications. Push notifications now.
1
comment: false
2

3
coverage:
4
  status:
5
    project:
6
      default:
7
        target: auto
8
        threshold: 1%
9
        informational: true
10
    patch:
11
      default:
12
        target: auto
13
        threshold: 1%
14
        informational: true
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading