google / flax

Compare 97cd5bc ... +5 ... 5c677ea


@@ -420,7 +420,9 @@
Loading
420 420
         split_rngs: Mapping[PRNGSequenceFilter, bool] = {},
421 421
         in_axes=0, out_axes=0,
422 422
         length: Optional[int] = None,
423 -
         reverse: bool = False) -> Callable[..., Any]:
423 +
         reverse: bool = False,
424 +
         data_transform: Optional[Callable[..., Any]] = None,
425 +
         ) -> Callable[..., Any]:
424 426
  """A lifted version of ``jax.lax.scan``.
425 427
426 428
  See ``jax.lax.scan`` for the unlifted scan in Jax.
@@ -478,6 +480,9 @@
Loading
478 480
    length: Specifies the number of loop iterations. This only needs
479 481
      to be specified if it cannot be derivied from the scan arguments.
480 482
    reverse: If true, scan from end to start in reverse order.
483 +
    data_transform: optional function to transform raw variable and rng groups,
484 +
      intended for inline SPMD annotations.
485 +
481 486
  Returns:
482 487
    The scan function with the signature ``(scope, carry, *xxs) -> (carry, yys)``,
483 488
    where ``xxs`` and ``yys`` are the scan values that go in and out of the loop.
@@ -524,6 +529,9 @@
Loading
524 529
    def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args):
525 530
      carry_vars, c = carry
526 531
      variable_groups = (broadcast_vars, carry_vars) + scan_variable_groups
532 +
      if data_transform is not None:
533 +
        variable_groups, rng_groups = data_transform(variable_groups,
534 +
                                                     rng_groups)
527 535
      scope = scope_fn(variable_groups, rng_groups)
528 536
      c, y = fn(scope, c, *args)
529 537
      out_vars = repack_fn(scope)
@@ -684,7 +692,7 @@
Loading
684 692
        backend: Union[str, None] = None,
685 693
        ) -> Callable[..., Any]:
686 694
  """Lifted version of ``jax.jit``.
687 -
  
695 +
688 696
  Args:
689 697
    fn: Scope function to be jitted.
690 698
    variables: The variable collections that are lifted. By default all
@@ -727,7 +735,7 @@
Loading
727 735
  # while jitted has 3 functions before the user arguments.
728 736
  static_argnums = (0,) + tuple(i + 2 for i in static_argnums if i > 0)
729 737
  donate_argnums = tuple(i + 2 for i in donate_argnums if i > 0)
730 -
  
738 +
731 739
  # Close over scope_fn & repack_fn to avoid recompilation
732 740
  # this is impure but we use the fingerprint arg to differentiate between cases
733 741
  # where scope_fn or repack_fn actually produce non-identical results.

@@ -238,7 +238,7 @@
Loading
238 238
  ``vmap`` can be used to add a batch axis to a ``Module``.
239 239
  For example we could create a version of ``Dense`` with
240 240
  a batch axis that does not share parameters::
241 -
  
241 +
242 242
    BatchDense = nn.vmap(
243 243
        nn.Dense,
244 244
        in_axes=0, out_axes=0,
@@ -296,7 +296,7 @@
Loading
296 296
        backend: Union[str, None] = None,
297 297
        methods=None) -> Target:
298 298
  """Lifted version of ``jax.jit``.
299 -
  
299 +
300 300
  Args:
301 301
    target: a ``Module`` or a function taking a ``Module``
302 302
      as its first argument.
@@ -348,7 +348,7 @@
Loading
348 348
        concrete: bool = False,
349 349
        methods=None) -> Target:
350 350
  """Lifted version of ``jax.checkpoint``.
351 -
  
351 +
352 352
  This function is aliased to ``lift.remat`` just like ``jax.remat``.
353 353
354 354
  Args:
@@ -385,6 +385,7 @@
Loading
385 385
         in_axes=0, out_axes=0,
386 386
         length: Optional[int] = None,
387 387
         reverse: bool = False,
388 +
         data_transform: Optional[Callable[..., Any]] = None,
388 389
         methods=None) -> Target:
389 390
  """A lifted version of ``jax.lax.scan``.
390 391
@@ -458,6 +459,10 @@
Loading
458 459
    length: Specifies the number of loop iterations. This only needs
459 460
      to be specified if it cannot be derivied from the scan arguments.
460 461
    reverse: If true, scan from end to start in reverse order.
462 +
    data_transform: optional function to transform raw functional-core variable
463 +
      and rng groups inside lifted scan body_fn, intended for inline SPMD
464 +
      annotations.
465 +
461 466
  Returns:
462 467
    The scan function with the signature ``(scope, carry, *xxs) -> (carry, yys)``,
463 468
    where ``xxs`` and ``yys`` are the scan values that go in and out of the loop.
@@ -471,6 +476,7 @@
Loading
471 476
      in_axes=in_axes, out_axes=out_axes,
472 477
      length=length,
473 478
      reverse=reverse,
479 +
      data_transform=data_transform,
474 480
      methods=methods)
475 481
476 482

@@ -170,7 +170,7 @@
Loading
170 170
171 171
def _all_names_on_object(obj: Any) -> Set[str]:
172 172
  """Gets all names of attributes on `obj` and its classes throughout MRO.
173 -
  
173 +
174 174
  Args:
175 175
    obj: The object to get names for.
176 176
  Returns:
@@ -185,10 +185,14 @@
Loading
185 185
def _freeze_attr(val: Any) -> Any:
186 186
  if isinstance(val, (dict, FrozenDict)):
187 187
    return FrozenDict({k: _freeze_attr(v) for k, v in val.items()})
188 -
  elif  isinstance(val, tuple) and hasattr(val, '_fields'):
189 -
    # Special case named tuple otherwise they would be downgraded to normal tuples.
190 -
    return type(val)(*[_freeze_attr(v) for v in val])
191 -
  elif isinstance(val, (list, tuple)):
188 +
  elif isinstance(val, tuple):
189 +
    # Special case namedtuples and special JAX tuple structures otherwise they
190 +
    # would be downgraded to normal tuples.
191 +
    if hasattr(val, '_fields') or type(val).__name__ == 'PartitionSpec':
192 +
      return type(val)(*[_freeze_attr(v) for v in val])
193 +
    else:
194 +
      return tuple(_freeze_attr(v) for v in val)
195 +
  elif isinstance(val, list):
192 196
    return tuple(_freeze_attr(v) for v in val)
193 197
  else:
194 198
    return val
@@ -198,7 +202,7 @@
Loading
198 202
# -----------------------------------------------------------------------------
199 203
def compact(fun: _CallableT) -> _CallableT:
200 204
  """Marks the given module method allowing inlined submodules.
201 -
  
205 +
202 206
  Methods wrapped in @compact can define submodules directly within the method.
203 207
204 208
  For instance::
@@ -207,7 +211,7 @@
Loading
207 211
    __call__(self, x, features):
208 212
      x = nn.Dense(features)(x)
209 213
      ...
210 -
  
214 +
211 215
  At most one method in each Module may be wrapped with @compact.
212 216
213 217
  Args:
@@ -221,7 +225,7 @@
Loading
221 225
222 226
def _get_local_method_names(cls: Any, exclude: Iterable[str] = ()) -> Tuple[str]:
223 227
  """Gets method names of a class, excluding class and static methods.
224 -
  
228 +
225 229
  Args:
226 230
    cls: The class to get method names for.
227 231
    excludes: Names to exclude from output.
@@ -239,7 +243,7 @@
Loading
239 243
240 244
def wrap_method_once(fun: Callable[..., Any]) -> Callable[..., Any]:
241 245
  """Manages Module state for a given user-defined method.
242 -
  
246 +
243 247
  Args:
244 248
    fun: User-defined Module method to manage state for.
245 249
  Returns:
@@ -306,7 +310,7 @@
Loading
306 310
307 311
def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]:
308 312
  """Returns an unbound function from a method that is possibly bound.
309 -
  
313 +
310 314
  This means that if the passed function belongs of an instance of a class, then
311 315
  the returned function does no longer depend on the instance, which is passed
312 316
  as the first argument to the function.
@@ -345,7 +349,7 @@
Loading
345 349
346 350
  def reset(self):
347 351
    """Resets transient state.
348 -
    
352 +
349 353
    This function is called after each module method, so only attributes that
350 354
    are method-dependent are reset.
351 355
    """
@@ -420,7 +424,7 @@
Loading
420 424
      def __call__(self, x):
421 425
        return self.dense2(nn.relu(self.dense1(x)))
422 426
423 -
  Optionally, for more concise module implementations where submodules 
427 +
  Optionally, for more concise module implementations where submodules
424 428
  definitions are co-located with their usage, you can use the
425 429
  :meth:`compact` wrapper.
426 430
  """
@@ -521,7 +525,7 @@
Loading
521 525
522 526
  def __setattr__(self, name: str, val: Any):
523 527
    """Sets an attribute on this Module.
524 -
    
528 +
525 529
    We overload setattr solely to support pythonic naming via assignment of
526 530
    submodules in the special :meth:`setup` function::
527 531
@@ -713,9 +717,9 @@
Loading
713 717
            parent: Optional[Union[Scope, 'Module']] = None,
714 718
            **updates) -> 'Module':
715 719
    """Creates a clone of this Module, with optionally updated arguments.
716 -
    
720 +
717 721
    Args:
718 -
      parent: The parent of the clone. The clone will have no parent if no 
722 +
      parent: The parent of the clone. The clone will have no parent if no
719 723
        explicit parent is specified.
720 724
      **updates: Attribute updates.
721 725
    Returns:
@@ -805,7 +809,7 @@
Loading
805 809
806 810
    See :mod:`flax.core.variables` for more explanation on variables and
807 811
    collections.
808 -
    
812 +
809 813
    Args:
810 814
      col: The variable collection name.
811 815
      name: The name of the variable.
@@ -824,7 +828,7 @@
Loading
824 828
825 829
  def make_rng(self, name: str) -> PRNGKey:
826 830
    """Returns a new RNG key from a given RNG sequence for this Module.
827 -
    
831 +
828 832
    The new RNG key is split from the previous one. Thus, every call to
829 833
    `make_rng` returns a new RNG key, while still guaranteeing full
830 834
    reproducibility.
@@ -913,7 +917,7 @@
Loading
913 917
914 918
      model = Transformer()
915 919
      encoded = model.apply({'params': params}, x, method=Transformer.encode)
916 -
  
920 +
917 921
    If a function instance is provided, the unbound function is used. For
918 922
    instance, the example below is equivalent to the one above::
919 923
@@ -1076,7 +1080,7 @@
Loading
1076 1080
      variables = model.init(jax.random.PRNGKey(0), x)
1077 1081
      y, state = model.apply(variables, x, mutable=['intermediates'])
1078 1082
      print(state['intermediates'])  # {'h': (...,)}
1079 -
    
1083 +
1080 1084
    By default the values are stored in a tuple and each stored value
1081 1085
    is appended at the end. This way all intermediates can be tracked when
1082 1086
    the same module is called multiple times. Alternatively, a custom
@@ -1180,7 +1184,7 @@
Loading
1180 1184
      y = foo.decode(z)
1181 1185
      # ...
1182 1186
      return y
1183 -
    
1187 +
1184 1188
    foo = Foo()
1185 1189
    f_jitted = jax.jit(nn.apply(f, foo))
1186 1190
    f_jitted(variables, x)
@@ -1238,7 +1242,7 @@
Loading
1238 1242
      y = foo.decode(z)
1239 1243
      # ...
1240 1244
      return y
1241 -
    
1245 +
1242 1246
    foo = Foo()
1243 1247
    f_jitted = jax.jit(nn.init_with_output(f, foo))
1244 1248
    y, variables = f_jitted(rng, x)
@@ -1282,7 +1286,7 @@
Loading
1282 1286
      y = foo.decode(z)
1283 1287
      # ...
1284 1288
      return y
1285 -
    
1289 +
1286 1290
    foo = Foo()
1287 1291
    f_jitted = jax.jit(nn.init(f, foo))
1288 1292
    variables = f_jitted(rng, x)

Everything is accounted for!

No changes detected that need to be reviewed.
What changes does Codecov check for?
Lines, not adjusted in diff, that have changed coverage data.
Files that introduced coverage data that had none before.
Files that have missing coverage data that once were tracked.
Files Coverage
flax -0.01% 82.28%
Project Totals (65 files) 82.28%
Loading