PiperOrigin-RevId: 449727140
Showing 1 of 7 files from the diff.
flax/jax_utils.py
changed.
Other files ignored by Codecov
tests/jax_utils_test.py
is new.
docs/howtos/full_eval.rst
is new.
CHANGELOG.md
has changed.
docs/flax.jax_utils.rst
has changed.
docs/notebooks/full_eval.ipynb
is new.
docs/index.rst
has changed.
@@ -243,3 +243,81 @@
Loading
243 | 243 | c, ys = _scan_nd(body_wrapper, init, xs, n=len(axis), unroll=unroll) |
|
244 | 244 | ys = jax.tree_map(transpose_out, ys) |
|
245 | 245 | return c, ys |
|
246 | + | ||
247 | + | ||
248 | + | # Copied from https://github.com/google-research/big_vision |
|
249 | + | def pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=(), |
|
250 | + | static_return=False): |
|
251 | + | """Wraps a function with code that pads, shards, then un-shards, un-pads. |
|
252 | + | ||
253 | + | Args: |
|
254 | + | wrapped: the function to be wrapped. Signature is `params, *args, *kwargs`. |
|
255 | + | static_argnums: indices of arguments to `wrapped` that should _not_ be |
|
256 | + | padded and sharded, but instead be forwarded as-is. The default is (0,) |
|
257 | + | because by far the most common use-case is to pass `params` first. |
|
258 | + | static_argnames: names of kwargs to `wrapped` that should _not_ be padded |
|
259 | + | and sharded, but instead be forwarded as-is. |
|
260 | + | static_return: whether not to un-shard, and un-pad the return value; static |
|
261 | + | return values are typically used with eval steps that compute metrics |
|
262 | + | ||
263 | + | Returns: |
|
264 | + | A new function that pads and shards its arguments before passing them to |
|
265 | + | the wrapped function, and un-shards and un-pads the returned pytree. |
|
266 | + | ||
267 | + | This is useful for calling a pmap'ed function with inputs that aren't |
|
268 | + | divisible by the number of devices. A typical use is: |
|
269 | + | @pad_shard_unpad |
|
270 | + | @jax.pmap |
|
271 | + | def forward(params, x): ... |
|
272 | + | ||
273 | + | Notes: |
|
274 | + | The padding is done in host-memory before being passed to the function, and |
|
275 | + | the values returned by the function are transferred back to host memory. |
|
276 | + | ||
277 | + | The returned function is augmented with a new keyword-only argument |
|
278 | + | `min_device_batch` that, if specified, forces padding inputs to at least |
|
279 | + | this size per device. This can be useful to avoid recompiles for the last |
|
280 | + | batch and reduce memory fragmentation. |
|
281 | + | ||
282 | + | For more information refer to |
|
283 | + | https://flax.readthedocs.io/en/latest/howtos/full_eval.html |
|
284 | + | """ |
|
285 | + | ||
286 | + | def pad_shard_unpad_wrapper(*args, min_device_batch=None, **kw): |
|
287 | + | d = jax.local_device_count() # d = devices, b = batch |
|
288 | + | batch_sizes = set() |
|
289 | + | for i, a in enumerate(args): |
|
290 | + | if i not in static_argnums: |
|
291 | + | batch_sizes |= {t.shape[0] for t in jax.tree_leaves(a)} |
|
292 | + | for k, v in kw.items(): |
|
293 | + | if k not in static_argnames: |
|
294 | + | batch_sizes |= {t.shape[0] for t in jax.tree_leaves(v)} |
|
295 | + | assert len(batch_sizes) == 1, f"Inconsistent batch-sizes: {batch_sizes}" |
|
296 | + | b = batch_sizes.pop() |
|
297 | + | ||
298 | + | def pad(x): |
|
299 | + | _, *shape = x.shape |
|
300 | + | db, rest = divmod(b, d) |
|
301 | + | if rest: |
|
302 | + | x = np.concatenate([x, np.zeros((d - rest, *shape), x.dtype)], axis=0) |
|
303 | + | db += 1 |
|
304 | + | if min_device_batch and db < min_device_batch: |
|
305 | + | x = np.concatenate( |
|
306 | + | [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)]) |
|
307 | + | db = min_device_batch |
|
308 | + | return x.reshape(d, db, *shape) |
|
309 | + | ||
310 | + | def maybe_pad(tree, actually_pad=True): |
|
311 | + | if not actually_pad: return tree # For call-site convenience below. |
|
312 | + | return jax.tree_map(pad, tree) |
|
313 | + | ||
314 | + | args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)] |
|
315 | + | kw = {k: maybe_pad(v, k not in static_argnames) for k, v in kw.items()} |
|
316 | + | out = wrapped(*args, **kw) |
|
317 | + | ||
318 | + | def unpad(x): |
|
319 | + | # Transfer back before cutting, to reduce on-device shape diversity. |
|
320 | + | return jax.device_get(x).reshape([np.prod(x.shape[:2]), *x.shape[2:]])[:b] |
|
321 | + | return out if static_return else jax.tree_map(unpad, out) |
|
322 | + | ||
323 | + | return pad_shard_unpad_wrapper |
Files | Coverage |
---|---|
flax | 75.11% |
Project Totals (59 files) | 75.11% |
2352254665
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file.
The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files.
The size and color of each slice is representing the number of statements and the coverage, respectively.