MI2DataLab / survxai
Showing 10 of 22 files from the diff.
Newly tracked file
R/plot_variable_response.R changed.
Newly tracked file
R/ceteris_paribus.R changed.
Newly tracked file
R/model_performance.R changed.
Newly tracked file
R/plot_ceteris_paribus.R changed.
Newly tracked file
R/explain.R changed.
Newly tracked file
R/plot_prediction_breakdown.R changed.
Newly tracked file
R/plot_model_performance.R changed.
Newly tracked file
R/variable_response.R changed.
Newly tracked file
R/plot_explainer.R changed.
Newly tracked file
R/prediction_breakdown.R changed.

@@ -19,7 +19,7 @@
 19 19 `#' prob <- rms::survest(model, data, times = times)\$surv` 20 20 `#' return(prob)` 21 21 `#' }` 22 - `#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 22 + `#' cph_model <- cph(Surv(years, status)~sex + bili + stage, data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 23 23 `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],` 24 24 `#' y = Surv(pbcTest\$years, pbcTest\$status), predict_function = predict_times)` 25 25 `#' svr_cph <- variable_response(surve_cph, "sex")`

@@ -1,5 +1,5 @@
 1 1 `#' Ceteris Paribus` 2 - `#' ` 2 + `#'` 3 3 `#' @description The \code{ceteris_paribus()} function computes the predictions for the neighbor of our chosen observation. The neighbour is defined as the observations with changed value of one of the variable.` 4 4 `#'` 5 5 `#' @param explainer a model to be explained, preprocessed by the 'survxai::explain' function`
@@ -14,19 +14,19 @@
 14 14 `#' @importFrom stats quantile` 15 15 `#' @importFrom utils head` 16 16 `#'` 17 - `#' @examples ` 17 + `#' @examples` 18 18 `#' \donttest{` 19 19 `#' library(survxai)` 20 - `#' library(rms) ` 20 + `#' library(rms)` 21 21 `#' data("pbcTrain")` 22 22 `#' data("pbcTest")` 23 - `#' predict_times <- function(model, data, times){ ` 23 + `#' predict_times <- function(model, data, times){` 24 24 `#' prob <- rms::survest(model, data, times = times)\$surv` 25 25 `#' return(prob)` 26 26 `#' }` 27 - `#' cph_model <- cph(Surv(years, status)~., data = pbcTrain, surv = TRUE, x = TRUE, y=TRUE)` 28 - `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)], ` 29 - `#' y = Surv(pbcTest\$years, pbcTest\$status), ` 27 + `#' cph_model <- cph(Surv(years, status)~ sex + bili + stage, data = pbcTrain, surv = TRUE, x = TRUE, y=TRUE)` 28 + `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],` 29 + `#' y = Surv(pbcTest\$years, pbcTest\$status),` 30 30 `#' predict_function = predict_times)` 31 31 `#' cp_cph <- ceteris_paribus(surve_cph, pbcTest[1,-c(1,5)])` 32 32 `#' }`
@@ -37,27 +37,27 @@
 37 37 ` stop("The ceteris_paribus() function requires an object created with explain() function from survxai package.")` 38 38 ` if (is.null(explainer\$data))` 39 39 ` stop("The ceteris_paribus() function requires explainers created with specified 'data' parameter.")` 40 - ` ` 40 + 41 41 ` data <- base::as.data.frame(explainer\$data)` 42 42 ` model <- explainer\$model` 43 43 ` predict_function <- explainer\$predict_function` 44 44 ` names_to_present <- colnames(data)` 45 45 ` grid_points <- grid_points` 46 - ` ` 46 + 47 47 ` if (!is.null(selected_variables)) {` 48 48 ` names_to_present <- intersect(names_to_present, selected_variables)` 49 49 ` }` 50 - ` ` 50 + 51 51 ` times <- explainer\$times` 52 52 ` times <- sort(times)` 53 - ` ` 53 + 54 54 ` responses <- lapply(names_to_present, function(vname, times_s, observation_s, model_s, explainer_s, grid_points_s, data_s, predict_function_s) calculate_responses(vname,times_s = times, observation_s = observation, model_s = model, explainer_s = explainer, grid_points_s = grid_points, data_s = data, predict_function_s = predict_function))` 55 - ` ` 55 + 56 56 ` all_responses <- do.call(rbind, responses)` 57 57 ` new_y_hat <- predict_function(model, observation, times)` 58 58 ` attr(all_responses, "prediction") <- list(observation = observation, new_y_hat = new_y_hat, times = times)` 59 59 ` attr(all_responses, "grid_points") <- grid_points` 60 - ` ` 60 + 61 61 ` class(all_responses) <- c("surv_ceteris_paribus_explainer", "data.frame")` 62 62 ` all_responses` 63 63 `}`

