mlr-org / parallelMap
1
#' @title Maps a function over lists or vectors in parallel.
2
#'
3
#' @description
4
#' Uses the parallelization mode and the other options specified in
5
#' [parallelStart()].
6
#'
7
#' Libraries and source file can be initialized on slaves with
8
#' [parallelLibrary()] and [parallelSource()].
9
#'
10
#' Large objects can be separately exported via [parallelExport()],
11
#' they can be simply used under their exported name in slave body code.
12
#'
13
#' Regarding error handling, see the argument `impute.error`.
14
#'
15
#' @param fun [function]\cr
16
#'   Function to map over `...`.
17
#' @param ... (any)\cr
18
#'   Arguments to vectorize over (list or vector).
19
#' @param more.args [list]\cr
20
#'   A list of other arguments passed to `fun`.
21
#'   Default is empty list.
22
#' @param simplify (`logical(1)`)\cr
23
#'   Should the result be simplified? See [simplify2array]. If `TRUE`,
24
#'   `simplify2array(higher = TRUE)` will be called on the result object.
25
#'   Default is `FALSE`.
26
#' @param use.names (`logical(1)`)\cr
27
#'   Should result be named?
28
#'   Use names if the first `...` argument has names, or if it is a
29
#'   character vector, use that character vector as the names.
30
#' @param impute.error (`NULL` | `function(x)`)\cr
31
#'   This argument can be used for improved error handling. `NULL` means that,
32
#'   if an exception is generated on one of the slaves, it is also thrown on the
33
#'   master. Usually all slave jobs will have to terminate until this exception
34
#'   on the master can be thrown. If you pass a constant value or a function,
35
#'   all jobs are guaranteed to return a result object, without generating an
36
#'   exception on the master for slave errors. In case of an error, this is a
37
#'   [simpleError()] object containing the error message. If you passed a
38
#'   constant object, the error-objects will be substituted with this object. If
39
#'   you passed a function, it will be used to operate on these error-objects
40
#'   (it will ONLY be applied to the error results). For example, using
41
#'   `identity` would  keep and return the `simpleError`-object, or `function(x)
42
#'   99` would impute a constant value (which could be achieved more easily by
43
#'   simply passing `99`). Default is `NULL`.
44
#' @param level (`character(1)`)\cr
45
#'   If a (non-missing) level is specified in [parallelStart()],
46
#'   this call is only parallelized if the level specified here matches.
47
#'   Useful if this function is used in a package.
48
#'   Default is `NA`.
49
#' @param show.info (`logical(1)`)\cr
50
#'   Verbose output on console?
51
#'   Can be used to override setting from options / [parallelStart()].
52
#'   Default is NA which means no overriding.
53
#' @return Result.
54
#' @export
55
#' @examples
56
#' parallelStart()
57
#' parallelMap(identity, 1:2)
58
#' parallelStop()
59
parallelMap = function(fun, ..., more.args = list(), simplify = FALSE,
60
  use.names = FALSE, impute.error = NULL, level = NA_character_,
61
  show.info = NA) {
62

63 2
  assertFunction(fun)
64 2
  assertList(more.args)
65 2
  assertFlag(simplify)
66 2
  assertFlag(use.names)
67
  # if it is a constant value construct function to impute
68 2
  if (!is.null(impute.error)) {
69 2
    if (is.function(impute.error)) {
70 2
      impute.error.fun = impute.error
71
    } else {
72 2
      impute.error.fun = function(x) impute.error
73
    }
74
  }
75 2
  assertString(level, na.ok = TRUE)
76 2
  assertFlag(show.info, na.ok = TRUE)
77

78 2
  if (!is.na(level) && level %nin% unlist(getPMOption("registered.levels", list()))) {
79 0
    stopf("Level '%s' not registered", level)
80
  }
81

82 2
  cpus = getPMOptCpus()
83 2
  load.balancing = getPMOptLoadBalancing()
84 2
  logging = getPMOptLogging()
85 2
  reproducible = getPMOptReproducible()
86
  # use NA to encode "no logging" in logdir
87 2
  logdir = ifelse(logging, getNextLogDir(), NA_character_)
88

89 2
  if (isModeLocal() || !isParallelizationLevel(level) || getPMOptOnSlave()) {
90 2
    if (!is.null(impute.error)) {
91
      # so we behave in local mode as in parallelSlaveWrapper
92 2
      fun2 = function(...) {
93 2
        res = try(fun(...), silent = getOption("parallelMap.suppress.local.errors"))
94 2
        if (BBmisc::is.error(res)) {
95 2
          res = list(try.object = res)
96 2
          class(res) = "parallelMapErrorWrapper"
97
        }
98 2
        return(res)
99
      }
100
    } else {
101 2
      fun2 = fun
102
    }
103 2
    assignInFunctionNamespace(fun, env = PKG_LOCAL_ENV)
104 2
    res = mapply(fun2, ..., MoreArgs = more.args, SIMPLIFY = FALSE, USE.NAMES = FALSE)
105
  } else {
106 2
    iters = seq_along(..1)
107 2
    showInfoMessage("Mapping in parallel%s: mode = %s; level = %s; cpus = %i; elements = %i.",
108 2
      ifelse(load.balancing, " (load balanced)", ""), getPMOptMode(),
109 2
      level, getPMOptCpus(), length(iters), show.info = show.info)
110

111 2
    if (isModeMulticore()) {
112 2
      more.args = c(list(.fun = fun, .logdir = logdir), more.args)
113 2
      if (reproducible) {
114 2
        old.seed = .Random.seed
115 2
        old.rng.kind = RNGkind()
116 2
        seed = sample(1:100000, 1)
117
        # we need to reset the seed first in case the user supplied a seed,
118
        # otherwise "L'Ecuyer-CMRG" won't be used
119 2
        rm(.Random.seed, envir = globalenv())
120 2
        set.seed(seed, "L'Ecuyer-CMRG")
121
      }
122 2
      res = MulticoreClusterMap(slaveWrapper, ..., .i = iters,
123 2
        MoreArgs = more.args, mc.cores = cpus,
124 2
        SIMPLIFY = FALSE, USE.NAMES = FALSE)
125 2
      if (reproducible) {
126
        # restore initial RNGkind
127 2
        .Random.seed = old.seed
128 2
        RNGkind(old.rng.kind[1], old.rng.kind[2], old.rng.kind[3])
129
      }
130 2
    } else if (isModeSocket() || isModeMPI()) {
131 2
      more.args = c(list(.fun = fun, .logdir = logdir), more.args)
132 2
      if (load.balancing) {
133 2
        res = clusterMapLB(cl = NULL, slaveWrapper, ..., .i = iters,
134 2
          MoreArgs = more.args)
135
      } else {
136 2
        res = clusterMap(cl = NULL, slaveWrapper, ..., .i = iters,
137 2
          MoreArgs = more.args, SIMPLIFY = FALSE, USE.NAMES = FALSE)
138
      }
139 2
    } else if (isModeBatchJobs()) {
140
      # dont log extra in BatchJobs
141 2
      more.args = c(list(.fun = fun, .logdir = NA_character_), more.args)
142 2
      suppressMessages({
143 2
        reg = getBatchJobsReg()
144
        # FIXME: this should be exported by BatchJobs ...
145 2
        asNamespace("BatchJobs")$dbRemoveJobs(reg, BatchJobs::getJobIds(reg))
146 2
        BatchJobs::batchMap(reg, slaveWrapper, ..., more.args = more.args)
147
        # increase max.retries a bit, we dont want to abort here prematurely
148
        # if no resources set we submit with the default ones from the bj conf
149 2
        BatchJobs::submitJobs(reg, resources = getPMOptBatchJobsResources(), max.retries = 15)
150 2
        ok = BatchJobs::waitForJobs(reg, stop.on.error = is.null(impute.error))
151
      })
152
      # copy log files of terminated jobs to designated dir
153 2
      if (!is.na(logdir)) {
154 2
        term = BatchJobs::findTerminated(reg)
155 2
        fns = BatchJobs::getLogFiles(reg, term)
156 2
        dests = file.path(logdir, sprintf("%05i.log", term))
157 2
        file.copy(from = fns, to = dests)
158
      }
159 2
      ids = BatchJobs::getJobIds(reg)
160 2
      ids.err = BatchJobs::findErrors(reg)
161 2
      ids.exp = BatchJobs::findExpired(reg)
162 2
      ids.done = BatchJobs::findDone(reg)
163 2
      ids.notdone = c(ids.err, ids.exp)
164
      # construct notdone error messages
165 2
      msgs = rep("Job expired!", length(ids.notdone))
166 2
      msgs[ids.err] = BatchJobs::getErrorMessages(reg, ids.err)
167
      # handle errors (no impute): kill other jobs + stop on master
168 2
      if (is.null(impute.error) && length(c(ids.notdone)) > 0) {
169 0
        extra.msg = sprintf("Please note that remaining jobs were killed when 1st error occurred to save cluster time.\nIf you want to further debug errors, your BatchJobs registry is here:\n%s",
170 0
          reg$file.dir)
171 0
        onsys = BatchJobs::findOnSystem(reg)
172 0
        suppressMessages(
173 0
          BatchJobs::killJobs(reg, onsys)
174
        )
175 0
        onsys = BatchJobs::findOnSystem(reg)
176 0
        if (length(onsys) > 0L) {
177 0
          warningf("Still %i jobs from operation on system! kill them manually!", length(onsys))
178
        }
179 0
        if (length(ids.notdone) > 0L) {
180 0
          stopWithJobErrorMessages(ids.notdone, msgs, extra.msg)
181
        }
182
      }
183
      # if we reached this line and error occurred, we have impute.error != NULL (NULL --> stop before)
184 2
      res = vector("list", length(ids))
185 2
      res[ids.done] = BatchJobs::loadResults(reg, simplify = FALSE, use.names = FALSE)
186 2
      res[ids.notdone] = lapply(msgs, function(s) impute.error.fun(simpleError(s)))
187 2
    } else if (isModeBatchtools()) {
188
      # don't log extra in batchtools
189 2
      more.args = insert(more.args, list(.fun = fun, .logdir = NA_character_))
190

191 2
      old = getOption("batchtools.verbose")
192 2
      options(batchtools.verbose = FALSE)
193 2
      on.exit(options(batchtools.verbose = old))
194

195 2
      reg = getBatchtoolsReg()
196 2
      if (nrow(reg$status) > 0L) {
197 2
        batchtools::clearRegistry(reg = reg)
198
      }
199 2
      ids = batchtools::batchMap(fun = slaveWrapper, ..., more.args = more.args, reg = reg)
200 2
      batchtools::submitJobs(ids = ids, resources = getPMOptBatchtoolsResources(), reg = reg)
201 2
      ok = batchtools::waitForJobs(ids = ids, stop.on.error = is.null(impute.error), reg = reg)
202

203
      # copy log files of terminated jobs to designated directory
204 2
      if (!is.na(logdir)) {
205 0
        x = batchtools::findStarted(reg = reg)
206 0
        x$log.file = file.path(reg$file.dir, "logs", sprintf("%s.log", x$job.hash))
207 0
        .mapply(function(id, fn) writeLines(batchtools::getLog(id, reg = reg), con = fn), x, NULL)
208
      }
209

210 2
      if (ok) {
211 2
        res = batchtools::reduceResultsList(ids, reg = reg)
212
      } else {
213 0
        if (is.null(impute.error)) {
214 0
          extra.msg = sprintf("Please note that remaining jobs were killed when 1st error occurred to save cluster time.\nIf you want to further debug errors, your batchtools registry is here:\n%s",
215 0
            reg$file.dir)
216 0
          batchtools::killJobs(reg = reg)
217 0
          ids.notdone = batchtools::findNotDone(reg = reg)
218 0
          stopWithJobErrorMessages(
219 0
            inds = ids.notdone$job.id,
220 0
            batchtools::getErrorMessages(ids.notdone, missing.as.error = TRUE, reg = reg)$message,
221 0
            extra.msg)
222
        } else { # if we reached this line and error occurred, we have impute.error != NULL (NULL --> stop before)
223 0
          res = batchtools::findJobs(reg = reg)
224 0
          res$result = list(NULL)
225 0
          ids.complete = batchtools::findDone(reg = reg)
226 0
          ids.incomplete = batchtools::findNotDone(reg = reg)
227 0
          res[match(ids.complete$job.id, res$job.id), "result"] = list(batchtools::reduceResultsList(ids.complete, reg = reg))
228 0
          res[match(ids.incomplete$job.id, res$job.id), "result"] = list(lapply(batchtools::getErrorMessages(ids.incomplete, reg = reg)$message, simpleError))
229 0
          res = res$result
230
        }
231
      }
232
    }
233
  }
234

235
  # handle potential errors in res, depending on user setting
236 2
  if (is.null(impute.error)) {
237 2
    checkResultsAndStopWithErrorsMessages(res)
238
  } else {
239 2
    res = lapply(res, function(x) {
240 2
      if (inherits(x, "parallelMapErrorWrapper")) {
241 2
        impute.error.fun(attr(x$try.object, "condition"))
242
      } else {
243 2
        x
244
      }
245
    })
246
  }
247

248 2
  if (use.names && !is.null(names(..1))) {
249 2
    names(res) = names(..1)
250 2
  } else if (use.names && is.character(..1)) {
251 2
    names(res) = ..1
252 2
  } else if (!use.names) {
253 2
    names(res) = NULL
254
  }
255 2
  if (isTRUE(simplify) && length(res) > 0L) {
256 2
    res = simplify2array(res, higher = simplify)
257
  }
258

259
  # count number of mapping operations for log dir
260 2
  options(parallelMap.nextmap = (getPMOptNextMap() + 1L))
261

262 2
  return(res)
263
}
264

