google / flax

@@ -24,6 +24,7 @@
Loading
24 24
from flax.linen.initializers import zeros
25 25
from flax.linen.module import compact
26 26
from flax.linen.module import Module
27 +
from flax.linen.dtypes import promote_dtype
27 28
from jax import eval_shape
28 29
from jax import lax
29 30
from jax import ShapedArray
@@ -62,7 +63,7 @@
Loading
62 63
      (-2, -1) will apply the transformation to the last two axes.
63 64
    batch_dims: tuple with batch axes.
64 65
    use_bias: whether to add a bias to the output (default: True).
65 -
    dtype: the dtype of the computation (default: float32).
66 +
    dtype: the dtype of the computation (default: infer from input and params).
66 67
    param_dtype: the dtype passed to parameter initializers (default: float32).
67 68
    kernel_init: initializer function for the weight matrix.
68 69
    bias_init: initializer function for the bias.
@@ -73,7 +74,7 @@
Loading
73 74
  axis: Union[int, Sequence[int]] = -1
74 75
  batch_dims: Sequence[int] = ()
75 76
  use_bias: bool = True
76 -
  dtype: Dtype = jnp.float32
77 +
  dtype: Optional[Dtype] = None
77 78
  param_dtype: Dtype = jnp.float32
78 79
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
79 80
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
@@ -98,8 +99,6 @@
Loading
98 99
        raise ValueError('batch_dims %s must be consecutive leading '
99 100
                         'dimensions starting from 0.' % str(batch_dims))
100 101
101 -
    inputs = jnp.asarray(inputs, self.dtype)
102 -
103 102
    ndim = inputs.ndim
104 103
    n_batch_dims = len(batch_dims)
105 104
    axis = _normalize_axes(axis, ndim)
@@ -122,15 +121,10 @@
Loading
122 121
    kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
123 122
    kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape,
124 123
                        self.param_dtype)
125 -
    kernel = jnp.asarray(kernel, self.dtype)
126 124
127 125
    batch_ind = tuple(range(n_batch_dims))
128 126
    contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims))
129 -
    out = lax.dot_general(inputs,
130 -
                          kernel,
131 -
                          ((axis, contract_ind), (batch_dims, batch_ind)),
132 -
                          precision=self.precision)
133 -
    # dot_general output has shape [batch_dims/group_dims] + [feature_dims]
127 +
134 128
    if self.use_bias:
135 129
      def bias_init_wrap(rng, shape, dtype=jnp.float32):
136 130
        size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
@@ -141,10 +135,20 @@
Loading
141 135
142 136
      bias = self.param('bias', bias_init_wrap, batch_shape + features,
143 137
                        self.param_dtype)
138 +
    else:
139 +
      bias = None
140 +
141 +
    inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
142 +
143 +
    out = lax.dot_general(inputs,
144 +
                          kernel,
145 +
                          ((axis, contract_ind), (batch_dims, batch_ind)),
146 +
                          precision=self.precision)
147 +
    # dot_general output has shape [batch_dims/group_dims] + [feature_dims]
148 +
    if self.use_bias:
144 149
      # expand bias shape to broadcast bias over batch dims.
145 150
      bias = jnp.reshape(bias, expanded_batch_shape + features)
146 -
      bias = jnp.asarray(bias, self.dtype)
147 -
      out = out + bias
151 +
      out += bias
148 152
    return out
149 153
150 154
@@ -154,7 +158,7 @@
Loading
154 158
  Attributes:
155 159
    features: the number of output features.
156 160
    use_bias: whether to add a bias to the output (default: True).
157 -
    dtype: the dtype of the computation (default: float32).
161 +
    dtype: the dtype of the computation (default: infer from input and params).
158 162
    param_dtype: the dtype passed to parameter initializers (default: float32).
159 163
    precision: numerical precision of the computation see `jax.lax.Precision`
160 164
      for details.
@@ -163,7 +167,7 @@
Loading
163 167
  """
164 168
  features: int
165 169
  use_bias: bool = True
166 -
  dtype: Dtype = jnp.float32
170 +
  dtype: Optional[Dtype] = None
167 171
  param_dtype: Dtype = jnp.float32
168 172
  precision: PrecisionLike = None
169 173
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
@@ -179,19 +183,20 @@
Loading
179 183
    Returns:
180 184
      The transformed input.
181 185
    """
182 -
    inputs = jnp.asarray(inputs, self.dtype)
183 186
    kernel = self.param('kernel',
184 187
                        self.kernel_init,
185 -
                        (inputs.shape[-1], self.features),
188 +
                        (jnp.shape(inputs)[-1], self.features),
186 189
                        self.param_dtype)
187 -
    kernel = jnp.asarray(kernel, self.dtype)
188 -
    y = lax.dot_general(inputs, kernel,
189 -
                        (((inputs.ndim - 1,), (0,)), ((), ())),
190 -
                        precision=self.precision)
191 190
    if self.use_bias:
192 191
      bias = self.param('bias', self.bias_init, (self.features,),
193 192
                        self.param_dtype)
194 -
      bias = jnp.asarray(bias, self.dtype)
193 +
    else:
194 +
      bias = None
195 +
    inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
