1
"""ArviZ rcparams. Based on matplotlib's implementation."""
2 6
import locale
3 6
import logging
4 6
import os
5 6
import pprint
6 6
import re
7 6
import sys
8 6
import warnings
9 6
from collections.abc import MutableMapping
10 6
from pathlib import Path
11

12 6
import numpy as np
13

14 6
_log = logging.getLogger(__name__)
15

16

17 6
def _make_validate_choice(accepted_values, allow_none=False, typeof=str):
18
    """Validate value is in accepted_values.
19

20
    Parameters
21
    ----------
22
    accepted_values : iterable
23
        Iterable containing all accepted_values.
24
    allow_none: boolean, optional
25
        Whether to accept ``None`` in addition to the values in ``accepted_values``.
26
    typeof: type, optional
27
        Type the values should be converted to.
28
    """
29
    # no blank lines allowed after function docstring by pydocstyle,
30
    # but black requires white line before function
31

32 6
    def validate_choice(value):
33 6
        if allow_none and (value is None or isinstance(value, str) and value.lower() == "none"):
34 4
            return None
35 6
        try:
36 6
            value = typeof(value)
37 4
        except (ValueError, TypeError) as err:
38 4
            raise ValueError("Could not convert to {}".format(typeof.__name__)) from err
39 6
        if isinstance(value, str):
40 6
            value = value.lower()
41

42 6
        if value in accepted_values:
43
            # Convert value to python boolean if string matches
44 6
            value = {"true": True, "false": False}.get(value, value)
45 6
            return value
46 4
        raise ValueError(
47
            "{} is not one of {}{}".format(
48
                value, accepted_values, " nor None" if allow_none else ""
49
            )
50
        )
51

52 6
    return validate_choice
53

54

55 6
def _make_validate_choice_regex(accepted_values, accepted_values_regex, allow_none=False):
56
    """Validate value is in accepted_values with regex.
57

58
    Parameters
59
    ----------
60
    accepted_values : iterable
61
        Iterable containing all accepted_values.
62
    accepted_values_regex : iterable
63
        Iterable containing all accepted_values with regex string.
64
    allow_none: boolean, optional
65
        Whether to accept ``None`` in addition to the values in ``accepted_values``.
66
    typeof: type, optional
67
        Type the values should be converted to.
68
    """
69
    # no blank lines allowed after function docstring by pydocstyle,
70
    # but black requires white line before function
71

72 6
    def validate_choice_regex(value):
73 6
        if allow_none and (value is None or isinstance(value, str) and value.lower() == "none"):
74 4
            return None
75 6
        value = str(value)
76 6
        if isinstance(value, str):
77 6
            value = value.lower()
78

79 6
        if value in accepted_values:
80
            # Convert value to python boolean if string matches
81 6
            value = {"true": True, "false": False}.get(value, value)
82 6
            return value
83 4
        elif any(re.match(pattern, value) for pattern in accepted_values_regex):
84 4
            return value
85 4
        raise ValueError(
86
            "{} is not one of {} or in regex {}{}".format(
87
                value, accepted_values, accepted_values_regex, " nor None" if allow_none else ""
88
            )
89
        )
90

91 6
    return validate_choice_regex
92

93

94 6
def _validate_positive_int(value):
95
    """Validate value is a natural number."""
96 6
    try:
97 6
        value = int(value)
98 4
    except ValueError as err:
99 4
        raise ValueError("Could not convert to int") from err
100 6
    if value > 0:
101 6
        return value
102
    else:
103 4
        raise ValueError("Only positive values are valid")
104

105

106 6
def _validate_positive_int_or_none(value):
107
    """Validate value is a natural number or None."""
108 6
    if value is None or isinstance(value, str) and value.lower() == "none":
109 4
        return None
110
    else:
111 6
        return _validate_positive_int(value)
112

113

114 6
def _validate_float(value):
115
    """Validate value is a float."""
116 6
    try:
117 6
        value = float(value)
118 4
    except ValueError as err:
119 4
        raise ValueError("Could not convert to float") from err
120 6
    return value
121

122

123 6
def _validate_float_or_none(value):
124
    """Validate value is a float or None."""
125 4
    if value is None or isinstance(value, str) and value.lower() == "none":
126 0
        return None
127
    else:
128 4
        return _validate_float(value)
129

130

131 6
def _validate_probability(value):
132
    """Validate a probability: a float between 0 and 1."""