265
slaveWrapper = function(..., .i, .fun, .logdir = NA_character_) {
266

267 2
  if (!is.na(.logdir)) {
268 2
    options(warning.length = 8170L, warn = 1L)
269 2
    .fn = file.path(.logdir, sprintf("%05i.log", .i))
270 2
    .fn = file(.fn, open = "wt")
271 2
    .start.time = as.integer(Sys.time())
272 2
    sink(.fn)
273 2
    sink(.fn, type = "message")
274 2
    on.exit(sink(NULL))
275
  }
276

277
  # make sure we dont parallelize any further
278 2
  options(parallelMap.on.slave = TRUE)
279
  # just make sure, we should not have changed anything on the master
280
  # except for BatchJobs / interactive
281 2
  on.exit(options(parallelMap.on.slave = FALSE))
282

283
  # wrap in try block so we can handle error on master
284 2
  res = try(.fun(...))
285
  # now we cant simply return the error object, because clusterMap would act on it. great...
286 2
  if (BBmisc::is.error(res)) {
287 2
    res = list(try.object = res)
288 2
    class(res) = "parallelMapErrorWrapper"
289
  }
290 2
  if (!is.na(.logdir)) {
291 2
    .end.time = as.integer(Sys.time())
292 2
    print(gc())
293 2
    message(sprintf("Job time in seconds: %i", .end.time - .start.time))
294
    # I am not sure why i need to do this again, but without i crash in multicore
295 2
    sink(NULL)
296
  }
297 2
  return(res)
298
}
299

300
assignInFunctionNamespace = function(fun, li = list(), env = new.env()) {
301
  # copy exported objects in PKG_LOCAL_ENV to env of fun so we can find them in any case in call
302 2
  ee = environment(fun)
303 2
  ns = ls(env)
304 2
  for (n in ns) {
305 2
    assign(n, get(n, envir = env), envir = ee)
306
  }
307 2
  ns = names(li)
308 2
  for (n in ns) {
309 0
    assign(n, li[[n]], envir = ee)
310
  }
311
}

Read our documentation on viewing source code .

Loading