196 +
    y = lax.dot_general(inputs, kernel,
197 +
                        (((inputs.ndim - 1,), (0,)), ((), ())),
198 +
                        precision=self.precision)
199 +
    if bias is not None:
195 200
      y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
196 201
    return y
197 202
@@ -259,7 +264,7 @@
Loading
259 264
    feature_group_count: integer, default 1. If specified divides the input
260 265
      features into groups.
261 266
    use_bias: whether to add a bias to the output (default: True).
262 -
    dtype: the dtype of the computation (default: float32).
267 +
    dtype: the dtype of the computation (default: infer from input and params).
263 268
    param_dtype: the dtype passed to parameter initializers (default: float32).
264 269
    precision: numerical precision of the computation see `jax.lax.Precision`
265 270
      for details.
@@ -274,7 +279,7 @@
Loading
274 279
  kernel_dilation: Union[None, int, Sequence[int]] = 1
275 280
  feature_group_count: int = 1
276 281
  use_bias: bool = True
277 -
  dtype: Dtype = jnp.float32
282 +
  dtype: Optional[Dtype] = None
278 283
  param_dtype: Dtype = jnp.float32
279 284
  precision: PrecisionLike = None
280 285
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
@@ -307,8 +312,6 @@
Loading
307 312
      The convolved data.
308 313
    """
309 314
310 -
    inputs = jnp.asarray(inputs, self.dtype)
311 -
312 315
    if isinstance(self.kernel_size, int):
313 316
      raise TypeError('Expected Conv kernel_size to be a'
314 317
                      ' tuple/list of integers (eg.: [3, 3]) but got'
@@ -348,7 +351,7 @@
Loading
348 351
      padding_lax = 'VALID'
349 352
350 353
    dimension_numbers = _conv_dimension_numbers(inputs.shape)
351 -
    in_features = inputs.shape[-1]
354 +
    in_features = jnp.shape(inputs)[-1]
352 355
353 356
    if self.shared_weights:
354 357
      # One shared convolutional kernel for all pixels in the output.
@@ -371,7 +374,9 @@
Loading
371 374
              rhs=rhs,
372 375
              window_strides=strides,
373 376
              padding=padding_lax,
374 -
              dimension_numbers=dimension_numbers
377 +
              dimension_numbers=dimension_numbers,
378 +
              lhs_dilation=input_dilation,
379 +
              rhs_dilation=kernel_dilation,
375 380
          ),
376 381
          inputs,
377 382
          ShapedArray(kernel_size + (in_features, self.features), inputs.dtype)
@@ -383,8 +388,20 @@
Loading
383 388
384 389
    kernel = self.param('kernel', self.kernel_init, kernel_shape,
385 390
                        self.param_dtype)
386 -
    kernel = jnp.asarray(kernel, self.dtype)
387 391
392 +
    if self.use_bias:
393 +
      if self.shared_weights:
394 +
        # One bias weight per output channel, shared between pixels.
395 +
        bias_shape = (self.features,)
396 +
      else:
397 +
        # One bias weight per output entry, unshared betwen pixels.
398 +
        bias_shape = conv_output_shape[1:]
399 +
400 +
      bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype)
401 +
    else:
402 +
      bias = None
403 +
404 +
    inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
388 405
    if self.shared_weights:
389 406
      y = lax.conv_general_dilated(
390 407
          inputs,
@@ -411,15 +428,6 @@
Loading
411 428
      )
412 429
413 430
    if self.use_bias:
414 -
      if self.shared_weights:
415 -
        # One bias weight per output channel, shared between pixels.
416 -
        bias_shape = (self.features,)
417 -
      else:
418 -
        # One bias weight per output entry, unshared betwen pixels.
419 -
        bias_shape = y.shape[1:]
420 -
421 -
      bias = self.param('bias', self.bias_init, bias_shape, self.param_dtype)
422 -
      bias = jnp.asarray(bias, self.dtype)
423 431
      bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape)
424 432
      y += bias
425 433
@@ -464,7 +472,7 @@
Loading
464 472
      kernel. Convolution with kernel dilation is also known as 'atrous
465 473
      convolution'.
466 474
    use_bias: whether to add a bias to the output (default: True).
467 -
    dtype: the dtype of the computation (default: float32).
475 +
    dtype: the dtype of the computation (default: infer from input and params).
468 476
    param_dtype: the dtype passed to parameter initializers (default: float32).
469 477
    precision: numerical precision of the computation see `jax.lax.Precision`
470 478
      for details.
@@ -477,7 +485,7 @@
Loading
477 485
  padding: PaddingLike = 'SAME'
478 486
  kernel_dilation: Optional[Sequence[int]] = None
479 487
  use_bias: bool = True
480 -
  dtype: Dtype = jnp.float32
488 +
  dtype: Dtype = None
481 489
  param_dtype: Dtype = jnp.float32
482 490
  precision: PrecisionLike = None
483 491
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
@@ -499,8 +507,6 @@
Loading
499 507
    Returns:
500 508
      The convolved data.
501 509
    """
502 -
    inputs = jnp.asarray(inputs, self.dtype)
503 -
504 510
    kernel_size: Tuple[int, ...]
505 511
    if isinstance(self.kernel_size, int):
506 512
      kernel_size = (self.kernel_size,)
@@ -515,16 +521,24 @@
Loading
515 521
    strides: Tuple[int, ...]
516 522
    strides = self.strides or (1,) * (inputs.ndim - 2)
517 523
518 -
    in_features = inputs.shape[-1]
524 +
    in_features = jnp.shape(inputs)[-1]
519 525
    kernel_shape = kernel_size + (in_features, self.features)
520 526
    kernel = self.param('kernel', self.kernel_init, kernel_shape,
521 527
                        self.param_dtype)
522 -
    kernel = jnp.asarray(kernel, self.dtype)
523 528
524 529
    padding_lax = canonicalize_padding(self.padding, len(kernel_size))
525 530
    if padding_lax == 'CIRCULAR':
526 531
      padding_lax = 'VALID'
527 532
533 +
    if self.use_bias:
534 +
      bias = self.param('bias', self.bias_init, (self.features,),
535 +
                        self.param_dtype)
536 +
    else:
537 +
      bias = None
538 +
539 +
    inputs, kernel, bias = promote_dtype(inputs, kernel, bias,
540 +
                                         dtype=self.dtype)
541 +
528 542
    y = lax.conv_transpose(
529 543
        inputs,
530 544
        kernel,
@@ -544,7 +558,7 @@
Loading
544 558
      # Compute period along each spatial dimension - it's input size scaled
545 559
      # by the stride.
546 560
      scaled_x_dims = [
547 -
          x_dim * stride for x_dim, stride in zip(inputs.shape[1:-1], strides)
561 +
          x_dim * stride for x_dim, stride in zip(jnp.shape(inputs)[1:-1], strides)
548 562
      ]
549 563
      # Compute difference between the current size of y and the final output
550 564
      # size, and complement this difference to 2 * period - that gives how
@@ -570,9 +584,6 @@
Loading
570 584
    if is_single_input:
571 585
      y = jnp.squeeze(y, axis=0)
572 586
    if self.use_bias:
573 -
      bias = self.param('bias', self.bias_init, (self.features,),
574 -
                        self.param_dtype)
575 -
      bias = jnp.asarray(bias, self.dtype)
576 587
      y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
577 588
    return y
578 589
@@ -588,13 +599,13 @@
Loading
588 599
  Attributes:
589 600
    num_embeddings: number of embeddings.
590 601
    features: number of feature dimensions for each embedding.
591 -
    dtype: the dtype of the embedding vectors (default: float32).
602 +
    dtype: the dtype of the embedding vectors (default: same as embedding).
592 603
    param_dtype: the dtype passed to parameter initializers (default: float32).
593 604
    embedding_init: embedding initializer.
594 605
  """
595 606
  num_embeddings: int
596 607
  features: int
597 -
  dtype: Dtype = jnp.float32
608 +
  dtype: Optional[Dtype] = None
598 609
  param_dtype: Dtype = jnp.float32
599 610
  embedding_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_embed_init
600 611
@@ -620,7 +631,7 @@
Loading
620 631
      raise ValueError('Input type must be an integer or unsigned integer.')
621 632
    # Use take because fancy indexing numpy arrays with JAX indices does not
622 633
    # work correctly.
623 -
    embedding = jnp.asarray(self.embedding, self.dtype)
634 +
    embedding, = promote_dtype(self.embedding, dtype=self.dtype, inexact=False)
624 635
    return jnp.take(embedding, inputs, axis=0)
625 636
626 637
  def attend(self, query: Array) -> Array:
@@ -635,6 +646,5 @@
Loading
635 646
      Commonly used for weight-sharing between embeddings and logit transform
636 647
      in NLP models.
637 648
    """
638 -
    query = jnp.asarray(query, self.dtype)
639 -
    embedding = jnp.asarray(self.embedding, self.dtype)
649 +
    query, embedding = promote_dtype(query, self.embedding, dtype=self.dtype)
640 650
    return jnp.dot(query, embedding.T)

@@ -17,7 +17,7 @@
Loading
17 17
18 18
# pylint: disable=unused-import
19 19
# re-export activation functions from jax.nn
20 -
from typing import Any
20 +
from typing import Any, Optional
21 21
22 22
from flax.linen.module import compact
23 23
from flax.linen.module import Module
@@ -52,14 +52,18 @@
Loading
52 52
53 53
54 54
Array = Any
55 +
Dtype = Any
55 56
56 57
57 58
class PReLU(Module):
58 59
  """Parametric Rectified Linear Unit (PReLU) activation function.
59 60
60 61
  Attributes:
61 -
    negative_slope_init: the value to initialize the negative slope.
62 +
    param_dtype: the dtype passed to parameter initializers (default: float32).
63 +
    negative_slope_init: the value to initialize the negative slope
64 +
      (default 0.01).
62 65
  """
66 +
  param_dtype: Dtype = jnp.float32
63 67
  negative_slope_init: float = 0.01
64 68
65 69
  @compact
@@ -74,6 +78,6 @@
Loading
74 78
    """
75 79
    negative_slope = self.param(
76 80
        'negative_slope',
77 -
        lambda k: jnp.asarray(self.negative_slope_init, jnp.float32))
81 +
        lambda k: jnp.asarray(self.negative_slope_init, self.param_dtype))
78 82
    return jnp.where(inputs >= 0, inputs,
79 83
                     jnp.asarray(negative_slope, inputs.dtype) * inputs)

@@ -0,0 +1,98 @@
Loading
1 +
# Copyright 2022 The Flax Authors.
2 +
#
3 +
# Licensed under the Apache License, Version 2.0 (the "License");
4 +
# you may not use this file except in compliance with the License.
5 +
# You may obtain a copy of the License at
6 +
#
7 +
#     http://www.apache.org/licenses/LICENSE-2.0
8 +
#
9 +
# Unless required by applicable law or agreed to in writing, software
10 +
# distributed under the License is distributed on an "AS IS" BASIS,
11 +
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +
# See the License for the specific language governing permissions and
13 +
# limitations under the License.
14 +
15 +
# Copyright 2022 The Flax Authors.
16 +
#
17 +
# Licensed under the Apache License, Version 2.0 (the "License");
18 +
# you may not use this file except in compliance with the License.
19 +
# You may obtain a copy of the License at
20 +
#
21 +
#     http://www.apache.org/licenses/LICENSE-2.0
22 +
#
23 +
# Unless required by applicable law or agreed to in writing, software
24 +
# distributed under the License is distributed on an "AS IS" BASIS,
25 +
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26 +
# See the License for the specific language governing permissions and
27 +
# limitations under the License.
28 +
"""APIs for handling dtypes in Linen Modules."""
29 +
30 +
from typing import Any, Optional, List
31 +
32 +
from jax import numpy as jnp
33 +
import jax
34 +
35 +
36 +
Dtype = Any
37 +
Array = Any
38 +
39 +
40 +
def canonicalize_dtype(*args,
41 +
                       dtype: Optional[Dtype] = None,
42 +
                       inexact: bool = True) -> Dtype:
43 +
  """Canonicalize an optional dtype to the definitive dtype.
44 +
45 +
  If the ``dtype`` is None this function will infer the dtype. If it is not
46 +
  None it will be returned unmodified or an exceptions is raised if the dtype
47 +
  is invalid.
48 +
  from the input arguments using ``jnp.result_type``.
49 +
50 +
  Args:
51 +
    *args: JAX array compatible values. None values
52 +
      are ignored.
53 +
    dtype: Optional dtype override. If specified the arguments are cast to
54 +
      the specified dtype instead and dtype inference is disabled.
55 +
    inexact: When True, the output dtype must be a subdtype
56 +
    of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
57 +
    is useful when you want to apply operations that don't work directly on
58 +
    integers like taking a mean for example.
59 +
  Returns:
60 +
    The dtype that *args should be cast to.
61 +
  """
62 +
  if dtype is None:
63 +
    args_filtered = [jnp.asarray(x) for x in args if x is not None]
64 +
    dtype = jnp.result_type(*args_filtered)
65 +
    if inexact and not jnp.issubdtype(dtype, jnp.inexact):
66 +
      dtype = jnp.promote_types(jnp.float32, dtype)
67 +
  if inexact and not jnp.issubdtype(dtype, jnp.inexact):
68 +
    raise ValueError(f'Dtype must be inexact: {dtype}')
69 +
  return dtype
70 +
71 +
72 +
def promote_dtype(*args, dtype=None, inexact=True) -> List[Array]:
73 +
  """"Promotes input arguments to a specified or inferred dtype.
74 +
75 +
  All args are cast to the same dtype. See ``canonicalize_dtype`` for how
76 +
  this dtype is determined.
77 +
78 +
  The behavior of promote_dtype is mostly a convinience wrapper around
79 +
  ``jax.numpy.promote_types``. The differences being that it automatically casts
80 +
  all input to the inferred dtypes, allows inference to be overridden by a
81 +
  forced dtype, and has an optional check to garantuee the resulting dtype is
82 +
  inexact.
83 +
84 +
  Args:
85 +
    *args: JAX array compatible values. None values
86 +
      are returned as is.
87 +
    dtype: Optional dtype override. If specified the arguments are cast to
88 +
      the specified dtype instead and dtype inference is disabled.
89 +
    inexact: When True, the output dtype must be a subdtype
90 +
    of `jnp.inexact`. Inexact dtypes are real or complex floating points. This
91 +
    is useful when you want to apply operations that don't work directly on
92 +
    integers like taking a mean for example.
93 +
  Returns:
94 +
    The arguments cast to arrays of the same dtype.
95 +
  """
96 +
  dtype = canonicalize_dtype(*args, dtype=dtype, inexact=inexact)
97 +
  return [jnp.asarray(x, dtype) if x is not None else None
98 +
          for x in args]

@@ -16,6 +16,7 @@
Loading
16 16
17 17
import functools
18 18
from typing import (Any, Callable, Optional, Tuple)
19 +
from flax.linen.dtypes import promote_dtype
19 20
20 21
from flax.linen.initializers import zeros
21 22
from flax.linen.linear import default_kernel_init
@@ -44,7 +45,7 @@
Loading
44 45
                                  dropout_rng: Optional[PRNGKey] = None,
45 46
                                  dropout_rate: float = 0.,
46 47
                                  deterministic: bool = False,
47 -
                                  dtype: Dtype = jnp.float32,
48 +
                                  dtype: Optional[Dtype] = None,
48 49
                                  precision: PrecisionLike = None):
49 50
  """Computes dot-product attention weights given query and key.
50 51
@@ -70,13 +71,16 @@
Loading
70 71
    dropout_rng: JAX PRNGKey: to be used for dropout
71 72
    dropout_rate: dropout rate
72 73
    deterministic: bool, deterministic or not (to apply dropout)
73 -
    dtype: the dtype of the computation (default: float32)
74 +
    dtype: the dtype of the computation (default: infer from inputs and params)
74 75
    precision: numerical precision of the computation see `jax.lax.Precision`
75 76
      for details.
76 77
77 78
  Returns:
78 79
    Output of shape `[batch..., num_heads, q_length, kv_length]`.
79 80
  """
81 +
  query, key = promote_dtype(query, key, dtype=dtype)
82 +
  dtype = query.dtype
83 +
80 84
  assert query.ndim == key.ndim, 'q, k must have same rank.'
81 85
  assert query.shape[:-3] == key.shape[:-3], (
82 86
      'q, k batch dims must match.')
@@ -111,7 +115,7 @@
Loading
111 115
      keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
112 116
    else:
113 117
      keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
114 -
    multiplier = (keep.astype(attn_weights.dtype) /
118 +
    multiplier = (keep.astype(dtype) /
115 119
                  jnp.asarray(keep_prob, dtype=dtype))
116 120
    attn_weights = attn_weights * multiplier
117 121
@@ -127,7 +131,7 @@
Loading
127 131
                          dropout_rng: Optional[PRNGKey] = None,
128 132
                          dropout_rate: float = 0.,
129 133
                          deterministic: bool = False,
130 -
                          dtype: Dtype = jnp.float32,
134 +
                          dtype: Optional[Dtype] = None,
131 135
                          precision: PrecisionLike = None):
132 136
  """Computes dot-product attention given query, key, and value.
133 137
@@ -157,13 +161,15 @@
Loading
157 161
    dropout_rng: JAX PRNGKey: to be used for dropout
158 162
    dropout_rate: dropout rate
159 163
    deterministic: bool, deterministic or not (to apply dropout)
160 -
    dtype: the dtype of the computation (default: float32)
164 +
    dtype: the dtype of the computation (default: infer from inputs)
161 165
    precision: numerical precision of the computation see `jax.lax.Precision`
162 166
      for details.
163 167
164 168
  Returns:
165 169
    Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
166 170
  """
171 +
  query, key, value = promote_dtype(query, key, value, dtype=dtype)
172 +
  dtype = query.dtype
167 173
  assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
168 174
  assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
169 175
      'q, k, v batch dims must match.')
@@ -173,8 +179,8 @@
Loading
173 179
174 180
  # compute attention weights
175 181
  attn_weights = dot_product_attention_weights(
176 -
    query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate,
177 -
    deterministic, dtype, precision)
182 +
      query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate,
183 +
      deterministic, dtype, precision)
178 184
179 185
  # return weighted sum over values for each query position
180 186
  return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value,
@@ -187,8 +193,9 @@
Loading
187 193
    Attributes:
188 194
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
189 195
        should be divisible by the number of heads.
190 -
      dtype: the dtype of the computation (default: float32)
191 -
      param_dtype: the dtype passed to parameter initializers (default: float32).
196 +
      dtype: the dtype of the computation
197 +
        (default: infer from inputs and params)
198 +
      param_dtype: the dtype passed to parameter initializers (default: float32)
192 199
      qkv_features: dimension of the key, query, and value.
193 200
      out_features: dimension of the last projection
194 201
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
@@ -207,7 +214,7 @@
Loading
207 214
      decode: whether to prepare and use an autoregressive cache.
208 215
  """
209 216
  num_heads: int
210 -
  dtype: Dtype = jnp.float32
217 +
  dtype: Optional[Dtype] = None
211 218
  param_dtype: Dtype = jnp.float32
212 219
  qkv_features: Optional[int] = None
213 220
  out_features: Optional[int] = None
@@ -255,14 +262,14 @@
Loading
255 262
    head_dim = qkv_features // self.num_heads
256 263
257 264
    dense = functools.partial(DenseGeneral,
258 -
                    axis=-1,
259 -
                    dtype=self.dtype,
260 -
                    param_dtype=self.param_dtype,
261 -
                    features=(self.num_heads, head_dim),
262 -
                    kernel_init=self.kernel_init,
263 -
                    bias_init=self.bias_init,
264 -
                    use_bias=self.use_bias,
265 -
                    precision=self.precision)
265 +
                              axis=-1,
266 +
                              dtype=self.dtype,
267 +
                              param_dtype=self.param_dtype,
268 +
                              features=(self.num_heads, head_dim),
269 +
                              kernel_init=self.kernel_init,
270 +
                              bias_init=self.bias_init,
271 +
                              use_bias=self.use_bias,
272 +
                              precision=self.precision)
