google / flax

@@ -16,7 +16,7 @@
Loading
16 16
"""
17 17
18 18
import functools
19 -
from typing import NamedTuple, Any, Callable, Union, Sequence
19 +
from typing import NamedTuple, Any, Optional, Callable, Union, Sequence
20 20
21 21
from .. import struct
22 22
@@ -86,7 +86,7 @@
Loading
86 86
  def value_and_grad(self, fun: Callable[..., Any],
87 87
                     argnums: Union[int, Sequence[int]] = 0,
88 88
                     has_aux: bool = False,
89 -
                     axis_name: str = None,
89 +
                     axis_name: Optional[str] = None,
90 90
                     ) -> Callable[..., DynamicScaleResult]:
91 91
    """Wrapper around `jax.value_and_grad`.
92 92

@@ -113,7 +113,7 @@
Loading
113 113
  return rng
114 114
115 115
116 -
def _fold_in_static(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey:
116 +
def _fold_in_static(rng: PRNGKey, data: Sequence[PRNGFoldable]) -> PRNGKey:
117 117
  """Folds static data (strings & ints) into a jax.random.PRNGKey using its SHA-1 hash.
118 118
119 119
  This is faster than splitting an PRNGKey because it allows generating new PRNG
@@ -454,6 +454,17 @@
Loading
454 454
  def _validate_trace_level(self):
455 455
    tracers.check_trace_level(self.trace_level)
456 456
457 +
  def rewind(self, rewind_rngs: bool = False):
458 +
    """Resets reservations and optionally rng counters in this Scope *in place*.
459 +
460 +
    Args:
461 +
      rewind_rngs: if true, reset the RNG counter of this scope.
462 +
    """
463 +
    self._check_valid()
464 +
    self.reservations = set()
465 +
    if rewind_rngs:
466 +
      self.rng_counters = {key: 0 for key in self.rngs}
467 +
457 468
  def rewound(self, rewind_rngs: bool = False) -> 'Scope':
458 469
    """Returns a rewound version of this Scope.
459 470
@@ -461,7 +472,7 @@
Loading
461 472
      rewind_rngs: if true, reset the RNG counter of this scope.
462 473
463 474
    Returns:
464 -
      A rewound version of this scope, which means reservations and children are
475 +
      A rewound version of this scope, which means reservations are
465 476
      emptied, and the rng counter is optionally rewound.
466 477
    """
467 478
    self._check_valid()

@@ -639,7 +639,7 @@
Loading
639 639
    finally:
640 640
      _context.module_stack.pop()
641 641
      if is_compact_method:
642 -
        object.__setattr__(self, 'scope', self.scope.rewound())
642 +
        self.scope.rewind()
643 643
      # setup or compact calls can be recurrent for example due to super calls
644 644
      # resetting the state would cause is compact/setup method
645 645
      # to be set to False prematurely.
Files Coverage
flax 83.17%
Project Totals (69 files) 83.17%

No yaml found.

Create your codecov.yml to customize your Codecov experience

Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading