google / flax
Showing 1 of 7 files from the diff.

@@ -243,3 +243,81 @@
Loading
243 243
  c, ys = _scan_nd(body_wrapper, init, xs, n=len(axis), unroll=unroll)
244 244
  ys = jax.tree_map(transpose_out, ys)
245 245
  return c, ys
246 +
247 +
248 +
# Copied from https://github.com/google-research/big_vision
249 +
def pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=(),
250 +
                    static_return=False):
251 +
  """Wraps a function with code that pads, shards, then un-shards, un-pads.
252 +
253 +
  Args:
254 +
    wrapped: the function to be wrapped. Signature is `params, *args, *kwargs`.
255 +
    static_argnums: indices of arguments to `wrapped` that should _not_ be
256 +
      padded and sharded, but instead be forwarded as-is. The default is (0,)
257 +
      because by far the most common use-case is to pass `params` first.
258 +
    static_argnames: names of kwargs to `wrapped` that should _not_ be padded
259 +
      and sharded, but instead be forwarded as-is.
260 +
    static_return: whether not to un-shard, and un-pad the return value; static
261 +
      return values are typically used with eval steps that compute metrics
262 +
263 +
  Returns:
264 +
    A new function that pads and shards its arguments before passing them to
265 +
    the wrapped function, and un-shards and un-pads the returned pytree.
266 +
267 +
    This is useful for calling a pmap'ed function with inputs that aren't
268 +
    divisible by the number of devices. A typical use is:
269 +
      @pad_shard_unpad
270 +
      @jax.pmap
271 +
      def forward(params, x): ...
272 +
273 +
  Notes:
274 +
    The padding is done in host-memory before being passed to the function, and
275 +
    the values returned by the function are transferred back to host memory.
276 +
277 +
    The returned function is augmented with a new keyword-only argument
278 +
    `min_device_batch` that, if specified, forces padding inputs to at least
279 +
    this size per device. This can be useful to avoid recompiles for the last
280 +
    batch and reduce memory fragmentation.
281 +
282 +
    For more information refer to
283 +
    https://flax.readthedocs.io/en/latest/howtos/full_eval.html
284 +
  """
285 +
286 +
  def pad_shard_unpad_wrapper(*args, min_device_batch=None, **kw):
287 +
    d = jax.local_device_count()  # d = devices, b = batch
288 +
    batch_sizes = set()
289 +
    for i, a in enumerate(args):
290 +
      if i not in static_argnums:
291 +
        batch_sizes |= {t.shape[0] for t in jax.tree_leaves(a)}
292 +
    for k, v in kw.items():
293 +
      if k not in static_argnames:
294 +
        batch_sizes |= {t.shape[0] for t in jax.tree_leaves(v)}
295 +
    assert len(batch_sizes) == 1, f"Inconsistent batch-sizes: {batch_sizes}"
296 +
    b = batch_sizes.pop()
297 +
298 +
    def pad(x):
299 +
      _, *shape = x.shape
300 +
      db, rest = divmod(b, d)
301 +
      if rest:
302 +
        x = np.concatenate([x, np.zeros((d - rest, *shape), x.dtype)], axis=0)
303 +
        db += 1
304 +
      if min_device_batch and db < min_device_batch:
305 +
        x = np.concatenate(
306 +
            [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)])
307 +
        db = min_device_batch
308 +
      return x.reshape(d, db, *shape)
309 +
310 +
    def maybe_pad(tree, actually_pad=True):
311 +
      if not actually_pad: return tree  # For call-site convenience below.
312 +
      return jax.tree_map(pad, tree)
313 +
314 +
    args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)]
315 +
    kw = {k: maybe_pad(v, k not in static_argnames) for k, v in kw.items()}
316 +
    out = wrapped(*args, **kw)
317 +
318 +
    def unpad(x):
319 +
      # Transfer back before cutting, to reduce on-device shape diversity.
320 +
      return jax.device_get(x).reshape([np.prod(x.shape[:2]), *x.shape[2:]])[:b]
321 +
    return out if static_return else jax.tree_map(unpad, out)
322 +
323 +
  return pad_shard_unpad_wrapper
Files Coverage
flax 75.11%
Project Totals (59 files) 75.11%

No yaml found.

Create your codecov.yml to customize your Codecov experience

Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading