ModelOriented / DALEX

Compare 0bd9bec ... +2 ... a79737f

Showing 2 of 6 files from the diff.

@@ -47,3 +47,12 @@
Loading
47 47
48 48
  TRUE
49 49
}
50 +
51 +
52 +
cut_data_to_n <- function(data, N) {
53 +
  if (!is.null(N) && N < nrow(data)) {
54 +
    ndata <- data[sample(1:nrow(data), N),]
55 +
  } else {
56 +
    ndata <- data
57 +
  }
58 +
}

@@ -14,7 +14,7 @@
Loading
14 14
#' @param ... other parameters that will be passed to \code{iBreakDown::break_down}
15 15
#' @param variable_splits named list of splits for variables. It is used by oscillations based measures. Will be passed to \code{\link[ingredients]{ceteris_paribus}}.
16 16
#' @param variables names of variables for which splits shall be calculated. Will be passed to \code{\link[ingredients]{ceteris_paribus}}.
17 -
#' @param N number of observations used for calculation of oscillations. By default 500.
17 +
#' @param N number of observations used for calculations. By default all observations are taken with an exception of \code{oscillations_emp} type when 500 is taken as default.
18 18
#' @param variable_splits_type how variable grids shall be calculated? Will be passed to \code{\link[ingredients]{ceteris_paribus}}.
19 19
#' @param type the type of variable attributions. Either \code{shap}, \code{oscillations}, \code{oscillations_uni},
20 20
#' \code{oscillations_emp}, \code{break_down} or \code{break_down_interactions}.
@@ -69,24 +69,30 @@
Loading
69 69
#'
70 70
#' @name predict_parts
71 71
#' @export
72 -
predict_parts <- function(explainer, new_observation, ..., type = "break_down") {
72 +
predict_parts <- function(explainer, new_observation, N = NULL, ..., type = "break_down") {
73 +
74 +
  # Sample the data according to N
75 +
73 76
  switch (type,
74 -
          "break_down"              = predict_parts_break_down(explainer, new_observation, ...),
75 -
          "break_down_interactions" = predict_parts_break_down_interactions(explainer, new_observation, ...),
76 -
          "shap"                    = predict_parts_shap(explainer, new_observation, ...),
77 -
          "oscillations"            = predict_parts_oscillations(explainer, new_observation, ...),
78 -
          "oscillations_uni"        = predict_parts_oscillations_uni(explainer, new_observation, ...),
79 -
          "oscillations_emp"        = predict_parts_oscillations_emp(explainer, new_observation, ...),
77 +
          "break_down"              = predict_parts_break_down(explainer, new_observation, N = N, ...),
78 +
          "break_down_interactions" = predict_parts_break_down_interactions(explainer, new_observation, N = N, ...),
79 +
          "shap"                    = predict_parts_shap(explainer, new_observation, N = N, ...),
80 +
          "oscillations"            = predict_parts_oscillations(explainer, new_observation, N = N, ...),
81 +
          "oscillations_uni"        = predict_parts_oscillations_uni(explainer, new_observation, N = N, ...),
82 +
          "oscillations_emp"        = predict_parts_oscillations_emp(explainer, new_observation, N = N, ...),
80 83
          stop("The type argument shall be either 'shap' or 'break_down' or 'break_down_interactions' or 'oscillations' or 'oscillations_uni' or 'oscillations_emp'")
81 84
  )
82 85
}
83 86
84 87
#' @name predict_parts
85 88
#' @export
86 -
predict_parts_oscillations <- function(explainer, new_observation, ...) {
89 +
predict_parts_oscillations <- function(explainer, new_observation, N = NULL, ...) {
87 90
  # run checks against the explainer objects
88 91
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_oscillations")
89 92
93 +
  # Cut data according to N
94 +
  explainer$data <- cut_data_to_n(explainer$data, N)
95 +
90 96
  # call the ceteris_paribus
91 97
  cp <- ingredients::ceteris_paribus(explainer,
92 98
                                     new_observation = new_observation,
@@ -98,10 +104,13 @@
Loading
98 104
99 105
#' @name predict_parts
100 106
#' @export
101 -
predict_parts_oscillations_uni <- function(explainer, new_observation, variable_splits_type = "uniform", ...) {
107 +
predict_parts_oscillations_uni <- function(explainer, new_observation, variable_splits_type = "uniform", N = NULL, ...) {
102 108
  # run checks against the explainer objects
103 109
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_oscillations_uni")
104 110
111 +
  # Cut data according to N
112 +
  explainer$data <- cut_data_to_n(explainer$data, N)
113 +
105 114
  # call the ceteris_paribus
106 115
  cp <- ingredients::ceteris_paribus(explainer,
107 116
                                     new_observation = new_observation,
@@ -118,11 +127,18 @@
Loading
118 127
  # run checks against the explainer objects
119 128
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_oscillations_emp")
120 129
  variables <- intersect(variables, colnames(new_observation))
121 -
  N <- min(N, nrow(explainer$data))
122 -
  data_sample <- explainer$data[sample(1:nrow(explainer$data), N),]
130 +
131 +
132 +
  if (is.null(N)) {
133 +
    # Default value, set 500
134 +
    N <- 500
135 +
  }
136 +
137 +
  # Cut data according to N
138 +
  explainer$data <- cut_data_to_n(explainer$data, N)
123 139
124 140
  variable_splits <- lapply(variables, function(var) {
125 -
    data_sample[,var]
141 +
    explainer$data[,var]
126 142
  })
127 143
  names(variable_splits) <- variables
128 144
@@ -139,10 +155,13 @@
Loading
139 155
140 156
#' @name predict_parts
141 157
#' @export
142 -
predict_parts_break_down <- function(explainer, new_observation, ...) {
158 +
predict_parts_break_down <- function(explainer, new_observation, N = NULL, ...) {
143 159
  # run checks against the explainer objects
144 160
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_break_down")
145 161
162 +
  # Cut data according to N
163 +
  explainer$data <- cut_data_to_n(explainer$data, N)
164 +
146 165
  # call the break_down
147 166
  res <- iBreakDown::break_down(explainer,
148 167
                                new_observation = new_observation,
@@ -153,10 +172,13 @@
Loading
153 172
154 173
#' @name predict_parts
155 174
#' @export
156 -
predict_parts_break_down_interactions <- function(explainer, new_observation, ...) {
175 +
predict_parts_break_down_interactions <- function(explainer, new_observation, N = NULL, ...) {
157 176
  # run checks against the explainer objects
158 177
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_break_down_interactions")
159 178
179 +
  # Cut data according to N
180 +
  explainer$data <- cut_data_to_n(explainer$data, N)
181 +
160 182
  # call the break_down
161 183
  res <- iBreakDown::break_down(explainer,
162 184
                                new_observation = new_observation,
@@ -168,10 +190,13 @@
Loading
168 190
169 191
#' @name predict_parts
170 192
#' @export
171 -
predict_parts_shap <- function(explainer, new_observation, ...) {
193 +
predict_parts_shap <- function(explainer, new_observation, N = NULL, ...) {
172 194
  # run checks against the explainer objects
173 195
  test_explainer(explainer, has_data = TRUE, function_name = "predict_parts_shap")
174 196
197 +
  # Cut data according to N
198 +
  explainer$data <- cut_data_to_n(explainer$data, N)
199 +
175 200
  # call the shap from iBreakDown
176 201
  res <- iBreakDown::shap(explainer,
177 202
                          new_observation = new_observation,

Learn more Showing 3 files with coverage changes found.

Changes in R/explain.R
-2
Loading file...
Changes in R/plot_model_performance.R
New
Loading file...
Changes in R/predict_parts.R
-1
Loading file...
Files Coverage
R 0.08% 86.89%
Project Totals (31 files) 86.89%
Loading