133 6
    value = _validate_float(value)
134 6
    if (value < 0) or (value > 1):
135 4
        raise ValueError("Only values between 0 and 1 are valid.")
136 6
    return value
137

138

139 6
def _validate_boolean(value):
140
    """Validate value is a float."""
141 6
    if value not in {True, "true", False, "false"}:
142 0
        raise ValueError("Only boolean values are valid.")
143 6
    return value is True or value == "true"
144

145

146 6
def _validate_bokeh_marker(value):
147
    """Validate the markers."""
148 6
    all_markers = (
149
        "Asterisk",
150
        "Circle",
151
        "CircleCross",
152
        "CircleX",
153
        "Cross",
154
        "Dash",
155
        "Diamond",
156
        "DiamondCross",
157
        "Hex",
158
        "InvertedTriangle",
159
        "Square",
160
        "SquareCross",
161
        "SquareX",
162
        "Triangle",
163
        "X",
164
    )
165 6
    if value not in all_markers:
166 0
        raise ValueError("{} is not one of {}".format(value, all_markers))
167 6
    return value
168

169

170 6
def _validate_dict_of_lists(values):
171 6
    if isinstance(values, dict):
172 6
        return {key: tuple(item) for key, item in values.items()}
173
    else:
174 4
        validated_dict = {}
175 4
        for value in values:
176 4
            tup = value.split(":", 1)
177 4
            if len(tup) != 2:
178 0
                raise ValueError(f"Could not interpret '{value}' as key: list or str")
179 4
            key, vals = tup
180 4
            key = key.strip(' "')
181 4
            vals = [val.strip(' "') for val in vals.strip(" [],").split(",")]
182 4
            if key in validated_dict:
183 0
                warnings.warn(f"Repeated key {key} when validating dict of lists")
184 4
            validated_dict[key] = tuple(vals)
185 4
        return validated_dict
186

187

188 6
def make_iterable_validator(scalar_validator, length=None, allow_none=False, allow_auto=False):
189
    """Validate value is an iterable datatype."""
190
    # based on matplotlib's _listify_validator function
191

192 6
    def validate_iterable(value):
193 6
        if allow_none and (value is None or isinstance(value, str) and value.lower() == "none"):
194 4
            return None
195 6
        if isinstance(value, str):
196 6
            if allow_auto and value.lower() == "auto":
197 6
                return "auto"
198 4
            value = tuple(v.strip("([ ])") for v in value.split(",") if v.strip())
199 4
        if np.iterable(value) and not isinstance(value, (set, frozenset)):
200 4
            val = tuple(scalar_validator(v) for v in value)
201 4
            if length is not None and len(val) != length:
202 4
                raise ValueError("Iterable must be of length: {}".format(length))
203 4
            return val
204 4
        raise ValueError("Only ordered iterable values are valid")
205

206 6
    return validate_iterable
207

208

209 6
_validate_bokeh_bounds = make_iterable_validator(  # pylint: disable=invalid-name
210
    _validate_float_or_none, length=2, allow_none=True, allow_auto=True
211
)
212

213 6
METAGROUPS = {
214
    "posterior_groups": ["posterior", "posterior_predictive", "sample_stats", "log_likelihood"],
215
    "prior_groups": ["prior", "prior_predictive", "sample_stats_prior"],
216
    "posterior_groups_warmup": [
217
        "_warmup_posterior",
218
        "_warmup_posterior_predictive",
219
        "_warmup_sample_stats",
220
    ],
221
    "latent_vars": ["posterior", "prior"],
222
    "observed_vars": ["posterior_predictive", "observed_data", "prior_predictive"],
223
}
224