@@ -17,7 +17,7 @@
 17 17 `#' library(rms)` 18 18 `#' data("pbcTrain")` 19 19 `#' data("pbcTest")` 20 - `#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 20 + `#' cph_model <- cph(Surv(years, status)~sex + bili + stage, data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 21 21 `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],` 22 22 `#' y = Surv(pbcTest\$years, pbcTest\$status))` 23 23 `#' mp_cph <- model_performance(surve_cph)`

@@ -21,7 +21,7 @@
 21 21 `#' prob <- rms::survest(model, data, times = times)\$surv` 22 22 `#' return(prob)` 23 23 `#' }` 24 - `#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 24 + `#' cph_model <- cph(Surv(years, status)~sex + bili + stage, data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 25 25 `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],` 26 26 `#' y = Surv(pbcTest\$years, pbcTest\$status), predict_function = predict_times)` 27 27 `#' cp_cph <- ceteris_paribus(surve_cph, pbcTest[1,-c(1,5)])`

@@ -45,7 +45,7 @@
 45 45 `#' prob <- rms::survest(model, data, times = times)\$surv` 46 46 `#' return(prob)` 47 47 `#' }` 48 - `#' cph_model <- cph(Surv(days/365, status)~., data=pbc, surv=TRUE, x = TRUE, y=TRUE)` 48 + `#' cph_model <- cph(Surv(days/365, status)~ sex + bili + stage, data=pbc, surv=TRUE, x = TRUE, y=TRUE)` 49 49 `#' surve_cph <- explain(model = cph_model, data = pbc[,-c(1,2)], y = Surv(pbc\$days/365, pbc\$status),` 50 50 `#' predict_function = predict_times)` 51 51 `#' }`

@@ -22,7 +22,7 @@
 22 22 `#' prob <- rms::survest(model, data, times = times)\$surv` 23 23 `#' return(prob)` 24 24 `#' }` 25 - `#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 25 + `#' cph_model <- cph(Surv(years, status)~sex + bili + stage, data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 26 26 `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],` 27 27 `#' y = Surv(pbcTest\$years, pbcTest\$status), predict_function = predict_times)` 28 28 `#' broken_prediction <- prediction_breakdown(surve_cph, pbcTest[1,-c(1,5)])`

@@ -17,7 +17,7 @@
 17 17 `#' prob <- rms::survest(model, data, times = times)\$surv` 18 18 `#' return(prob)` 19 19 `#' }` 20 - `#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 20 + `#' cph_model <- cph(Surv(years, status)~sex + bili + stage, data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 21 21 `#'surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],` 22 22 `#' y = Surv(pbcTest\$years, pbcTest\$status), predict_function = predict_times)` 23 23 `#' mp_cph <- model_performance(surve_cph)`

@@ -18,8 +18,8 @@
 18 18 `#' prob <- rms::survest(model, data, times = times)\$surv` 19 19 `#' return(prob)` 20 20 `#' }` 21 - `#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 22 - `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)], ` 21 + `#' cph_model <- cph(Surv(years, status)~sex + bili + stage, data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 22 + `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],` 23 23 `#' y = Surv(pbcTest\$years, pbcTest\$status), predict_function = predict_times)` 24 24 `#' svr_cph <- variable_response(surve_cph, "sex")` 25 25 `#' }`

@@ -8,7 +8,7 @@
 8 8 `#' @import ggplot2` 9 9 `#' @importFrom survival survfit` 10 10 `#' @importFrom survminer ggsurvplot` 11 - `#' @examples ` 11 + `#' @examples` 12 12 `#' \donttest{` 13 13 `#' library(survxai)` 14 14 `#' library(rms)`
@@ -18,8 +18,8 @@
 18 18 `#' prob <- rms::survest(model, data, times = times)\$surv` 19 19 `#' return(prob)` 20 20 `#' }` 21 - `#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 22 - `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)], ` 21 + `#' cph_model <- cph(Surv(years, status)~sex + bili + stage, data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 22 + `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],` 23 23 `#' y = Surv(pbcTest\$years, pbcTest\$status), predict_function = predict_times)` 24 24 `#' plot(surve_cph)` 25 25 `#' }`

@@ -24,7 +24,7 @@
 24 24 `#' prob <- rms::survest(model, data, times = times)\$surv` 25 25 `#' return(prob)` 26 26 `#' }` 27 - `#' cph_model <- cph(Surv(years, status)~., data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 27 + `#' cph_model <- cph(Surv(years, status)~sex + bili + stage, data=pbcTrain, surv=TRUE, x = TRUE, y=TRUE)` 28 28 `#' surve_cph <- explain(model = cph_model, data = pbcTest[,-c(1,5)],` 29 29 `#' y = Surv(pbcTest\$years, pbcTest\$status), predict_function = predict_times)` 30 30 `#' broken_prediction <- prediction_breakdown(surve_cph, pbcTest[1,-c(1,5)])`
 1 ```comment: false ```