266 273
    # project inputs_q to multi-headed q/k/v
267 274
    # dimensions are then [batch..., length, n_heads, n_features_per_head]
268 275
    query, key, value = (dense(name='query')(inputs_q),
@@ -345,7 +352,7 @@
Loading
345 352
346 353
  @compact
347 354
  def __call__(self, inputs_q: Array, mask: Optional[Array] = None,
348 -
               deterministic: Optional[bool] = None):   
355 +
               deterministic: Optional[bool] = None):
349 356
    """Applies multi-head dot product self-attention on the input data.
350 357
351 358
    Projects the inputs into multi-headed query, key, and value vectors,
@@ -365,7 +372,8 @@
Loading
365 372
    Returns:
366 373
      output of shape `[batch_sizes..., length, features]`.
367 374
    """
368 -
    return super().__call__(inputs_q, inputs_q, mask, deterministic=deterministic)
375 +
    return super().__call__(inputs_q, inputs_q, mask,
376 +
                            deterministic=deterministic)
369 377
370 378
371 379
# mask-making utility functions

@@ -15,6 +15,7 @@
Loading
15 15
"""Normalization modules for Flax."""
16 16
17 17
from typing import (Any, Callable, Iterable, Optional, Tuple, Union)
18 +
from flax.linen.dtypes import canonicalize_dtype
18 19
19 20
from flax.linen.module import Module, compact, merge_param  # pylint: disable=g-multiple-import
20 21
from jax import lax
@@ -46,13 +47,14 @@
Loading
46 47
47 48
48 49
def _compute_stats(x: Array, axes: Axes,
50 +
                   dtype: Optional[Dtype],
49 51
                   axis_name: Optional[str] = None,
50 52
                   axis_index_groups: Any = None):
51 53
  """Computes mean and variance statistics.
52 54
53 55
  This implementation takes care of a few important details:
54 -
  - Computes in float32 precision for half precision inputs
55 -
  -  mean and variance is computable in a single XLA fusion,
56 +
  - Computes in float32 precision for stability in half precision training.
57 +
  - mean and variance are computable in a single XLA fusion,
56 58
    by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
57 59
  - Clips negative variances to zero which can happen due to
58 60
    roundoff errors. This avoids downstream NaNs.
@@ -62,15 +64,21 @@
Loading
62 64
  Arguments:
63 65
    x: Input array.
64 66
    axes: The axes in ``x`` to compute mean and variance statistics for.
67 +
    dtype: Optional dtype specifying the minimal precision. Statistics
68 +
      are always at least float32 for stability (default: dtype of x).
65 69
    axis_name: Optional name for the pmapped axis to compute mean over.
66 70
    axis_index_groups: Optional axis indices.
67 71
68 72
  Returns:
69 73
    A pair ``(mean, var)``.
70 74
  """
75 +
  if dtype is None:
76 +
    dtype = jnp.result_type(x)
71 77
  # promote x to at least float32, this avoids half precision computation
72 78
  # but preserves double or complex floating points
73 -
  x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
79 +
  dtype = jnp.promote_types(dtype, jnp.float32)