225 6
defaultParams = {  # pylint: disable=invalid-name
226
    "data.http_protocol": ("https", _make_validate_choice({"https", "http"})),
227
    "data.load": ("lazy", _make_validate_choice({"lazy", "eager"})),
228
    "data.metagroups": (METAGROUPS, _validate_dict_of_lists),
229
    "data.index_origin": (0, _make_validate_choice({0, 1}, typeof=int)),
230
    "data.save_warmup": (False, _validate_boolean),
231
    "data.pandas_float_precision": (
232
        "high",
233
        _make_validate_choice({"high", "round_trip"}, allow_none=True),
234
    ),
235
    "plot.backend": ("matplotlib", _make_validate_choice({"matplotlib", "bokeh"})),
236
    "plot.max_subplots": (40, _validate_positive_int_or_none),
237
    "plot.point_estimate": (
238
        "mean",
239
        _make_validate_choice({"mean", "median", "mode"}, allow_none=True),
240
    ),
241
    "plot.bokeh.bounds_x_range": ("auto", _validate_bokeh_bounds),
242
    "plot.bokeh.bounds_y_range": ("auto", _validate_bokeh_bounds),
243
    "plot.bokeh.figure.dpi": (60, _validate_positive_int),
244
    "plot.bokeh.figure.height": (500, _validate_positive_int),
245
    "plot.bokeh.figure.width": (500, _validate_positive_int),
246
    "plot.bokeh.layout.order": (
247
        "default",
248
        _make_validate_choice_regex(
249
            {"default", r"column", r"row", "square", "square_trimmed"}, {r"\d*column", r"\d*row"}
250
        ),
251
    ),
252
    "plot.bokeh.layout.sizing_mode": (
253
        "fixed",
254
        _make_validate_choice(
255
            {
256
                "fixed",
257
                "stretch_width",
258
                "stretch_height",
259
                "stretch_both",
260
                "scale_width",
261
                "scale_height",
262
                "scale_both",
263
            }
264
        ),
265
    ),
266
    "plot.bokeh.layout.toolbar_location": (
267
        "above",
268
        _make_validate_choice({"above", "below", "left", "right"}, allow_none=True),
269
    ),
270
    "plot.bokeh.marker": ("Cross", _validate_bokeh_marker),
271
    "plot.bokeh.output_backend": ("webgl", _make_validate_choice({"canvas", "svg", "webgl"})),
272
    "plot.bokeh.show": (True, _validate_boolean),
273
    "plot.bokeh.tools": (
274
        "reset,pan,box_zoom,wheel_zoom,lasso_select,undo,save,hover",
275
        lambda x: x,
276
    ),
277
    "plot.matplotlib.constrained_layout": (True, _validate_boolean),
278
    "plot.matplotlib.show": (False, _validate_boolean),
279
    "stats.hdi_prob": (0.94, _validate_probability),
280
    "stats.information_criterion": ("loo", _make_validate_choice({"waic", "loo"})),
281
    "stats.ic_pointwise": (False, _validate_boolean),
282
    "stats.ic_scale": ("log", _make_validate_choice({"deviance", "log", "negative_log"})),
283
}
284

285

286 6
class RcParams(MutableMapping, dict):  # pylint: disable=too-many-ancestors
287
    """Class to contain ArviZ default parameters.
288

289
    It is implemented as a dict with validation when setting items.
290
    """
291

292 6
    validate = {key: validate_fun for key, (_, validate_fun) in defaultParams.items()}
293

294
    # validate values on the way in
295 6
    def __init__(self, *args, **kwargs):  # pylint: disable=super-init-not-called
296 6
        self.update(*args, **kwargs)
297

298 6
    def __setitem__(self, key, val):
299
        """Add validation to __setitem__ function."""
300 6
        try:
301 6
            try:
302 6
                cval = self.validate[key](val)
303 4
            except ValueError as verr:
304 4
                raise ValueError("Key %s: %s" % (key, str(verr))) from verr
305 6
            dict.__setitem__(self, key, cval)
306 4
        except KeyError as err:
307 4
            raise KeyError(
308
                "{} is not a valid rc parameter (see rcParams.keys() for "
309
                "a list of valid parameters)".format(key)
310
            ) from err
311

312 6
    def __getitem__(self, key):
313
        """Use dict getitem method."""
314 6
        return dict.__getitem__(self, key)
315

316 6
    def __delitem__(self, key):
317
        """Raise TypeError if someone ever tries to delete a key from RcParams."""
318 4
        raise TypeError("RcParams keys cannot be deleted")
319

320 6
    def clear(self):
321
        """Raise TypeError if someone ever tries to delete all keys from RcParams."""
322 4
        raise TypeError("RcParams keys cannot be deleted")
323

324 6
    def pop(self, key, default=None):
325
        """Raise TypeError if someone ever tries to delete a key from RcParams."""
326 4
        raise TypeError(
327
            "RcParams keys cannot be deleted. Use .get(key) of RcParams[key] to check values"
328
        )
329

330 6
    def popitem(self):
331
        """Raise TypeError if someone ever tries to delete a key from RcParams."""
332 4
        raise TypeError(
333
            "RcParams keys cannot be deleted. Use .get(key) of RcParams[key] to check values"
334
        )
335

336 6
    def setdefault(self, key, default=None):
337
        """Raise error when using setdefault, defaults are handled on initialization."""
338 4
        raise TypeError(
339
            "Defaults in RcParams are handled on object initialization during library"
340
            "import. Use arvizrc file instead."
341
            ""
342
        )
343

344 6
    def items(self):
345
        """Explicit use of MutableMapping attributes."""
346 4
        return MutableMapping.items(self)
347

348 6
    def keys(self):
349
        """Explicit use of MutableMapping attributes."""
350 4
        return MutableMapping.keys(self)
351

352 6
    def values(self):
353
        """Explicit use of MutableMapping attributes."""
354 0
        return MutableMapping.values(self)
355

356 6
    def __repr__(self):
357
        """Customize repr of RcParams objects."""
358 4
        class_name = self.__class__.__name__
359 4
        indent = len(class_name) + 1
360 4
        repr_split = pprint.pformat(dict(self), indent=1, width=80 - indent).split("\n")
361 4
        repr_indented = ("\n" + " " * indent).join(repr_split)
362 4
        return "{}({})".format(class_name, repr_indented)
363

364 6
    def __str__(self):
365
        """Customize str/print of RcParams objects."""
366 4
        return "\n".join(map("{0[0]:<22}: {0[1]}".format, sorted(self.items())))
367

368 6
    def __iter__(self):
369
        """Yield sorted list of keys."""
370 4
        yield from sorted(dict.__iter__(self))
371

372 6
    def __len__(self):
373
        """Use dict len method."""
374 4
        return dict.__len__(self)
375

376 6
    def find_all(self, pattern):
377
        """
378
        Find keys that match a regex pattern.
379

380
        Return the subset of this RcParams dictionary whose keys match,
381
        using :func:`re.search`, the given ``pattern``.
382

383
        Notes
384
        -----
385
            Changes to the returned dictionary are *not* propagated to
386
            the parent RcParams dictionary.
387
        """
388 4
        pattern_re = re.compile(pattern)
389 4
        return RcParams((key, value) for key, value in self.items() if pattern_re.search(key))
390

391 6
    def copy(self):
392
        """Get a copy of the RcParams object."""
393 4
        return {k: dict.__getitem__(self, k) for k in self}
394

395

396 6
def get_arviz_rcfile():
397
    """Get arvizrc file.
398

399
    The file location is determined in the following order:
400

401
    - ``$PWD/arvizrc``
402
    - ``$ARVIZ_DATA/arvizrc``
403
    - On Linux,
404
        - ``$XDG_CONFIG_HOME/arviz/arvizrc`` (if ``$XDG_CONFIG_HOME``
405
          is defined)
406
        - or ``$HOME/.config/arviz/arvizrc`` (if ``$XDG_CONFIG_HOME``
407
          is not defined)
408
    - On other platforms,
409
        - ``$HOME/.arviz/arvizrc`` if ``$HOME`` is defined
410

411
    Otherwise, the default defined in ``rcparams.py`` file will be used.
412
    """
413
    # no blank lines allowed after function docstring by pydocstyle,
414
    # but black requires white line before function
415

416 6
    def gen_candidates():
417 6
        yield os.path.join(os.getcwd(), "arvizrc")
418 6
        arviz_data_dir = os.environ.get("ARVIZ_DATA")
419 6
        if arviz_data_dir:
420 0
            yield os.path.join(arviz_data_dir, "arvizrc")
421 6
        xdg_base = os.environ.get("XDG_CONFIG_HOME", str(Path.home() / ".config"))
422 6
        if sys.platform.startswith(("linux", "freebsd")):
423 6
            configdir = str(Path(xdg_base, "arviz"))
424
        else:
425 0
            configdir = str(Path.home() / ".arviz")
426 6
        yield os.path.join(configdir, "arvizrc")
427

428 6
    for fname in gen_candidates():
429 6
        if os.path.exists(fname) and not os.path.isdir(fname):
430 0
            return fname
431

432 6
    return None
433

434

435 6
def read_rcfile(fname):
436
    """Return :class:`arviz.RcParams` from the contents of the given file.
437

438
    Unlike `rc_params_from_file`, the configuration class only contains the
439
    parameters specified in the file (i.e. default values are not filled in).
440
    """
441 4
    _error_details_fmt = 'line #%d\n\t"%s"\n\tin file "%s"'
442

443 4
    config = RcParams()
444 4
    with open(fname, "r") as rcfile:
445 4
        try:
446 4
            multiline = False
447 4
            for line_no, line in enumerate(rcfile, 1):
448 4
                strippedline = line.split("#", 1)[0].strip()
449 4
                if not strippedline:
450 4
                    continue
451 4
                if multiline:
452 4
                    if strippedline == "}":
453 4
                        multiline = False
454 4
                        val = aux_val
455
                    else:
456 4
                        aux_val.append(strippedline)
457 4
                        continue
458
                else:
459 4
                    tup = strippedline.split(":", 1)
460 4
                    if len(tup) != 2:
461 0
                        error_details = _error_details_fmt % (line_no, line, fname)
462 0
                        _log.warning("Illegal %s", error_details)
463 0
                        continue
464 4
                    key, val = tup
465 4
                    key = key.strip()
466 4
                    val = val.strip()
467 4
                    if key in config:
468 4
                        _log.warning("Duplicate key in file %r line #%d.", fname, line_no)
469 4
                    if key in {"data.metagroups"}:
470 4
                        aux_val = []
471 4
                        multiline = True
472 4
                        continue
473 4
                try:
474 4
                    config[key] = val
475 4
                except ValueError as verr:
476 4
                    error_details = _error_details_fmt % (line_no, line, fname)
477 4
                    raise ValueError(
478
                        "Bad val {} on {}\n\t{}".format(val, error_details, str(verr))
479
                    ) from verr
480

481 4
        except UnicodeDecodeError:
482 0
            _log.warning(
483
                "Cannot decode configuration file %s with encoding "
484
                "%s, check LANG and LC_* variables.",
485
                fname,
486
                locale.getpreferredencoding(do_setlocale=False) or "utf-8 (default)",
487
            )
488 0
            raise
489

490 4
        return config
491

492

493 6
def rc_params(ignore_files=False):
494
    """Read and validate arvizrc file."""
495 6
    fname = None
496 6
    if not ignore_files:
497 6
        fname = get_arviz_rcfile()
498 6
    defaults = RcParams([(key, default) for key, (default, _) in defaultParams.items()])
499 6
    if fname is not None:
500 0
        file_defaults = read_rcfile(fname)
501 0
        defaults.update(file_defaults)
502 6
    return defaults
503

504

505 6
rcParams = rc_params()  # pylint: disable=invalid-name
506

507

508 6
class rc_context:  # pylint: disable=invalid-name
509
    """
510
    Return a context manager for managing rc settings.
511

512
    Parameters
513
    ----------
514
    rc : dict, optional
515
        Mapping containing the rcParams to modify temporally.
516
    fname : str, optional
517
        Filename of the file containig the rcParams to use inside the rc_context.
518

519
    Examples
520
    --------
521
    This allows one to do::
522

523
        with az.rc_context(fname='pystan.rc'):
524
            idata = az.load_arviz_data("radon")
525
            az.plot_posterior(idata, var_names=["gamma"])
526

527
    The plot would have settings from 'screen.rc'
528

529
    A dictionary can also be passed to the context manager::
530

531
        with az.rc_context(rc={'plot.max_subplots': None}, fname='pystan.rc'):
532
            idata = az.load_arviz_data("radon")
533
            az.plot_posterior(idata, var_names=["gamma"])
534

535
    The 'rc' dictionary takes precedence over the settings loaded from
536
    'fname'. Passing a dictionary only is also valid.
537
    """
538

539
    # Based on mpl.rc_context
540

541 6
    def __init__(self, rc=None, fname=None):
542 4
        self._orig = rcParams.copy()
543 4
        if fname:
544 4
            file_rcparams = read_rcfile(fname)
545 4
            rcParams.update(file_rcparams)
546 4
        if rc:
547 4
            rcParams.update(rc)
548

549 6
    def __enter__(self):
550
        """Define enter method of context manager."""
551 4
        return self
552

553 6
    def __exit__(self, exc_type, exc_value, exc_tb):
554
        """Define exit method of context manager."""
555 4
        rcParams.update(self._orig)

Read our documentation on viewing source code .

Loading