80 +
  x = jnp.asarray(x, dtype)
81 +
  
74 82
  mean = jnp.mean(x, axes)
75 83
  mean2 = jnp.mean(_abs_sq(x), axes)
76 84
  if axis_name is not None:
@@ -104,8 +112,8 @@
Loading
104 112
    reduction_axes: The axes in ``x`` to reduce.
105 113
    feature_axes: Axes containing features. A separate bias and scale is learned
106 114
      for each specified feature.
107 -
    dtype: Dtype of the returned result.
108 -
    param_dtype: Dtype of the parameters.
115 +
    dtype: The dtype of the result (default: infer from input and params).
116 +
    param_dtype: The dtype of the parameters.
109 117
    epsilon: Normalization epsilon.
110 118
    use_bias: If true, add a bias term to the output.
111 119
    use_scale: If true, scale the output.
@@ -129,15 +137,19 @@
Loading
129 137
    reduced_feature_shape.append(x.shape[ax])
130 138
  y = x - mean
131 139
  mul = lax.rsqrt(var + epsilon)
140 +
  args = [x]
132 141
  if use_scale:
133 142
    scale = mdl.param('scale', scale_init, reduced_feature_shape,
134 143
                      param_dtype).reshape(feature_shape)
135 144
    mul *= scale
145 +
    args.append(scale)
136 146
  y *= mul
137 147
  if use_bias:
138 148
    bias = mdl.param('bias', bias_init, reduced_feature_shape,
139 149
                     param_dtype).reshape(feature_shape)
140 150
    y += bias
151 +
    args.append(bias)
152 +
  dtype = canonicalize_dtype(*args, dtype=dtype)
141 153
  return jnp.asarray(y, dtype)
142 154
143 155
@@ -178,7 +190,7 @@
Loading
178 190
    momentum: decay rate for the exponential moving average of
179 191
      the batch statistics.
180 192
    epsilon: a small float added to variance to avoid dividing by zero.
181 -
    dtype: the dtype of the computation (default: float32).
193 +
    dtype: the dtype of the result (default: infer from input and params).
182 194
    param_dtype: the dtype passed to parameter initializers (default: float32).
183 195
    use_bias:  if True, bias (beta) is added.
184 196
    use_scale: if True, multiply by scale (gamma).
@@ -198,7 +210,7 @@
Loading
198 210
  axis: int = -1
199 211
  momentum: float = 0.99
200 212
  epsilon: float = 1e-5
201 -
  dtype: Dtype = jnp.float32
213 +
  dtype: Optional[Dtype] = None
202 214
  param_dtype: Dtype = jnp.float32
203 215
  use_bias: bool = True
204 216
  use_scale: bool = True
@@ -248,6 +260,7 @@
Loading
248 260
    else:
249 261
      mean, var = _compute_stats(
250 262
          x, reduction_axes,
263 +
          dtype=self.dtype,
251 264
          axis_name=self.axis_name if not initializing else None,
252 265
          axis_index_groups=self.axis_index_groups)
253 266
@@ -275,7 +288,7 @@
Loading
275 288
276 289
  Attributes:
277 290
    epsilon: A small float added to variance to avoid dividing by zero.
278 -
    dtype: the dtype of the computation (default: float32).
291 +
    dtype: the dtype of the result (default: infer from input and params).
279 292
    param_dtype: the dtype passed to parameter initializers (default: float32).
280 293
    use_bias:  If True, bias (beta) is added.
281 294
    use_scale: If True, multiply by scale (gamma). When the next layer is linear
@@ -285,7 +298,7 @@
Loading
285 298
    scale_init: Initializer for scale, by default, one.
286 299
  """
287 300
  epsilon: float = 1e-6
288 -
  dtype: Any = jnp.float32
301 +
  dtype: Optional[Dtype] = None
289 302
  param_dtype: Dtype = jnp.float32
290 303
  use_bias: bool = True
291 304
  use_scale: bool = True
@@ -306,7 +319,7 @@
Loading
306 319
    feature_axes = (-1,)
307 320
308 321
    # TODO(jheek) suport axis_name for model parallelism?
309 -
    mean, var = _compute_stats(x, reduction_axes, None, None)
322 +
    mean, var = _compute_stats(x, reduction_axes, self.dtype, None, None)
310 323
311 324
    return _normalize(
312 325
        self, x, mean, var, reduction_axes, feature_axes,
@@ -330,9 +343,8 @@
Loading
330 343
        proposed by the original group normalization paper.
331 344
      group_size: the number of channels in a group.
332 345
      epsilon: A small float added to variance to avoid dividing by zero.
333 -
      dtype: the dtype of the computation (default: float32).
334 -
      param_dtype: the dtype passed to parameter initializers (default:
335 -
        float32).
346 +
      dtype: the dtype of the result (default: infer from input and params).
347 +
      param_dtype: the dtype passed to parameter initializers (default: float32).
336 348
      use_bias:  If True, bias (beta) is added.
337 349
      use_scale: If True, multiply by scale (gamma). When the next layer is
338 350
        linear (also e.g. nn.relu), this can be disabled since the scaling will
@@ -343,7 +355,7 @@
Loading
343 355
  num_groups: Optional[int] = 32
344 356
  group_size: Optional[int] = None
345 357
  epsilon: float = 1e-6
346 -
  dtype: Any = jnp.float32
358 +
  dtype: Optional[Dtype] = None
347 359
  param_dtype: Dtype = jnp.float32
348 360
  use_bias: bool = True
349 361
  use_scale: bool = True
@@ -396,7 +408,7 @@
Loading
396 408
397 409
    # TODO(jheek): suport axis_name for model parallelism?
398 410
    mean, var = _compute_stats(
399 -
        x.reshape(group_shape), reduction_axes, None, None)
411 +
        x.reshape(group_shape), reduction_axes, self.dtype, None, None)
400 412
    mean = broadcast_stat(mean)
401 413
    var = broadcast_stat(var)
402 414

@@ -24,6 +24,7 @@
Loading
24 24
25 25
from flax.linen.activation import sigmoid
26 26
from flax.linen.activation import tanh
27 +
from flax.linen.dtypes import promote_dtype
27 28
from flax.linen.initializers import orthogonal
28 29
from flax.linen.initializers import zeros
29 30
from flax.linen.linear import Conv
@@ -87,7 +88,7 @@
Loading
87 88
    recurrent_kernel_init: initializer function for the kernels that transform
88 89
      the hidden state (default: orthogonal).
89 90
    bias_init: initializer for the bias parameters (default: zeros)
90 -
    dtype: the dtype of the computation (default: float32).
91 +
    dtype: the dtype of the computation (default: infer from inputs and params).
91 92
    param_dtype: the dtype passed to parameter initializers (default: float32).
92 93
  """
93 94
  gate_fn: Callable[..., Any] = sigmoid
@@ -95,7 +96,7 @@
Loading
95 96
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
96 97
  recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = orthogonal()
97 98
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
98 -
  dtype: Dtype = jnp.float32
99 +
  dtype: Optional[Dtype] = None
99 100
  param_dtype: Dtype = jnp.float32
100 101
101 102
  @compact
@@ -153,11 +154,10 @@
Loading
153 154
154 155
155 156
class DenseParams(Module):
156 -
  """Dummy module for creating parameters matching `flax.deprecated.nn.Dense`."""
157 +
  """Dummy module for creating parameters matching `flax.linen.Dense`."""
157 158
158 159
  features: int
159 160
  use_bias: bool = True
160 -
  dtype: Dtype = jnp.float32
161 161
  param_dtype: Dtype = jnp.float32
162 162
  precision: PrecisionLike = None
163 163
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
@@ -166,11 +166,12 @@
Loading
166 166
  @compact
167 167
  def __call__(self, inputs: Array) -> Tuple[Array, Array]:
168 168
    k = self.param(
169 -
        'kernel', self.kernel_init, (inputs.shape[-1], self.features))
169 +
        'kernel', self.kernel_init, (inputs.shape[-1], self.features),
170 +
        self.param_dtype)
170 171
    if self.use_bias:
171 -
      b = self.param('bias', self.bias_init, (self.features,))
172 +
      b = self.param('bias', self.bias_init, (self.features,), self.param_dtype)
172 173
    else:
173 -
      b = jnp.zeros((self.features,))
174 +
      b = None
174 175
    return k, b
175 176
176 177
@@ -206,7 +207,7 @@
Loading
206 207
    recurrent_kernel_init: initializer function for the kernels that transform
207 208
      the hidden state (default: orthogonal).
208 209
    bias_init: initializer for the bias parameters (default: zeros).
209 -
    dtype: the dtype of the computation (default: float32).
210 +
    dtype: the dtype of the computation (default: infer from inputs and params).
210 211
    param_dtype: the dtype passed to parameter initializers (default: float32).
211 212
  """
212 213
  gate_fn: Callable[..., Any] = sigmoid
@@ -214,7 +215,7 @@
Loading
214 215
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
215 216
  recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = orthogonal()
216 217
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
217 -
  dtype: Dtype = jnp.float32
218 +
  dtype: Optional[Dtype] = None
218 219
  param_dtype: Dtype = jnp.float32
219 220
220 221
  @compact
@@ -233,7 +234,6 @@
Loading
233 234
    """
234 235
    c, h = carry
235 236
    hidden_features = h.shape[-1]
236 -
    inputs = jnp.asarray(inputs, self.dtype)
237 237
238 238
    def _concat_dense(inputs: Array,
239 239
                      params: Mapping[str, Tuple[Array, Array]],
@@ -242,15 +242,18 @@
Loading
242 242
      # single kernel and single bias for efficiency before applying them using
243 243
      # dot_general.
244 244
      kernels, biases = zip(*params.values())
245 -
      kernel = jnp.asarray(jnp.concatenate(kernels, axis=-1), self.dtype)
246 -
245 +
      kernel = jnp.concatenate(kernels, axis=-1)
246 +
      if use_bias:
247 +
        bias = jnp.concatenate(biases, axis=-1)
248 +
      else:
249 +
        bias = None
250 +
      inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
247 251
      y = jnp.dot(inputs, kernel)
248 252
      if use_bias:
249 -
        bias = jnp.asarray(jnp.concatenate(biases, axis=-1), self.dtype)
250 253
        y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
251 254
252 255
      # Split the result back into individual (i, f, g, o) outputs.
253 -
      split_indices = np.cumsum([b.shape[0] for b in biases[:-1]])
256 +
      split_indices = np.cumsum([kernel.shape[-1] for kernel in kernels[:-1]])
254 257
      ys = jnp.split(y, split_indices, axis=-1)
255 258
      return dict(zip(params.keys(), ys))
256 259
@@ -323,7 +326,7 @@
Loading
323 326
    recurrent_kernel_init: initializer function for the kernels that transform
324 327
      the hidden state (default: orthogonal).
325 328
    bias_init: initializer for the bias parameters (default: zeros)
326 -
    dtype: the dtype of the computation (default: float32).
329 +
    dtype: the dtype of the computation (default: None).
327 330
    param_dtype: the dtype passed to parameter initializers (default: float32).
328 331
  """
329 332
  gate_fn: Callable[..., Any] = sigmoid
@@ -333,7 +336,7 @@
Loading
333 336
  recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = (
334 337
      orthogonal())
335 338
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
336 -
  dtype: Dtype = jnp.float32
339 +
  dtype: Optional[Dtype] = None
337 340
  param_dtype: Dtype = jnp.float32
338 341
339 342
  @compact
@@ -427,7 +430,7 @@
Loading
427 430
      of `n` `(low, high)` integer pairs that give the padding to apply before
428 431
      and after each spatial dimension.
429 432
    bias: whether to add a bias to the output (default: True).
430 -
    dtype: the dtype of the computation (default: float32).
433 +
    dtype: the dtype of the computation (default: None).
431 434
    param_dtype: the dtype passed to parameter initializers (default: float32).
432 435
  """
433 436
@@ -436,7 +439,7 @@
Loading
436 439
  strides: Optional[Sequence[int]] = None
437 440
  padding: Union[str, Sequence[Tuple[int, int]]] = 'SAME'
438 441
  use_bias: bool = True
439 -
  dtype: Dtype = jnp.float32
442 +
  dtype: Optional[Dtype] = None
440 443
  param_dtype: Dtype = jnp.float32
441 444
442 445
  @compact
Files Coverage
flax 75.24%
Project Totals (60 files) 75.24%
Notifications are pending CI completion. Periodically Codecov will check the CI state, when complete notifications will be submitted. Push notifications now.

No yaml found.

Create your codecov.yml to customize your Codecov experience

Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading