1
#!/usr/bin/env python
2

3
# =============================================================================================
4
# MODULE DOCSTRING
5
# =============================================================================================
6 8
"""
7
Parameter handlers for the SMIRNOFF force field engine
8

9
This file contains standard parameter handlers for the SMIRNOFF force field engine.
10
These classes implement the object model for self-contained parameter assignment.
11
New pluggable handlers can be created by creating subclasses of :class:`ParameterHandler`.
12

13
"""
14

15 8
__all__ = [
16
    "SMIRNOFFSpecError",
17
    "IncompatibleParameterError",
18
    "UnassignedValenceParameterException",
19
    "UnassignedBondParameterException",
20
    "UnassignedAngleParameterException",
21
    "NonbondedMethod",
22
    "ParameterList",
23
    "ParameterType",
24
    "ParameterHandler",
25
    "ParameterAttribute",
26
    "IndexedParameterAttribute",
27
    "IndexedMappedParameterAttribute",
28
    "ConstraintHandler",
29
    "BondHandler",
30
    "AngleHandler",
31
    "ProperTorsionHandler",
32
    "ImproperTorsionHandler",
33
    "vdWHandler",
34
    "GBSAHandler",
35
]
36

37

38
# =============================================================================================
39
# GLOBAL IMPORTS
40
# =============================================================================================
41

42 8
import copy
43 8
import functools
44 8
import inspect
45 8
import logging
46 8
import re
47 8
from collections import OrderedDict, defaultdict
48 8
from enum import Enum
49

50 8
from simtk import openmm, unit
51

52 8
from openforcefield.topology import ImproperDict, SortedDict, ValenceDict
53 8
from openforcefield.topology.molecule import Molecule
54 8
from openforcefield.typing.chemistry import ChemicalEnvironment
55 8
from openforcefield.utils import (
56
    GLOBAL_TOOLKIT_REGISTRY,
57
    IncompatibleUnitError,
58
    MessageException,
59
    attach_units,
60
    extract_serialized_units_from_dict,
61
    object_to_quantity,
62
)
63 8
from openforcefield.utils.collections import ValidatedDict, ValidatedList
64

65
# =============================================================================================
66
# CONFIGURE LOGGER
67
# =============================================================================================
68

69 8
logger = logging.getLogger(__name__)
70

71

72
# ======================================================================
73
# CUSTOM EXCEPTIONS
74
# ======================================================================
75

76

77 8
class SMIRNOFFSpecError(MessageException):
78
    """
79
    Exception for when data is noncompliant with the SMIRNOFF data specification.
80
    """
81

82 8
    pass
83

84

85 8
class IncompatibleParameterError(MessageException):
86
    """
87
    Exception for when a set of parameters is scientifically/technically incompatible with another
88
    """
89

90 8
    pass
91

92

93 8
class UnassignedValenceParameterException(Exception):
94
    """Exception raised when there are valence terms for which a ParameterHandler can't find parameters."""
95

96 8
    pass
97

98

99 8
class UnassignedBondParameterException(UnassignedValenceParameterException):
100
    """Exception raised when there are bond terms for which a ParameterHandler can't find parameters."""
101

102 8
    pass
103

104

105 8
class UnassignedAngleParameterException(UnassignedValenceParameterException):
106
    """Exception raised when there are angle terms for which a ParameterHandler can't find parameters."""
107

108 8
    pass
109

110

111 8
class UnassignedProperTorsionParameterException(UnassignedValenceParameterException):
112
    """Exception raised when there are proper torsion terms for which a ParameterHandler can't find parameters."""
113

114 8
    pass
115

116

117 8
class UnassignedMoleculeChargeException(Exception):
118
    """Exception raised when no charge method is able to assign charges to a molecule."""
119

120 8
    pass
121

122

123 8
class NonintegralMoleculeChargeException(Exception):
124
    """Exception raised when the partial charges on a molecule do not sum up to its formal charge."""
125

126 8
    pass
127

128

129 8
class DuplicateParameterError(MessageException):
130
    """Exception raised when trying to add a ParameterType that already exists"""
131

132

133
# ======================================================================
134
# ENUM TYPES
135
# ======================================================================
136

137

138 8
class NonbondedMethod(Enum):
139
    """
140
    An enumeration of the nonbonded methods
141
    """
142

143 8
    NoCutoff = 0
144 8
    CutoffPeriodic = 1
145 8
    CutoffNonPeriodic = 2
146 8
    Ewald = 3
147 8
    PME = 4
148

149

150
# ======================================================================
151
# PARAMETER ATTRIBUTES
152
# ======================================================================
153

154
# TODO: Think about adding attrs to the dependencies and inherit from attr.ib
155 8
class ParameterAttribute:
156
    """A descriptor for ``ParameterType`` attributes.
157

158
    The descriptors allows associating to the parameter a default value,
159
    which makes the attribute optional, a unit, and a custom converter.
160

161
    Because we may want to have ``None`` as a default value, required
162
    attributes have the ``default`` set to the special type ``UNDEFINED``.
163

164
    Converters can be both static or instance functions/methods with
165
    respective signatures
166

167
    converter(value): -> converted_value
168
    converter(instance, parameter_attribute, value): -> converted_value
169

170
    A decorator syntax is available (see example below).
171

172
    Parameters
173
    ----------
174
    default : object, optional
175
        When specified, the descriptor makes this attribute optional by
176
        attaching a default value to it.
177
    unit : simtk.unit.Quantity, optional
178
        When specified, only quantities with compatible units are allowed
179
        to be set, and string expressions are automatically parsed into a
180
        ``Quantity``.
181
    converter : callable, optional
182
        An optional function that can be used to convert values before
183
        setting the attribute.
184

185
    See Also
186
    --------
187
    IndexedParameterAttribute
188
        A parameter attribute with multiple terms.
189

190
    Examples
191
    --------
192

193
    Create a parameter type with an optional and a required attribute.
194

195
    >>> class MyParameter:
196
    ...     attr_required = ParameterAttribute()
197
    ...     attr_optional = ParameterAttribute(default=2)
198
    ...
199
    >>> my_par = MyParameter()
200

201
    Even without explicit assignment, the default value is returned.
202

203
    >>> my_par.attr_optional
204
    2
205

206
    If you try to access an attribute without setting it first, an
207
    exception is raised.
208

209
    >>> my_par.attr_required
210
    Traceback (most recent call last):
211
    ...
212
    AttributeError: 'MyParameter' object has no attribute '_attr_required'
213

214
    The attribute allow automatic conversion and validation of units.
215

216
    >>> from simtk import unit
217
    >>> class MyParameter:
218
    ...     attr_quantity = ParameterAttribute(unit=unit.angstrom)
219
    ...
220
    >>> my_par = MyParameter()
221
    >>> my_par.attr_quantity = '1.0 * nanometer'
222
    >>> my_par.attr_quantity
223
    Quantity(value=1.0, unit=nanometer)
224
    >>> my_par.attr_quantity = 3.0
225
    Traceback (most recent call last):
226
    ...
227
    openforcefield.utils.utils.IncompatibleUnitError: attr_quantity=3.0 dimensionless should have units of angstrom
228

229
    You can attach a custom converter to an attribute.
230

231
    >>> class MyParameter:
232
    ...     # Both strings and integers convert nicely to floats with float().
233
    ...     attr_all_to_float = ParameterAttribute(converter=float)
234
    ...     attr_int_to_float = ParameterAttribute()
235
    ...     @attr_int_to_float.converter
236
    ...     def attr_int_to_float(self, attr, value):
237
    ...         # This converter converts only integers to float
238
    ...         # and raise an exception for the other types.
239
    ...         if isinstance(value, int):
240
    ...             return float(value)
241
    ...         elif not isinstance(value, float):
242
    ...             raise TypeError(f"Cannot convert '{value}' to float")
243
    ...         return value
244
    ...
245
    >>> my_par = MyParameter()
246

247
    attr_all_to_float accepts and convert to float both strings and integers
248

249
    >>> my_par.attr_all_to_float = 1
250
    >>> my_par.attr_all_to_float
251
    1.0
252
    >>> my_par.attr_all_to_float = '2.0'
253
    >>> my_par.attr_all_to_float
254
    2.0
255

256
    The custom converter associated to attr_int_to_float converts only integers instead.
257
    >>> my_par.attr_int_to_float = 3
258
    >>> my_par.attr_int_to_float
259
    3.0
260
    >>> my_par.attr_int_to_float = '4.0'
261
    Traceback (most recent call last):
262
    ...
263
    TypeError: Cannot convert '4.0' to float
264

265
    """
266

267 8
    class UNDEFINED:
268
        """Custom type used by ``ParameterAttribute`` to differentiate between ``None`` and undeclared default."""
269

270 8
        pass
271

272 8
    def __init__(self, default=UNDEFINED, unit=None, converter=None):
273 8
        self.default = default
274 8
        self._unit = unit
275 8
        self._converter = converter
276

277 8
    def __set_name__(self, owner, name):
278 8
        self._name = "_" + name
279

280 8
    @property
281
    def name(self):
282
        # Get rid of the initial underscore.
283 8
        return self._name[1:]
284

285 8
    def __get__(self, instance, owner):
286 8
        if instance is None:
287
            # This is called from the class. Return the descriptor object.
288 8
            return self
289

290 8
        try:
291 8
            return getattr(instance, self._name)
292 8
        except AttributeError:
293
            # The attribute has not initialized. Check if there's a default.
294 8
            if self.default is ParameterAttribute.UNDEFINED:
295 8
                raise
296 8
            return self.default
297

298 8
    def __set__(self, instance, value):
299
        # Convert and validate the value.
300 8
        value = self._convert_and_validate(instance, value)
301 8
        setattr(instance, self._name, value)
302

303 8
    def converter(self, converter):
304
        """Create a new ParameterAttribute with an associated converter.
305

306
        This is meant to be used as a decorator (see main examples).
307
        """
308 8
        return self.__class__(default=self.default, converter=converter)
309

310 8
    def _convert_and_validate(self, instance, value):
311
        """Convert to Quantity, validate units, and call custom converter."""
312
        # The default value is always allowed.
313 8
        if self._is_valid_default(value):
314 8
            return value
315
        # Convert and validate units.
316 8
        value = self._validate_units(value)
317
        # Call the custom converter before setting the value.
318 8
        value = self._call_converter(value, instance)
319 8
        return value
320

321 8
    def _is_valid_default(self, value):
322
        """Return True if this is a defined default value."""
323 8
        return (
324
            self.default is not ParameterAttribute.UNDEFINED and value == self.default
325
        )
326

327 8
    def _validate_units(self, value):
328
        """Convert strings expressions to Quantity and validate the units if requested."""
329 8
        if self._unit is not None:
330
            # Convert eventual strings to Quantity objects.
331 8
            value = object_to_quantity(value)
332

333
            # Check if units are compatible.
334 8
            try:
335 8
                if not self._unit.is_compatible(value.unit):
336 8
                    raise IncompatibleUnitError(
337
                        f"{self.name}={value} should have units of {self._unit}"
338
                    )
339 8
            except AttributeError:
340
                # This is not a Quantity object.
341 8
                raise IncompatibleUnitError(
342
                    f"{self.name}={value} should have units of {self._unit}"
343
                )
344 8
        return value
345

346 8
    def _call_converter(self, value, instance):
347
        """Correctly calls static and instance converters."""
348 8
        if self._converter is not None:
349 8
            try:
350
                # Static function.
351 8
                return self._converter(value)
352 8
            except TypeError:
353
                # Instance method.
354 8
                return self._converter(instance, self, value)
355 8
        return value
356

357

358 8
class IndexedParameterAttribute(ParameterAttribute):
359
    """The attribute of a parameter with an unspecified number of terms.
360

361
    Some parameters can be associated to multiple terms, For example,
362
    torsions have parameters such as k1, k2, ..., and ``IndexedParameterAttribute``
363
    can be used to encapsulate the sequence of terms.
364

365
    The only substantial difference with ``ParameterAttribute`` is that
366
    only sequences are supported as values and converters and units are
367
    checked on each element of the sequence.
368

369
    Currently, the descriptor makes the sequence immutable. This is to
370
    avoid that an element of the sequence could be set without being
371
    properly validated. In the future, the data could be wrapped in a
372
    safe list that would safely allow mutability.
373

374
    Parameters
375
    ----------
376
    default : object, optional
377
        When specified, the descriptor makes this attribute optional by
378
        attaching a default value to it.
379
    unit : simtk.unit.Quantity, optional
380
        When specified, only sequences of quantities with compatible units
381
        are allowed to be set.
382
    converter : callable, optional
383
        An optional function that can be used to validate and cast each
384
        element of the sequence before setting the attribute.
385

386
    See Also
387
    --------
388
    ParameterAttribute
389
        A simple parameter attribute.
390

391
    Examples
392
    --------
393

394
    Create an optional indexed attribute with unit of angstrom.
395

396
    >>> from simtk import unit
397
    >>> class MyParameter:
398
    ...     length = IndexedParameterAttribute(default=None, unit=unit.angstrom)
399
    ...
400
    >>> my_par = MyParameter()
401
    >>> my_par.length is None
402
    True
403

404
    Strings are parsed into Quantity objects.
405

406
    >>> my_par.length = ['1 * angstrom', 0.5 * unit.nanometer]
407
    >>> my_par.length[0]
408
    Quantity(value=1, unit=angstrom)
409

410
    Similarly, custom converters work as with ``ParameterAttribute``, but
411
    they are used to validate each value in the sequence.
412

413
    >>> class MyParameter:
414
    ...     attr_indexed = IndexedParameterAttribute(converter=float)
415
    ...
416
    >>> my_par = MyParameter()
417
    >>> my_par.attr_indexed = [1, '1.0', '1e-2', 4.0]
418
    >>> my_par.attr_indexed
419
    [1.0, 1.0, 0.01, 4.0]
420

421
    """
422

423 8
    def _convert_and_validate(self, instance, value):
424
        """Overwrite ParameterAttribute._convert_and_validate to make the value a ValidatedList."""
425
        # The default value is always allowed.
426 8
        if self._is_valid_default(value):
427 8
            return value
428

429
        # We push the converters into a ValidatedList so that we can make
430
        # sure that elements are validated correctly when they are modified
431
        # after their initialization.
432
        # ValidatedList expects converters that take the value as a single
433
        # argument so we create a partial function with the instance assigned.
434 8
        static_converter = functools.partial(self._call_converter, instance=instance)
435 8
        value = ValidatedList(value, converter=[self._validate_units, static_converter])
436

437 8
        return value
438

439

440 8
class IndexedMappedParameterAttribute(ParameterAttribute):
441
    """The attribute of a parameter with an unspecified number of terms, where
442
    each term is a mapping.
443

444
    Some parameters can be associated to multiple terms,
445
    where those terms have multiple components.
446
    For example, torsions with fractional bond orders have parameters such as
447
    k1_bondorder1, k1_bondorder2, k2_bondorder1, k2_bondorder2, ..., and
448
    ``IndexedMappedParameterAttribute`` can be used to encapsulate the sequence of
449
    terms as mappings (typically, `dict`s) of their components.
450

451
    The only substantial difference with ``IndexedParameterAttribute`` is that
452
    only sequences of mappings are supported as values and converters and units are
453
    checked on each component of each element in the sequence.
454

455
    Currently, the descriptor makes the sequence immutable. This is to
456
    avoid that an element of the sequence could be set without being
457
    properly validated. In the future, the data could be wrapped in a
458
    safe list that would safely allow mutability.
459

460
    Parameters
461
    ----------
462
    default : object, optional
463
        When specified, the descriptor makes this attribute optional by
464
        attaching a default value to it.
465
    unit : simtk.unit.Quantity, optional
466
        When specified, only sequences of mappings where values are quantities with
467
        compatible units are allowed to be set.
468
    converter : callable, optional
469
        An optional function that can be used to validate and cast each
470
        component of each element of the sequence before setting the attribute.
471

472
    See Also
473
    --------
474
    IndexedParameterAttribute
475
        A parameter attribute representing a sequence.
476

477
    Examples
478
    --------
479

480
    Create an optional indexed attribute with unit of angstrom.
481

482
    >>> from simtk import unit
483
    >>> class MyParameter:
484
    ...     length = IndexedMappedParameterAttribute(default=None, unit=unit.angstrom)
485
    ...
486
    >>> my_par = MyParameter()
487
    >>> my_par.length is None
488
    True
489

490
    Strings are parsed into Quantity objects.
491

492
    >>> my_par.length = [{1:'1 * angstrom'}, {1: 0.5 * unit.nanometer}]
493
    >>> my_par.length[0]
494
    {1: Quantity(value=1, unit=angstrom)}
495

496
    Similarly, custom converters work as with ``ParameterAttribute``, but
497
    they are used to validate each value in the sequence.
498

499
    >>> class MyParameter:
500
    ...     attr_indexed = IndexedMappedParameterAttribute(converter=float)
501
    ...
502
    >>> my_par = MyParameter()
503
    >>> my_par.attr_indexed = [{1: 1}, {2: '1.0', 3: '1e-2'}, {4: 4.0}]
504
    >>> my_par.attr_indexed
505
    [{1: 1.0}, {2: 1.0, 3: 0.01}, {4: 4.0}]
506

507
    """
508

509 8
    def _convert_and_validate(self, instance, value):
510
        """Overwrite ParameterAttribute._convert_and_validate to make the value a ValidatedList."""
511
        # The default value is always allowed.
512 8
        if self._is_valid_default(value):
513 0
            return value
514

515
        # We push the converters into a ValidatedListMapping so that we can make
516
        # sure that elements are validated correctly when they are modified
517
        # after their initialization.
518
        # ValidatedListMapping expects converters that take the value as a single
519
        # argument so we create a partial function with the instance assigned.
520 8
        static_converter = functools.partial(self._call_converter, instance=instance)
521

522 8
        value = ValidatedList(
523
            [
524
                ValidatedDict(
525
                    element, converter=[self._validate_units, static_converter]
526
                )
527
                for element in value
528
            ],
529
            converter=self._index_converter,
530
        )
531

532 8
        return value
533

534 8
    @staticmethod
535
    def _index_converter(x):
536 8
        return ValidatedDict(x)
537

538

539 8
class _ParameterAttributeHandler:
540
    """A base class for ``ParameterType`` and ``ParameterHandler`` objects.
541

542
    Encapsulate shared code of ``ParameterType`` and ``ParameterHandler``.
543
    In particular, this base class provides an ``__init__`` method that
544
    automatically initialize the attributes defined through the ``ParameterAttribute``
545
    and ``IndexedParameterAttribute`` descriptors, as well as handling
546
    cosmetic attributes.
547

548
    See Also
549
    --------
550
    ParameterAttribute
551
        A simple parameter attribute.
552
    IndexedParameterAttribute
553
        A parameter attribute with multiple terms.
554

555
    Examples
556
    --------
557

558
    This base class was design to encapsulate shared code between ``ParameterType``
559
    and ``ParameterHandler``, which both need to deal with parameter and cosmetic
560
    attributes.
561

562
    To create a new type/handler, you can use the ``ParameterAttribute`` descriptors.
563

564
    >>> class ParameterTypeOrHandler(_ParameterAttributeHandler):
565
    ...     length = ParameterAttribute(unit=unit.angstrom)
566
    ...     k = ParameterAttribute(unit=unit.kilocalorie_per_mole / unit.angstrom**2)
567
    ...
568

569
    ``_ParameterAttributeHandler`` and the descriptors take care of performing
570
    sanity checks on initialization and assignment of the single attributes. Because
571
    we attached units to the parameters, we need to pass them with compatible units.
572

573
    >>> my_par = ParameterTypeOrHandler(
574
    ...     length='1.01 * angstrom',
575
    ...     k=5 * unit.kilocalorie_per_mole / unit.angstrom**2
576
    ... )
577

578
    Note that ``_ParameterAttributeHandler`` took care of implementing
579
    a constructor, and that unit parameters support string assignments.
580
    These are automatically converted to ``Quantity`` objects.
581

582
    >>> my_par.length
583
    Quantity(value=1.01, unit=angstrom)
584

585
    While assigning incompatible units is forbidden.
586

587
    >>> my_par.k = 3.0 * unit.gram
588
    Traceback (most recent call last):
589
    ...
590
    openforcefield.utils.utils.IncompatibleUnitError: k=3.0 g should have units of kilocalorie/(angstrom**2*mole)
591

592
    On top of type checking, the constructor implemented in ``_ParameterAttributeHandler``
593
    checks if some required parameters are not given.
594

595
    >>> ParameterTypeOrHandler(length=3.0*unit.nanometer)
596
    Traceback (most recent call last):
597
    ...
598
    openforcefield.typing.engines.smirnoff.parameters.SMIRNOFFSpecError: <class 'openforcefield.typing.engines.smirnoff.parameters.ParameterTypeOrHandler'> require the following missing parameters: ['k']. Defined kwargs are ['length']
599

600
    Each attribute can be made optional by specifying a default value,
601
    and you can attach a converter function by passing a callable as an
602
    argument or through the decorator syntax.
603

604
    >>> class ParameterTypeOrHandler(_ParameterAttributeHandler):
605
    ...     attr_optional = ParameterAttribute(default=2)
606
    ...     attr_all_to_float = ParameterAttribute(converter=float)
607
    ...     attr_int_to_float = ParameterAttribute()
608
    ...
609
    ...     @attr_int_to_float.converter
610
    ...     def attr_int_to_float(self, attr, value):
611
    ...         # This converter converts only integers to floats
612
    ...         # and raise an exception for the other types.
613
    ...         if isinstance(value, int):
614
    ...             return float(value)
615
    ...         elif not isinstance(value, float):
616
    ...             raise TypeError(f"Cannot convert '{value}' to float")
617
    ...         return value
618
    ...
619
    >>> my_par = ParameterTypeOrHandler(attr_all_to_float='3.0', attr_int_to_float=1)
620
    >>> my_par.attr_optional
621
    2
622
    >>> my_par.attr_all_to_float
623
    3.0
624
    >>> my_par.attr_int_to_float
625
    1.0
626

627
    The float() function can convert strings to integers, but our custom
628
    converter forbids it
629

630
    >>> my_par.attr_all_to_float = '2.0'
631
    >>> my_par.attr_int_to_float = '4.0'
632
    Traceback (most recent call last):
633
    ...
634
    TypeError: Cannot convert '4.0' to float
635

636
    Parameter attributes that can be indexed can be handled with the
637
    ``IndexedParameterAttribute``. These support unit validation and
638
    converters exactly as ``ParameterAttribute``s, but the validation/conversion
639
    is performed for each indexed attribute.
640

641
    >>> class MyTorsionType(_ParameterAttributeHandler):
642
    ...     periodicity = IndexedParameterAttribute(converter=int)
643
    ...     k = IndexedParameterAttribute(unit=unit.kilocalorie_per_mole)
644
    ...
645
    >>> my_par = MyTorsionType(
646
    ...     periodicity1=2,
647
    ...     k1=5 * unit.kilocalorie_per_mole,
648
    ...     periodicity2='3',
649
    ...     k2=6 * unit.kilocalorie_per_mole,
650
    ... )
651
    >>> my_par.periodicity
652
    [2, 3]
653

654
    Indexed attributes, can be accessed both as a list or as their indexed
655
    parameter name.
656

657
    >>> my_par.periodicity2 = 6
658
    >>> my_par.periodicity[0] = 1
659
    >>> my_par.periodicity
660
    [1, 6]
661

662
    """
663

664 8
    def __init__(self, allow_cosmetic_attributes=False, **kwargs):
665
        """
666
        Initialize parameter and cosmetic attributes.
667

668
        Parameters
669
        ----------
670
        allow_cosmetic_attributes : bool optional. Default = False
671
            Whether to permit non-spec kwargs ("cosmetic attributes").
672
            If True, non-spec kwargs will be stored as an attribute of
673
            this parameter which can be accessed and written out. Otherwise,
674
            an exception will be raised.
675

676
        """
677
        # A list that may be populated to record the cosmetic attributes
678
        # read from a SMIRNOFF data source.
679 8
        self._cosmetic_attribs = []
680

681
        # Do not modify the original data.
682 8
        smirnoff_data = copy.deepcopy(kwargs)
683

684 8
        (
685
            smirnoff_data,
686
            indexed_mapped_attr_lengths,
687
        ) = self._process_indexed_mapped_attributes(smirnoff_data)
688 8
        smirnoff_data = self._process_indexed_attributes(
689
            smirnoff_data, indexed_mapped_attr_lengths
690
        )
691

692
        # Check for missing required arguments.
693 8
        given_attributes = set(smirnoff_data.keys())
694 8
        required_attributes = set(self._get_required_parameter_attributes().keys())
695 8
        missing_attributes = required_attributes.difference(given_attributes)
696 8
        if len(missing_attributes) != 0:
697 8
            msg = (
698
                f"{self.__class__} require the following missing parameters: {sorted(missing_attributes)}."
699
                f" Defined kwargs are {sorted(smirnoff_data.keys())}"
700
            )
701 8
            raise SMIRNOFFSpecError(msg)
702

703
        # Finally, set attributes of this ParameterType and handle cosmetic attributes.
704 8
        allowed_attributes = set(self._get_parameter_attributes().keys())
705 8
        for key, val in smirnoff_data.items():
706 8
            if key in allowed_attributes:
707 8
                setattr(self, key, val)
708
            # Handle all unknown kwargs as cosmetic so we can write them back out
709 8
            elif allow_cosmetic_attributes:
710 8
                self.add_cosmetic_attribute(key, val)
711
            else:
712 8
                msg = (
713
                    f"Unexpected kwarg ({key}: {val})  passed to {self.__class__} constructor. "
714
                    "If this is a desired cosmetic attribute, consider setting "
715
                    "'allow_cosmetic_attributes=True'"
716
                )
717 8
                raise SMIRNOFFSpecError(msg)
718

719 8
    def _process_indexed_mapped_attributes(self, smirnoff_data):
720
        # TODO: construct data structure for holding indexed_mapped attrs, which
721
        # will get fed into setattr
722 8
        indexed_mapped_attr_lengths = {}
723 8
        reindex = set()
724 8
        reverse = defaultdict(dict)
725

726 8
        kwargs = list(smirnoff_data.keys())
727 8
        for kwarg in kwargs:
728 8
            attr_name, index, key = self._split_attribute_index_mapping(kwarg)
729

730
            # Check if this is an indexed_mapped attribute.
731 8
            if (
732
                (key is not None)
733
                and (index is not None)
734
                and attr_name in self._get_indexed_mapped_parameter_attributes()
735
            ):
736

737
                # we start with a dict because have no guarantee of order
738
                # in which we will see each kwarg
739
                # we'll switch this to a list later
740 8
                if attr_name not in smirnoff_data:
741 8
                    smirnoff_data[attr_name] = dict()
742 8
                    reindex.add(attr_name)
743

744 8
                if index not in smirnoff_data[attr_name]:
745 8
                    smirnoff_data[attr_name][index] = dict()
746

747 8
                smirnoff_data[attr_name][index][key] = smirnoff_data[kwarg]
748 8
                del smirnoff_data[kwarg]
749

750
                # build reverse mapping; needed for contiguity check below
751 8
                if index not in reverse[attr_name]:
752 8
                    reverse[attr_name][index] = dict()
753 8
                reverse[attr_name][index][key] = kwarg
754

755
        # turn all our top-level dicts into lists
756
        # catch cases where we skip an index,
757
        # e.g. k1_bondorder*, k3_bondorder* defined, but not k2_bondorder*
758 8
        for attr_name in reindex:
759 8
            indexed_mapping = []
760 8
            j = 0
761 8
            for i in sorted(smirnoff_data[attr_name].keys()):
762 8
                if int(i) == j:
763 8
                    indexed_mapping.append(smirnoff_data[attr_name][i])
764 8
                    j += 1
765
                else:
766
                    # any key will do; we are sensitive only to top-level index
767 8
                    key = sorted(reverse[attr_name][i].keys())[0]
768 8
                    kwarg = reverse[attr_name][i][key]
769 8
                    val = smirnoff_data[attr_name][i][key]
770

771 8
                    msg = (
772
                        f"Unexpected kwarg ({kwarg}: {val})  passed to {self.__class__} constructor. "
773
                        "If this is a desired cosmetic attribute, consider setting "
774
                        "'allow_cosmetic_attributes=True'"
775
                    )
776 8
                    raise SMIRNOFFSpecError(msg)
777

778 8
            smirnoff_data[attr_name] = indexed_mapping
779

780
            # keep track of lengths; used downstream for checking against other
781
            # indexed attributes
782 8
            indexed_mapped_attr_lengths[attr_name] = len(smirnoff_data[attr_name])
783

784 8
        return smirnoff_data, indexed_mapped_attr_lengths
785

786 8
    def _process_indexed_attributes(self, smirnoff_data, indexed_attr_lengths=None):
787
        # Check for indexed attributes and stack them into a list.
788
        # Keep track of how many indexed attribute we find to make sure they all have the same length.
789

790
        # TODO: REFACTOR ME; try looping over contents of `smirnoff_data`, using
791
        # `split_attribute_index` to extract values
792

793 8
        if indexed_attr_lengths is None:
794 0
            indexed_attr_lengths = {}
795

796 8
        for attrib_basename in self._get_indexed_parameter_attributes().keys():
797 8
            index = 1
798 8
            while True:
799 8
                attrib_w_index = "{}{}".format(attrib_basename, index)
800

801
                # Exit the while loop if the indexed attribute is not given.
802
                # this is the stop condition
803 8
                try:
804 8
                    attrib_w_index_value = smirnoff_data[attrib_w_index]
805 8
                except KeyError:
806 8
                    break
807

808
                # Check if this is the first iteration.
809 8
                if index == 1:
810
                    # Check if this attribute has been specified with and without index.
811 8
                    if attrib_basename in smirnoff_data:
812 8
                        err_msg = (
813
                            f"The attribute '{attrib_basename}' has been specified "
814
                            f"with and without index: '{attrib_w_index}'"
815
                        )
816 8
                        raise TypeError(err_msg)
817

818
                    # Otherwise create the list object.
819 8
                    smirnoff_data[attrib_basename] = list()
820

821
                # Append the new value to the list.
822 8
                smirnoff_data[attrib_basename].append(attrib_w_index_value)
823

824
                # Remove the indexed attribute from the kwargs as it will
825
                # be exposed only as an element of the list.
826 8
                del smirnoff_data[attrib_w_index]
827 8
                index += 1
828

829
            # Update the lengths with this attribute (if it was found).
830 8
            if index > 1:
831 8
                indexed_attr_lengths[attrib_basename] = len(
832
                    smirnoff_data[attrib_basename]
833
                )
834

835
        # Raise an error if we there are different indexed
836
        # attributes with a different number of terms.
837 8
        if len(set(indexed_attr_lengths.values())) > 1:
838 8
            raise TypeError(
839
                "The following indexed attributes have "
840
                f"different lengths: {indexed_attr_lengths}"
841
            )
842

843 8
        return smirnoff_data
844

845 8
    def to_dict(self, discard_cosmetic_attributes=False):
846
        """
847
        Convert this object to dict format.
848

849
        The returning dictionary contains all the ``ParameterAttribute``
850
        and ``IndexedParameterAttribute`` as well as cosmetic attributes
851
        if ``discard_cosmetic_attributes`` is ``False``.
852

853
        Parameters
854
        ----------
855
        discard_cosmetic_attributes : bool, optional. Default = False
856
            Whether to discard non-spec attributes of this object
857

858
        Returns
859
        -------
860
        smirnoff_dict : dict
861
            The SMIRNOFF-compliant dict representation of this object.
862

863
        """
864
        # Make a list of all attribs that should be included in the
865
        # returned dict (call list() to make a copy). We discard
866
        # optional attributes that are set to None defaults.
867 8
        attribs_to_return = list(self._get_defined_parameter_attributes().keys())
868

869
        # Start populating a dict of the attribs.
870 8
        indexed_attribs = set(self._get_indexed_parameter_attributes().keys())
871 8
        indexed_mapped_attribs = set(
872
            self._get_indexed_mapped_parameter_attributes().keys()
873
        )
874 8
        smirnoff_dict = OrderedDict()
875

876
        # If attribs_to_return is ordered here, that will effectively be an informal output ordering
877 8
        for attrib_name in attribs_to_return:
878 8
            attrib_value = getattr(self, attrib_name)
879

880 8
            if attrib_name in indexed_mapped_attribs:
881 8
                for idx, mapping in enumerate(attrib_value):
882 8
                    for key, val in mapping.items():
883 8
                        attrib_name_indexed, attrib_name_mapped = attrib_name.split("_")
884 8
                        smirnoff_dict[
885
                            f"{attrib_name_indexed}{str(idx+1)}_{attrib_name_mapped}{key}"
886
                        ] = val
887 8
            elif attrib_name in indexed_attribs:
888 8
                for idx, val in enumerate(attrib_value):
889 8
                    smirnoff_dict[attrib_name + str(idx + 1)] = val
890
            else:
891 8
                smirnoff_dict[attrib_name] = attrib_value
892

893
        # Serialize cosmetic attributes.
894 8
        if not (discard_cosmetic_attributes):
895 8
            for cosmetic_attrib in self._cosmetic_attribs:
896 8
                smirnoff_dict[cosmetic_attrib] = getattr(self, "_" + cosmetic_attrib)
897

898 8
        return smirnoff_dict
899

900 8
    def __getattr__(self, item):
901
        """Take care of mapping indexed attributes to their respective list elements."""
902

903
        # Try matching the case where there are two indices
904
        # this indicates a index_mapped parameter
905 8
        attr_name, index, key = self._split_attribute_index_mapping(item)
906

907
        # Check if this is an indexed_mapped attribute.
908 8
        if (
909
            (key is not None)
910
            and (index is not None)
911
            and attr_name in self._get_indexed_mapped_parameter_attributes()
912
        ):
913 8
            indexed_mapped_attr_value = getattr(self, attr_name)
914 8
            try:
915 8
                return indexed_mapped_attr_value[index][key]
916 0
            except (IndexError, KeyError) as err:
917 0
                if not err.args:
918 0
                    err.args = ("",)
919 0
                err.args = err.args + (
920
                    f"'{item}' is out of bound for indexed attribute '{attr_name}'",
921
                )
922 0
                raise
923

924
        # Otherwise, try indexed attribute
925
        # Separate the indexed attribute name from the list index.
926 8
        attr_name, index = self._split_attribute_index(item)
927

928
        # Check if this is an indexed attribute.
929 8
        if (
930
            index is not None
931
        ) and attr_name in self._get_indexed_parameter_attributes():
932 8
            indexed_attr_value = getattr(self, attr_name)
933 8
            try:
934 8
                return indexed_attr_value[index]
935 8
            except IndexError:
936 8
                raise IndexError(
937
                    f"'{item}' is out of bound for indexed attribute '{attr_name}'"
938
                )
939

940
        # Otherwise, forward the search to the next class in the MRO.
941 8
        try:
942 8
            return super().__getattr__(item)
943 8
        except AttributeError as e:
944
            # If this fails because the next classes in the MRO do not
945
            # implement __getattr__(), then raise the standard Attribute error.
946 8
            if "__getattr__" in str(e):
947 8
                raise AttributeError(
948
                    f"{self.__class__} object has no attribute '{item}'"
949
                )
950
            # Otherwise, re-raise the error from the class in the MRO.
951 0
            raise
952

953 8
    def __setattr__(self, key, value):
954
        """Take care of mapping indexed attributes to their respective list elements."""
955

956
        # Try matching the case where there are two indices
957
        # this indicates a index_mapped parameter
958 8
        attr_name, index, mapkey = self._split_attribute_index_mapping(key)
959

960
        # Check if this is an index_mapped attribute. avoiding an infinite
961
        # recursion by calling getattr() with non-existing keys.
962 8
        if (
963
            (mapkey is not None)
964
            and (index is not None)
965
            and attr_name in self._get_indexed_mapped_parameter_attributes()
966
        ):
967 8
            indexed_mapped_attr_value = getattr(self, attr_name)
968 8
            try:
969 8
                indexed_mapped_attr_value[index][mapkey] = value
970 8
                return
971 0
            except (IndexError, KeyError) as err:
972 0
                if not err.args:
973 0
                    err.args = ("",)
974 0
                err.args = err.args + (
975
                    f"'{key}' is out of bound for indexed attribute '{attr_name}'",
976
                )
977 0
                raise
978

979
        # Otherwise, try indexed attribute
980
        # Separate the indexed attribute name from the list index.
981 8
        attr_name, index = self._split_attribute_index(key)
982

983
        # Check if this is an indexed attribute. avoiding an infinite
984
        # recursion by calling getattr() with non-existing keys.
985 8
        if (index is not None) and (
986
            attr_name in self._get_indexed_parameter_attributes()
987
        ):
988 8
            indexed_attr_value = getattr(self, attr_name)
989 8
            try:
990 8
                indexed_attr_value[index] = value
991 8
                return
992 8
            except IndexError:
993 8
                raise IndexError(
994
                    f"'{key}' is out of bound for indexed attribute '{attr_name}'"
995
                )
996

997
        # Forward the request to the next class in the MRO.
998 8
        super().__setattr__(key, value)
999

1000 8
    def add_cosmetic_attribute(self, attr_name, attr_value):
1001
        """
1002
        Add a cosmetic attribute to this object.
1003

1004
        This attribute will not have a functional effect on the object
1005
        in the Open Force Field toolkit, but can be written out during
1006
        output.
1007

1008
        .. warning :: The API for modifying cosmetic attributes is experimental
1009
           and may change in the future (see issue #338).
1010

1011
        Parameters
1012
        ----------
1013
        attr_name : str
1014
            Name of the attribute to define for this object.
1015
        attr_value : str
1016
            The value of the attribute to define for this object.
1017

1018
        """
1019 8
        setattr(self, "_" + attr_name, attr_value)
1020 8
        self._cosmetic_attribs.append(attr_name)
1021

1022 8
    def delete_cosmetic_attribute(self, attr_name):
1023
        """
1024
        Delete a cosmetic attribute from this object.
1025

1026
        .. warning :: The API for modifying cosmetic attributes is experimental
1027
           and may change in the future (see issue #338).
1028

1029
        Parameters
1030
        ----------
1031
        attr_name : str
1032
            Name of the cosmetic attribute to delete.
1033
        """
1034
        # TODO: Can we handle this by overriding __delattr__ instead?
1035
        #  Would we also need to override __del__ as well to cover both deletation methods?
1036 8
        delattr(self, "_" + attr_name)
1037 8
        self._cosmetic_attribs.remove(attr_name)
1038

1039 8
    def attribute_is_cosmetic(self, attr_name):
1040
        """
1041
        Determine whether an attribute of this object is cosmetic.
1042

1043
        .. warning :: The API for modifying cosmetic attributes is experimental
1044
           and may change in the future (see issue #338).
1045

1046
        Parameters
1047
        ----------
1048
        attr_name : str
1049
            The attribute name to check
1050

1051
        Returns
1052
        -------
1053
        is_cosmetic : bool
1054
            Returns True if the attribute is defined and is cosmetic. Returns False otherwise.
1055
        """
1056 8
        return attr_name in self._cosmetic_attribs
1057

1058 8
    @staticmethod
1059
    def _split_attribute_index(item):
1060
        """Split the attribute name from the final index.
1061

1062
        For example, the method takes 'k2' and returns the tuple ('k', 1).
1063
        If attribute_name doesn't end with an integer, it returns (item, None).
1064
        """
1065

1066
        # Match any number (\d+) at the end of the string ($).
1067 8
        match = re.search(r"\d+$", item)
1068 8
        if match is None:
1069 8
            return item, None
1070

1071 8
        index = match.group()  # This is a str.
1072 8
        attr_name = item[: -len(index)]
1073 8
        index = int(match.group()) - 1
1074 8
        return attr_name, index
1075

1076 8
    @staticmethod
1077
    def _split_attribute_index_mapping(item):
1078
        """Split the attribute name from the final index.
1079

1080
        For example, the method takes 'k2' and returns the tuple ('k', 1).
1081
        If attribute_name doesn't end with an integer, it returns (item, None).
1082
        """
1083
        # Match items of the form <item><index>_<mapping><key>
1084
        # where <index> and <key> always integers
1085 8
        match = re.search(r"\d+_[A-z]+\d+$", item)
1086 8
        if match is None:
1087 8
            return item, None, None
1088

1089
        # Match any number (\d+) at the end of the string ($).
1090 8
        i_match = r"\d+$"
1091

1092 8
        indexed, mapped = item.split("_")
1093

1094
        # process indexed component
1095 8
        match_indexed = re.search(i_match, indexed)
1096 8
        index = match_indexed.group()  # This is a str.
1097 8
        attr_name = indexed[: -len(index)]
1098 8
        index = int(index) - 1
1099

1100
        # process mapped component
1101 8
        match_mapping = re.search(i_match, mapped)
1102 8
        key = match_mapping.group()  # This is a str.
1103 8
        attr_name = f"{attr_name}_{mapped[:-len(key)]}"
1104 8
        key = int(key)  # we don't subtract 1 here, because these are keys, not indices
1105

1106 8
        return attr_name, index, key
1107

1108 8
    @classmethod
1109 8
    def _get_parameter_attributes(cls, filter=None):
1110
        """Return all the attributes of the parameters.
1111

1112
        This is constructed dynamically by introspection gathering all
1113
        the descriptors that are instances of the ParameterAttribute class.
1114
        Parent classes of the parameter types are inspected as well.
1115

1116
        Note that since Python 3.6 the order of the class attribute definition
1117
        is preserved (see PEP 520) so this function will return the attribute
1118
        in their declaration order.
1119

1120
        Parameters
1121
        ----------
1122
        filter : Callable, optional
1123
            An optional function with signature filter(ParameterAttribute) -> bool.
1124
            If specified, only attributes for which this functions returns
1125
            True are returned.
1126

1127
        Returns
1128
        -------
1129
        parameter_attributes : Dict[str, ParameterAttribute]
1130
            A map from the name of the controlled parameter to the
1131
            ParameterAttribute descriptor handling it.
1132

1133
        Examples
1134
        --------
1135
        >>> parameter_attributes = ParameterType._get_parameter_attributes()
1136
        >>> sorted(parameter_attributes.keys())
1137
        ['id', 'parent_id', 'smirks']
1138
        >>> isinstance(parameter_attributes['id'], ParameterAttribute)
1139
        True
1140

1141
        """
1142
        # If no filter is specified, get all the parameters.
1143 8
        if filter is None:
1144 8
            filter = lambda x: True
1145

1146
        # Go through MRO and retrieve also parents descriptors. The function
1147
        # inspect.getmembers() automatically resolves the MRO, but it also
1148
        # sorts the attribute alphabetically by name. Here we want the order
1149
        # to be the same as the declaration order, which is guaranteed by PEP 520,
1150
        # starting from the parent class.
1151 8
        parameter_attributes = OrderedDict(
1152
            (name, descriptor)
1153
            for c in reversed(inspect.getmro(cls))
1154
            for name, descriptor in c.__dict__.items()
1155
            if isinstance(descriptor, ParameterAttribute) and filter(descriptor)
1156
        )
1157 8
        return parameter_attributes
1158

1159 8
    @classmethod
1160
    def _get_indexed_mapped_parameter_attributes(cls):
1161
        """Shortcut to retrieve only IndexedMappedParameterAttributes."""
1162 8
        return cls._get_parameter_attributes(
1163
            filter=lambda x: isinstance(x, IndexedMappedParameterAttribute)
1164
        )
1165

1166 8
    @classmethod
1167
    def _get_indexed_parameter_attributes(cls):
1168
        """Shortcut to retrieve only IndexedParameterAttributes."""
1169 8
        return cls._get_parameter_attributes(
1170
            filter=lambda x: isinstance(x, IndexedParameterAttribute)
1171
        )
1172

1173 8
    @classmethod
1174
    def _get_required_parameter_attributes(cls):
1175
        """Shortcut to retrieve only required ParameterAttributes."""
1176 8
        return cls._get_parameter_attributes(filter=lambda x: x.default is x.UNDEFINED)
1177

1178 8
    @classmethod
1179
    def _get_optional_parameter_attributes(cls):
1180
        """Shortcut to retrieve only required ParameterAttributes."""
1181 8
        return cls._get_parameter_attributes(
1182
            filter=lambda x: x.default is not x.UNDEFINED
1183
        )
1184

1185 8
    def _get_defined_parameter_attributes(self):
1186
        """Returns all the attributes except for the optional attributes that have None default value.
1187

1188
        This returns first the required attributes and then the defined optional
1189
        attribute in their respective declaration order.
1190
        """
1191 8
        required = self._get_required_parameter_attributes()
1192 8
        optional = self._get_optional_parameter_attributes()
1193
        # Filter the optional parameters that are set to their default.
1194 8
        optional = OrderedDict(
1195
            (name, descriptor)
1196
            for name, descriptor in optional.items()
1197
            if not (
1198
                descriptor.default is None and getattr(self, name) == descriptor.default
1199
            )
1200
        )
1201 8
        required.update(optional)
1202 8
        return required
1203

1204

1205
# ======================================================================
1206
# PARAMETER TYPE/LIST
1207
# ======================================================================
1208

1209
# We can't actually make this derive from dict, because it's possible for the user to change SMIRKS
1210
# of parameters already in the list, which would cause the ParameterType object's SMIRKS and
1211
# the dictionary key's SMIRKS to be out of sync.
1212 8
class ParameterList(list):
1213
    """
1214
    Parameter list that also supports accessing items by SMARTS string.
1215

1216
    .. warning :: This API is experimental and subject to change.
1217

1218
    """
1219

1220
    # TODO: Make this faster by caching SMARTS -> index lookup?
1221

1222
    # TODO: Override __del__ to make sure we don't remove root atom type
1223

1224
    # TODO: Allow retrieval by `id` as well
1225

1226 8
    def __init__(self, input_parameter_list=None):
1227
        """
1228
        Initialize a new ParameterList, optionally providing a list of ParameterType objects
1229
        to initially populate it.
1230

1231
        Parameters
1232
        ----------
1233
        input_parameter_list: list[ParameterType], default=None
1234
            A pre-existing list of ParameterType-based objects. If None, this ParameterList
1235
            will be initialized empty.
1236
        """
1237 8
        super().__init__()
1238

1239 8
        input_parameter_list = input_parameter_list or []
1240
        # TODO: Should a ParameterList only contain a single kind of ParameterType?
1241 8
        for input_parameter in input_parameter_list:
1242 8
            self.append(input_parameter)
1243

1244 8
    def append(self, parameter):
1245
        """
1246
        Add a ParameterType object to the end of the ParameterList
1247

1248
        Parameters
1249
        ----------
1250
        parameter : a ParameterType object
1251

1252
        """
1253
        # TODO: Ensure that newly added parameter is the same type as existing?
1254 8
        super().append(parameter)
1255

1256 8
    def extend(self, other):
1257
        """
1258
        Add a ParameterList object to the end of the ParameterList
1259

1260
        Parameters
1261
        ----------
1262
        other : a ParameterList
1263

1264
        """
1265 8
        if not isinstance(other, ParameterList):
1266 0
            msg = (
1267
                "ParameterList.extend(other) expected instance of ParameterList, "
1268
                "but received {} (type {}) instead".format(other, type(other))
1269
            )
1270 0
            raise TypeError(msg)
1271
        # TODO: Check if other ParameterList contains the same ParameterTypes?
1272 8
        super().extend(other)
1273

1274 8
    def index(self, item):
1275
        """
1276
        Get the numerical index of a ParameterType object or SMIRKS in this ParameterList. Raises ValueError
1277
        if the item is not found.
1278

1279
        Parameters
1280
        ----------
1281
        item : ParameterType object or str
1282
            The parameter or SMIRKS to look up in this ParameterList
1283

1284
        Returns
1285
        -------
1286
        index : int
1287
            The index of the found item
1288
        """
1289 8
        if isinstance(item, ParameterType):
1290 8
            return super().index(item)
1291
        else:
1292 8
            for parameter in self:
1293 8
                if parameter.smirks == item:
1294 8
                    return self.index(parameter)
1295 8
            raise IndexError(
1296
                "SMIRKS {item} not found in ParameterList".format(item=item)
1297
            )
1298

1299 8
    def insert(self, index, parameter):
1300
        """
1301
        Add a ParameterType object as if this were a list
1302

1303
        Parameters
1304
        ----------
1305
        index : int
1306
            The numerical position to insert the parameter at
1307
        parameter : a ParameterType object
1308
            The parameter to insert
1309
        """
1310
        # TODO: Ensure that newly added parameter is the same type as existing?
1311 8
        super().insert(index, parameter)
1312

1313 8
    def __delitem__(self, item):
1314
        """
1315
        Delete item by index or SMIRKS.
1316

1317
        Parameters
1318
        ----------
1319
        item : str or int
1320
            SMIRKS or numerical index of item in this ParameterList
1321
        """
1322 8
        if type(item) is int:
1323 8
            index = item
1324
        else:
1325
            # Try to find by SMIRKS
1326 8
            index = self.index(item)
1327 8
        super().__delitem__(index)
1328

1329 8
    def __getitem__(self, item):
1330
        """
1331
        Retrieve item by index or SMIRKS
1332

1333
        Parameters
1334
        ----------
1335
        item : str or int
1336
            SMIRKS or numerical index of item in this ParameterList
1337
        """
1338 8
        if type(item) is int:
1339 8
            index = item
1340 8
        elif type(item) is slice:
1341 0
            index = item
1342
        else:
1343 8
            index = self.index(item)
1344 8
        return super().__getitem__(index)
1345

1346
    # TODO: Override __setitem__ and __del__ to ensure we can slice by SMIRKS as well
1347
    # This is needed for pickling. See https://github.com/openforcefield/openforcefield/issues/411
1348
    # for more details.
1349
    # TODO: Is there a cleaner way (getstate/setstate perhaps?) to allow FFs to be
1350
    #       pickled?
1351 8
    def __reduce__(self):
1352 8
        return (__class__, (list(self),), self.__dict__)
1353

1354 8
    def __contains__(self, item):
1355
        """Check to see if either Parameter or SMIRKS is contained in parameter list.
1356

1357
        Parameters
1358
        ----------
1359
        item : str
1360
            SMIRKS of item in this ParameterList
1361
        """
1362 8
        if isinstance(item, str):
1363
            # Special case for SMIRKS strings
1364 8
            if item in [result.smirks for result in self]:
1365 8
                return True
1366
        # Fall back to traditional access
1367 8
        return list.__contains__(self, item)
1368

1369 8
    def to_list(self, discard_cosmetic_attributes=True):
1370
        """
1371
        Render this ParameterList to a normal list, serializing each ParameterType object in it to dict.
1372

1373
        Parameters
1374
        ----------
1375

1376
        discard_cosmetic_attributes : bool, optional. Default = True
1377
            Whether to discard non-spec attributes of each ParameterType object.
1378

1379
        Returns
1380
        -------
1381
        parameter_list : List[dict]
1382
            A serialized representation of a ParameterList, with each ParameterType it contains converted to dict.
1383
        """
1384 8
        parameter_list = list()
1385

1386 8
        for parameter in self:
1387 8
            parameter_dict = parameter.to_dict(
1388
                discard_cosmetic_attributes=discard_cosmetic_attributes
1389
            )
1390 8
            parameter_list.append(parameter_dict)
1391

1392 8
        return parameter_list
1393

1394

1395
# TODO: Rename to better reflect role as parameter base class?
1396 8
class ParameterType(_ParameterAttributeHandler):
1397
    """
1398
    Base class for SMIRNOFF parameter types.
1399

1400
    This base class provides utilities to create new parameter types. See
1401
    the below for examples of how to do this.
1402

1403
    .. warning :: This API is experimental and subject to change.
1404

1405
    Attributes
1406
    ----------
1407
    smirks : str
1408
        The SMIRKS pattern that this parameter matches.
1409
    id : str or None
1410
        An optional identifier for the parameter.
1411
    parent_id : str or None
1412
        Optionally, the identifier of the parameter of which this parameter
1413
        is a specialization.
1414

1415
    See Also
1416
    --------
1417
    ParameterAttribute
1418
    IndexedParameterAttribute
1419

1420
    Examples
1421
    --------
1422

1423
    This class allows to define new parameter types by just listing its
1424
    attributes. In the example below, ``_VALENCE_TYPE`` AND ``_ELEMENT_NAME``
1425
    are used for the validation of the SMIRKS pattern associated to the
1426
    parameter and the automatic serialization/deserialization into a ``dict``.
1427

1428
    >>> class MyBondParameter(ParameterType):
1429
    ...     _VALENCE_TYPE = 'Bond'
1430
    ...     _ELEMENT_NAME = 'Bond'
1431
    ...     length = ParameterAttribute(unit=unit.angstrom)
1432
    ...     k = ParameterAttribute(unit=unit.kilocalorie_per_mole / unit.angstrom**2)
1433
    ...
1434

1435
    The parameter automatically inherits the required smirks attribute
1436
    from ``ParameterType``. Associating a ``unit`` to a ``ParameterAttribute``
1437
    cause the attribute to accept only values in compatible units and to
1438
    parse string expressions.
1439

1440
    >>> my_par = MyBondParameter(
1441
    ...     smirks='[*:1]-[*:2]',
1442
    ...     length='1.01 * angstrom',
1443
    ...     k=5 * unit.kilocalorie_per_mole / unit.angstrom**2
1444
    ... )
1445
    >>> my_par.length
1446
    Quantity(value=1.01, unit=angstrom)
1447
    >>> my_par.k = 3.0 * unit.gram
1448
    Traceback (most recent call last):
1449
    ...
1450
    openforcefield.utils.utils.IncompatibleUnitError: k=3.0 g should have units of kilocalorie/(angstrom**2*mole)
1451

1452
    Each attribute can be made optional by specifying a default value,
1453
    and you can attach a converter function by passing a callable as an
1454
    argument or through the decorator syntax.
1455

1456
    >>> class MyParameterType(ParameterType):
1457
    ...     _VALENCE_TYPE = 'Atom'
1458
    ...     _ELEMENT_NAME = 'Atom'
1459
    ...
1460
    ...     attr_optional = ParameterAttribute(default=2)
1461
    ...     attr_all_to_float = ParameterAttribute(converter=float)
1462
    ...     attr_int_to_float = ParameterAttribute()
1463
    ...
1464
    ...     @attr_int_to_float.converter
1465
    ...     def attr_int_to_float(self, attr, value):
1466
    ...         # This converter converts only integers to floats
1467
    ...         # and raise an exception for the other types.
1468
    ...         if isinstance(value, int):
1469
    ...             return float(value)
1470
    ...         elif not isinstance(value, float):
1471
    ...             raise TypeError(f"Cannot convert '{value}' to float")
1472
    ...         return value
1473
    ...
1474
    >>> my_par = MyParameterType(smirks='[*:1]', attr_all_to_float='3.0', attr_int_to_float=1)
1475
    >>> my_par.attr_optional
1476
    2
1477
    >>> my_par.attr_all_to_float
1478
    3.0
1479
    >>> my_par.attr_int_to_float
1480
    1.0
1481

1482
    The float() function can convert strings to integers, but our custom
1483
    converter forbids it
1484

1485
    >>> my_par.attr_all_to_float = '2.0'
1486
    >>> my_par.attr_int_to_float = '4.0'
1487
    Traceback (most recent call last):
1488
    ...
1489
    TypeError: Cannot convert '4.0' to float
1490

1491
    Parameter attributes that can be indexed can be handled with the
1492
    ``IndexedParameterAttribute``. These support unit validation and
1493
    converters exactly as ``ParameterAttribute``s, but the validation/conversion
1494
    is performed for each indexed attribute.
1495

1496
    >>> class MyTorsionType(ParameterType):
1497
    ...     _VALENCE_TYPE = 'ProperTorsion'
1498
    ...     _ELEMENT_NAME = 'Proper'
1499
    ...     periodicity = IndexedParameterAttribute(converter=int)
1500
    ...     k = IndexedParameterAttribute(unit=unit.kilocalorie_per_mole)
1501
    ...
1502
    >>> my_par = MyTorsionType(
1503
    ...     smirks='[*:1]-[*:2]-[*:3]-[*:4]',
1504
    ...     periodicity1=2,
1505
    ...     k1=5 * unit.kilocalorie_per_mole,
1506
    ...     periodicity2='3',
1507
    ...     k2=6 * unit.kilocalorie_per_mole,
1508
    ... )
1509
    >>> my_par.periodicity
1510
    [2, 3]
1511

1512
    Indexed attributes, can be accessed both as a list or as their indexed
1513
    parameter name.
1514

1515
    >>> my_par.periodicity2 = 6
1516
    >>> my_par.periodicity[0] = 1
1517
    >>> my_par.periodicity
1518
    [1, 6]
1519

1520
    """
1521

1522
    # ChemicalEnvironment valence type string expected by SMARTS string for this Handler
1523 8
    _VALENCE_TYPE = None
1524
    # The string mapping to this ParameterType in a SMIRNOFF data source
1525 8
    _ELEMENT_NAME = None
1526

1527
    # Parameter attributes shared among all parameter types.
1528 8
    smirks = ParameterAttribute()
1529 8
    id = ParameterAttribute(default=None)
1530 8
    parent_id = ParameterAttribute(default=None)
1531

1532 8
    @smirks.converter
1533
    def smirks(self, attr, smirks):
1534
        # Validate the SMIRKS string to ensure it matches the expected
1535
        # parameter type, raising an exception if it is invalid or doesn't
1536
        # tag a valid set of atoms.
1537

1538
        # TODO: Add check to make sure we can't make tree non-hierarchical
1539
        #       This would require parameter type knows which ParameterList it belongs to
1540 8
        ChemicalEnvironment.validate_smirks(smirks, validate_valence_type=True)
1541 8
        return smirks
1542

1543 8
    def __init__(self, smirks, allow_cosmetic_attributes=False, **kwargs):
1544
        """
1545
        Create a ParameterType.
1546

1547
        Parameters
1548
        ----------
1549
        smirks : str
1550
            The SMIRKS match for the provided parameter type.
1551
        allow_cosmetic_attributes : bool optional. Default = False
1552
            Whether to permit non-spec kwargs ("cosmetic attributes"). If True, non-spec kwargs will be stored as
1553
            an attribute of this parameter which can be accessed and written out. Otherwise an exception will
1554
            be raised.
1555

1556
        """
1557
        # This is just to make smirks a required positional argument.
1558 8
        kwargs["smirks"] = smirks
1559 8
        super().__init__(allow_cosmetic_attributes=allow_cosmetic_attributes, **kwargs)
1560

1561 8
    def __repr__(self):
1562 8
        ret_str = "<{} with ".format(self.__class__.__name__)
1563 8
        for attr, val in self.to_dict().items():
1564 8
            ret_str += f"{attr}: {val}  "
1565 8
        ret_str += ">"
1566 8
        return ret_str
1567

1568

1569
# ======================================================================
1570
# PARAMETER HANDLERS
1571
#
1572
# The following classes are Handlers that know how to create Force
1573
# subclasses and add them to a System that is being created. Each Handler
1574
# class must define three methods:
1575
# 1) a constructor which takes as input hierarchical dictionaries of data
1576
#    conformant to the SMIRNOFF spec;
1577
# 2) a create_force() method that constructs the Force object and adds it
1578
#    to the System; and
1579
# 3) a labelForce() method that provides access to which terms are applied
1580
#    to which atoms in specified mols.
1581
# ======================================================================
1582

1583
# TODO: Should we have a parameter handler registry?
1584

1585

1586 8
class ParameterHandler(_ParameterAttributeHandler):
1587
    """Base class for parameter handlers.
1588

1589
    Parameter handlers are configured with some global parameters for a
1590
    given section. They may also contain a :class:`ParameterList` populated
1591
    with :class:`ParameterType` objects if they are responsible for assigning
1592
    SMIRKS-based parameters.
1593

1594
    .. warning
1595

1596
       Parameter handler objects can only belong to a single :class:`ForceField` object.
1597
       If you need to create a copy to attach to a different :class:`ForceField` object,
1598
       use ``create_copy()``.
1599

1600
    .. warning :: This API is experimental and subject to change.
1601

1602
    """
1603

1604 8
    _TAGNAME = None  # str of section type handled by this ParameterHandler (XML element name for SMIRNOFF XML representation)
1605 8
    _INFOTYPE = None  # container class with type information that will be stored in self._parameters
1606 8
    _OPENMMTYPE = None  # OpenMM Force class (or None if no equivalent)
1607 8
    _DEPENDENCIES = (
1608
        None  # list of ParameterHandler classes that must precede this, or None
1609
    )
1610

1611 8
    _KWARGS = []  # Kwargs to catch when create_force is called
1612 8
    _SMIRNOFF_VERSION_INTRODUCED = (
1613
        0.0  # the earliest version of SMIRNOFF spec that supports this ParameterHandler
1614
    )
1615 8
    _SMIRNOFF_VERSION_DEPRECATED = (
1616
        None  # if deprecated, the first SMIRNOFF version number it is no longer used
1617
    )
1618 8
    _MIN_SUPPORTED_SECTION_VERSION = 0.3
1619 8
    _MAX_SUPPORTED_SECTION_VERSION = 0.3
1620

1621 8
    version = ParameterAttribute()
1622

1623 8
    @version.converter
1624
    def version(self, attr, new_version):
1625
        """
1626
        Raise a parsing exception if the given section version is unsupported.
1627

1628
        Raises
1629
        ------
1630
        SMIRNOFFVersionError if an incompatible version is passed in.
1631

1632
        """
1633 8
        import packaging.version
1634

1635 8
        from openforcefield.typing.engines.smirnoff import SMIRNOFFVersionError
1636

1637
        # Use PEP-440 compliant version number comparison, if requested
1638 8
        if (
1639
            packaging.version.parse(str(new_version))
1640
            > packaging.version.parse(str(self._MAX_SUPPORTED_SECTION_VERSION))
1641
        ) or (
1642
            packaging.version.parse(str(new_version))
1643
            < packaging.version.parse(str(self._MIN_SUPPORTED_SECTION_VERSION))
1644
        ):
1645 8
            raise SMIRNOFFVersionError(
1646
                f"SMIRNOFF offxml file was written with version {new_version}, but this version "
1647
                f"of ForceField only supports version {self._MIN_SUPPORTED_SECTION_VERSION} "
1648
                f"to version {self._MAX_SUPPORTED_SECTION_VERSION}"
1649
            )
1650 8
        return new_version
1651

1652 8
    def __init__(
1653
        self, allow_cosmetic_attributes=False, skip_version_check=False, **kwargs
1654
    ):
1655
        """
1656
        Initialize a ParameterHandler, optionally with a list of parameters and other kwargs.
1657

1658
        Parameters
1659
        ----------
1660
        allow_cosmetic_attributes : bool, optional. Default = False
1661
            Whether to permit non-spec kwargs. If True, non-spec kwargs will be stored as attributes of this object
1662
            and can be accessed and modified. Otherwise an exception will be raised if a non-spec kwarg is encountered.
1663
        skip_version_check: bool, optional. Default = False
1664
            If False, the SMIRNOFF section version will not be checked, and the ParameterHandler will be initialized
1665
            with version set to _MAX_SUPPORTED_SECTION_VERSION.
1666
        **kwargs : dict
1667
            The dict representation of the SMIRNOFF data source
1668

1669
        """
1670
        # Skip version check if requested.
1671 8
        if "version" not in kwargs:
1672 8
            if skip_version_check:
1673 8
                kwargs["version"] = self._MAX_SUPPORTED_SECTION_VERSION
1674
            else:
1675 8
                raise SMIRNOFFSpecError(
1676
                    f"Missing version while trying to construct {self.__class__}. "
1677
                    f"0.3 SMIRNOFF spec requires each parameter section to have its own version."
1678
                )
1679

1680
        # List of ParameterType objects (also behaves like an OrderedDict where keys are SMARTS).
1681 8
        self._parameters = ParameterList()
1682

1683
        # Initialize ParameterAttributes and cosmetic attributes.
1684 8
        super().__init__(allow_cosmetic_attributes=allow_cosmetic_attributes, **kwargs)
1685

1686 8
    def _add_parameters(self, section_dict, allow_cosmetic_attributes=False):
1687
        """
1688
        Extend the ParameterList in this ParameterHandler using a SMIRNOFF data source.
1689

1690
        Parameters
1691
        ----------
1692
        section_dict : dict
1693
            The dict representation of a SMIRNOFF data source containing parameters to att to this ParameterHandler
1694
        allow_cosmetic_attributes : bool, optional. Default = False
1695
            Whether to allow non-spec fields in section_dict. If True, non-spec kwargs will be stored as an
1696
            attribute of the parameter. If False, non-spec kwargs will raise an exception.
1697

1698
        """
1699 8
        unitless_kwargs, attached_units = extract_serialized_units_from_dict(
1700
            section_dict
1701
        )
1702 8
        smirnoff_data = attach_units(unitless_kwargs, attached_units)
1703

1704 8
        element_name = None
1705 8
        if self._INFOTYPE is not None:
1706 8
            element_name = self._INFOTYPE._ELEMENT_NAME
1707

1708 8
        for key, val in smirnoff_data.items():
1709
            # Skip sections that aren't the parameter list
1710 8
            if key != element_name:
1711 0
                continue
1712
            # If there are multiple parameters, this will be a list. If there's just one, make it a list
1713 8
            if not (isinstance(val, list)):
1714 8
                val = [val]
1715
            # If we're reading the parameter list, iterate through and attach units to
1716
            # each parameter_dict, then use it to initialize a ParameterType
1717 8
            for unitless_param_dict in val:
1718 8
                param_dict = attach_units(unitless_param_dict, attached_units)
1719 8
                new_parameter = self._INFOTYPE(
1720
                    **param_dict, allow_cosmetic_attributes=allow_cosmetic_attributes
1721
                )
1722 8
                self._parameters.append(new_parameter)
1723

1724 8
    @property
1725
    def parameters(self):
1726
        """The ParameterList that holds this ParameterHandler's parameter objects"""
1727 8
        return self._parameters
1728

1729 8
    @property
1730
    def TAGNAME(self):
1731
        """
1732
        The name of this ParameterHandler corresponding to the SMIRNOFF tag name
1733

1734
        Returns
1735
        -------
1736
        handler_name : str
1737
            The name of this parameter handler
1738

1739
        """
1740 8
        return self._TAGNAME
1741

1742
    # TODO: Do we need to return these, or can we handle this internally
1743 8
    @property
1744
    def known_kwargs(self):
1745
        """List of kwargs that can be parsed by the function."""
1746
        # TODO: Should we use introspection to inspect the function signature instead?
1747 8
        return set(self._KWARGS)
1748

1749 8
    def check_handler_compatibility(self, handler_kwargs):
1750
        """
1751
        Checks if a set of kwargs used to create a ParameterHandler are compatible with this ParameterHandler. This is
1752
        called if a second handler is attempted to be initialized for the same tag.
1753

1754
        Parameters
1755
        ----------
1756
        handler_kwargs : dict
1757
            The kwargs that would be used to construct
1758

1759
        Raises
1760
        ------
1761
        IncompatibleParameterError if handler_kwargs are incompatible with existing parameters.
1762
        """
1763 8
        pass
1764

1765
    # TODO: Can we ensure SMIRKS and other parameters remain valid after manipulation?
1766 8
    def add_parameter(
1767
        self, parameter_kwargs=None, parameter=None, after=None, before=None
1768
    ):
1769
        """Add a parameter to the forcefield, ensuring all parameters are valid.
1770

1771
        Parameters
1772
        ----------
1773
        parameter_kwargs: dict, optional
1774
            The kwargs to pass to the ParameterHandler.INFOTYPE (a ParameterType) constructor
1775
        parameter: ParameterType, optional
1776
            A ParameterType to add to the ParameterHandler
1777
        after : str or int, optional
1778
            The SMIRKS pattern (if str) or index (if int) of the parameter directly before where
1779
            the new parameter will be added
1780
        before : str, optional
1781
            The SMIRKS pattern (if str) or index (if int) of the parameter directly after where
1782
            the new parameter will be added
1783

1784
        Note that one of (parameter_kwargs, parameter) must be specified
1785
        Note that when `before` and `after` are both None, the new parameter will be appended
1786
            to the END of the parameter list.
1787
        Note that when `before` and `after` are both specified, the new parameter
1788
            will be added immediately after the parameter matching the `after` pattern or index.
1789

1790
        Examples
1791
        --------
1792

1793
        Add a ParameterType to an existing ParameterList at a specified position.
1794

1795
        Given an existing parameter handler and a new parameter to add to it:
1796

1797
        >>> from simtk import unit
1798
        >>> bh = BondHandler(skip_version_check=True)
1799
        >>> length = 1.5 * unit.angstrom
1800
        >>> k = 100 * unit.kilocalorie_per_mole / unit.angstrom ** 2
1801
        >>> bh.add_parameter({'smirks': '[*:1]-[*:2]', 'length': length, 'k': k, 'id': 'b1'})
1802
        >>> bh.add_parameter({'smirks': '[*:1]=[*:2]', 'length': length, 'k': k, 'id': 'b2'})
1803
        >>> bh.add_parameter({'smirks': '[*:1]#[*:2]', 'length': length, 'k': k, 'id': 'b3'})
1804
        >>> [p.id for p in bh.parameters]
1805
        ['b1', 'b2', 'b3']
1806

1807
        >>> param = {'smirks': '[#1:1]-[#6:2]', 'length': length, 'k': k, 'id': 'b4'}
1808

1809
        Add a new parameter immediately after the parameter with the smirks '[*:1]=[*:2]'
1810

1811
        >>> bh.add_parameter(param, after='[*:1]=[*:2]')
1812
        >>> [p.id for p in bh.parameters]
1813
        ['b1', 'b2', 'b4', 'b3']
1814
        """
1815 8
        for val in [before, after]:
1816 8
            if val and not isinstance(val, (str, int)):
1817 0
                raise TypeError
1818

1819
        # If a dict was passed, construct it; if a ParameterType was passed, do nothing
1820 8
        if parameter_kwargs:
1821 8
            new_parameter = self._INFOTYPE(**parameter_kwargs)
1822 8
        elif parameter:
1823 8
            new_parameter = parameter
1824
        else:
1825 0
            raise ValueError("One of (parameter, parameter_kwargs) must be specified")
1826

1827 8
        if new_parameter.smirks in [p.smirks for p in self._parameters]:
1828 8
            msg = f"A parameter SMIRKS pattern {new_parameter.smirks} already exists."
1829 8
            raise DuplicateParameterError(msg)
1830

1831 8
        if before is not None:
1832 8
            if isinstance(before, str):
1833 8
                before_index = self._parameters.index(before)
1834 8
            elif isinstance(before, int):
1835 8
                before_index = before
1836

1837 8
        if after is not None:
1838 8
            if isinstance(after, str):
1839 8
                after_index = self._parameters.index(after)
1840 8
            elif isinstance(after, int):
1841 8
                after_index = after
1842

1843 8
        if None not in (before, after):
1844 8
            if after_index > before_index:
1845 8
                raise ValueError("before arg must be before after arg")
1846

1847 8
        if after is not None:
1848 8
            self._parameters.insert(after_index + 1, new_parameter)
1849 8
        elif before is not None:
1850 8
            self._parameters.insert(before_index, new_parameter)
1851
        else:
1852 8
            self._parameters.append(new_parameter)
1853

1854 8
    def get_parameter(self, parameter_attrs):
1855
        """
1856
        Return the parameters in this ParameterHandler that match the parameter_attrs argument.
1857
        When multiple attrs are passed, parameters that have any (not all) matching attributes
1858
        are returned.
1859

1860
        Parameters
1861
        ----------
1862
        parameter_attrs : dict of {attr: value}
1863
            The attrs mapped to desired values (for example {"smirks": "[*:1]~[#16:2]=,:[#6:3]~[*:4]", "id": "t105"} )
1864

1865
        Returns
1866
        -------
1867
        params : list of ParameterType objects
1868
            A list of matching ParameterType objects
1869

1870
        Examples
1871
        --------
1872

1873
        Create a parameter handler and populate it with some data.
1874

1875
        >>> from simtk import unit
1876
        >>> handler = BondHandler(skip_version_check=True)
1877
        >>> handler.add_parameter(
1878
        ...     {
1879
        ...         'smirks': '[*:1]-[*:2]',
1880
        ...         'length': 1*unit.angstrom,
1881
        ...         'k': 10*unit.kilocalorie_per_mole/unit.angstrom**2,
1882
        ...     }
1883
        ... )
1884

1885
        Look up, from this handler, all parameters matching some SMIRKS pattern
1886

1887
        >>> handler.get_parameter({'smirks': '[*:1]-[*:2]'})
1888
        [<BondType with smirks: [*:1]-[*:2]  length: 1 A  k: 10 kcal/(A**2 mol)  >]
1889

1890
        """
1891 8
        params = list()
1892 8
        for attr, value in parameter_attrs.items():
1893 8
            for param in self.parameters:
1894 8
                if param in params:
1895 8
                    continue
1896
                # TODO: Cleaner accessing of cosmetic attributes
1897
                # See issue #338
1898 8
                if param.attribute_is_cosmetic(attr):
1899 8
                    attr = "_" + attr
1900 8
                if hasattr(param, attr):
1901 8
                    if getattr(param, attr) == value:
1902 8
                        params.append(param)
1903 8
        return params
1904

1905 8
    class _Match:
1906
        """Represents a ParameterType which has been matched to
1907
        a given chemical environment.
1908
        """
1909

1910 8
        @property
1911
        def parameter_type(self):
1912
            """ParameterType: The matched parameter type."""
1913 8
            return self._parameter_type
1914

1915 8
        @property
1916
        def environment_match(self):
1917
            """Topology._ChemicalEnvironmentMatch: The environment which matched the type."""
1918 8
            return self._environment_match
1919

1920 8
        def __init__(self, parameter_type, environment_match):
1921
            """Constructs a new ParameterHandlerMatch object.
1922

1923
            Parameters
1924
            ----------
1925
            parameter_type: ParameterType
1926
                The matched parameter type.
1927
            environment_match: Topology._ChemicalEnvironmentMatch
1928
                The environment which matched the type.
1929
            """
1930 8
            self._parameter_type = parameter_type
1931 8
            self._environment_match = environment_match
1932

1933 8
    def find_matches(self, entity):
1934
        """Find the elements of the topology/molecule matched by a parameter type.
1935

1936
        Parameters
1937
        ----------
1938
        entity : openforcefield.topology.Topology
1939
            Topology to search.
1940

1941
        Returns
1942
        ---------
1943
        matches : ValenceDict[Tuple[int], ParameterHandler._Match]
1944
            ``matches[particle_indices]`` is the ``ParameterType`` object
1945
            matching the tuple of particle indices in ``entity``.
1946
        """
1947

1948
        # TODO: Right now, this method is only ever called with an entity that is a Topoogy.
1949
        #  Should we reduce its scope and have a check here to make sure entity is a Topology?
1950 8
        return self._find_matches(entity)
1951

1952 8
    def _find_matches(self, entity, transformed_dict_cls=ValenceDict):
1953
        """Implement find_matches() and allow using a difference valence dictionary.
1954

1955
        Parameters
1956
        ----------
1957
        entity : openforcefield.topology.Topology
1958
            Topology to search.
1959
        transformed_dict_cls: class
1960
            The type of dictionary to store the matches in. This
1961
            will determine how groups of atom indices are stored
1962
            and accessed (e.g for angles indices should be 0-1-2
1963
            and not 2-1-0).
1964

1965
        Returns
1966
        ---------
1967
        matches : `transformed_dict_cls` of ParameterHandlerMatch
1968
            ``matches[particle_indices]`` is the ``ParameterType`` object
1969
            matching the tuple of particle indices in ``entity``.
1970
        """
1971 8
        logger.debug("Finding matches for {}".format(self.__class__.__name__))
1972

1973 8
        matches = transformed_dict_cls()
1974

1975
        # TODO: There are probably performance gains to be had here
1976
        #       by performing this loop in reverse order, and breaking early once
1977
        #       all environments have been matched.
1978 8
        for parameter_type in self._parameters:
1979 8
            matches_for_this_type = {}
1980

1981 8
            for environment_match in entity.chemical_environment_matches(
1982
                parameter_type.smirks
1983
            ):
1984
                # Update the matches for this parameter type.
1985 8
                handler_match = self._Match(parameter_type, environment_match)
1986 8
                matches_for_this_type[
1987
                    environment_match.topology_atom_indices
1988
                ] = handler_match
1989

1990
            # Update matches of all parameter types.
1991 8
            matches.update(matches_for_this_type)
1992

1993 8
            logger.debug(
1994
                "{:64} : {:8} matches".format(
1995
                    parameter_type.smirks, len(matches_for_this_type)
1996
                )
1997
            )
1998

1999 8
        logger.debug("{} matches identified".format(len(matches)))
2000 8
        return matches
2001

2002 8
    @staticmethod
2003 8
    def _assert_correct_connectivity(match, expected_connectivity=None):
2004
        """A more performant version of the `topology.assert_bonded` method
2005
        to ensure that the results of `_find_matches` are valid.
2006

2007
        Raises
2008
        ------
2009
        ValueError
2010
            Raise an exception when the atoms in the match don't have
2011
            the correct connectivity.
2012

2013
        Parameters
2014
        ----------
2015
        match: ParameterHandler._Match
2016
            The match found by `_find_matches`
2017
        connectivity: list of tuple of int, optional
2018
            The expected connectivity of the match (e.g. for a torsion
2019
            expected_connectivity=[(0, 1), (1, 2), (2, 3)]). If `None`,
2020
            a connectivity of [(0, 1), ... (n - 1, n)] is assumed.
2021
        """
2022

2023
        # I'm not 100% sure this is really necessary... but this should do
2024
        # the same checks as the more costly assert_bonded method in the
2025
        # ParameterHandler.create_force methods.
2026 8
        if expected_connectivity is None:
2027 8
            return
2028

2029 8
        reference_molecule = match.environment_match.reference_molecule
2030

2031 8
        for connectivity in expected_connectivity:
2032

2033 8
            atom_i = match.environment_match.reference_atom_indices[connectivity[0]]
2034 8
            atom_j = match.environment_match.reference_atom_indices[connectivity[1]]
2035

2036 8
            reference_molecule.get_bond_between(atom_i, atom_j)
2037

2038 8
    def assign_parameters(self, topology, system):
2039
        """Assign parameters for the given Topology to the specified System object.
2040

2041
        Parameters
2042
        ----------
2043
        topology : openforcefield.topology.Topology
2044
            The Topology for which parameters are to be assigned.
2045
            Either a new Force will be created or parameters will be appended to an existing Force.
2046
        system : simtk.openmm.System
2047
            The OpenMM System object to add the Force (or append new parameters) to.
2048
        """
2049 0
        pass
2050

2051 8
    def postprocess_system(self, topology, system, **kwargs):
2052
        """Allow the force to perform a a final post-processing pass on the System following parameter assignment, if needed.
2053

2054
        Parameters
2055
        ----------
2056
        topology : openforcefield.topology.Topology
2057
            The Topology for which parameters are to be assigned.
2058
            Either a new Force will be created or parameters will be appended to an existing Force.
2059
        system : simtk.openmm.System
2060
            The OpenMM System object to add the Force (or append new parameters) to.
2061
        """
2062 8
        pass
2063

2064 8
    def to_dict(self, discard_cosmetic_attributes=False):
2065
        """
2066
        Convert this ParameterHandler to an OrderedDict, compliant with the SMIRNOFF data spec.
2067

2068
        Parameters
2069
        ----------
2070
        discard_cosmetic_attributes : bool, optional. Default = False.
2071
            Whether to discard non-spec parameter and header attributes in this ParameterHandler.
2072

2073
        Returns
2074
        -------
2075
        smirnoff_data : OrderedDict
2076
            SMIRNOFF-spec compliant representation of this ParameterHandler and its internal ParameterList.
2077

2078
        """
2079 8
        smirnoff_data = OrderedDict()
2080

2081
        # Populate parameter list
2082 8
        parameter_list = self._parameters.to_list(
2083
            discard_cosmetic_attributes=discard_cosmetic_attributes
2084
        )
2085

2086
        # NOTE: This assumes that a ParameterHandler will have just one homogenous ParameterList under it
2087 8
        if self._INFOTYPE is not None:
2088
            # smirnoff_data[self._INFOTYPE._ELEMENT_NAME] = unitless_parameter_list
2089 8
            smirnoff_data[self._INFOTYPE._ELEMENT_NAME] = parameter_list
2090

2091
        # Collect parameter and cosmetic attributes.
2092 8
        header_attribute_dict = super().to_dict(
2093
            discard_cosmetic_attributes=discard_cosmetic_attributes
2094
        )
2095 8
        smirnoff_data.update(header_attribute_dict)
2096

2097 8
        return smirnoff_data
2098

2099
    # -------------------------------
2100
    # Utilities for children classes.
2101
    # -------------------------------
2102

2103 8
    @classmethod
2104 8
    def _check_all_valence_terms_assigned(
2105
        cls,
2106
        assigned_terms,
2107
        valence_terms,
2108
        exception_cls=UnassignedValenceParameterException,
2109
    ):
2110
        """Check that all valence terms have been assigned and print a user-friendly error message.
2111

2112
        Parameters
2113
        ----------
2114
        assigned_terms : ValenceDict
2115
            Atom index tuples defining added valence terms.
2116
        valence_terms : Iterable[TopologyAtom] or Iterable[Iterable[TopologyAtom]]
2117
            Atom or atom tuples defining topological valence terms.
2118
        exception_cls : UnassignedValenceParameterException
2119
            A specific exception class to raise to allow catching only specific
2120
            types of errors.
2121

2122
        """
2123 8
        from openforcefield.topology import TopologyAtom
2124

2125
        # Provided there are no duplicates in either list,
2126
        # or something weird like a bond has been added to
2127
        # a torsions list - this should work just fine I think.
2128
        # If we expect either of those assumptions to be incorrect,
2129
        # (i.e len(not_found_terms) > 0) we have bigger issues
2130
        # in the code and should be catching those cases elsewhere!
2131
        # The fact that we graph match all topol molecules to ref
2132
        # molecules should avoid the len(not_found_terms) > 0 case.
2133

2134 8
        if len(assigned_terms) == len(valence_terms):
2135 8
            return
2136

2137
        # Convert the valence term to a valence dictionary to make sure
2138
        # the order of atom indices doesn't matter for comparison.
2139 8
        valence_terms_dict = assigned_terms.__class__()
2140 8
        for atoms in valence_terms:
2141 8
            try:
2142
                # valence_terms is a list of TopologyAtom tuples.
2143 8
                atom_indices = (a.topology_particle_index for a in atoms)
2144 0
            except TypeError:
2145
                # valence_terms is a list of TopologyAtom.
2146 0
                atom_indices = (atoms.topology_particle_index,)
2147 8
            valence_terms_dict[atom_indices] = atoms
2148

2149
        # Check that both valence dictionaries have the same keys (i.e. terms).
2150 8
        assigned_terms_set = set(assigned_terms.keys())
2151 8
        valence_terms_set = set(valence_terms_dict.keys())
2152 8
        unassigned_terms = valence_terms_set.difference(assigned_terms_set)
2153 8
        not_found_terms = assigned_terms_set.difference(valence_terms_set)
2154

2155
        # Raise an error if there are unassigned terms.
2156 8
        err_msg = ""
2157

2158 8
        if len(unassigned_terms) > 0:
2159

2160 8
            unassigned_topology_atom_tuples = []
2161

2162
            # Gain access to the relevant topology
2163 8
            if type(valence_terms[0]) is TopologyAtom:
2164 0
                topology = valence_terms[0].topology_molecule.topology
2165
            else:
2166 8
                topology = valence_terms[0][0].topology_molecule.topology
2167 8
            unassigned_str = ""
2168 8
            for unassigned_tuple in unassigned_terms:
2169 8
                unassigned_str += "\n- Topology indices " + str(unassigned_tuple)
2170 8
                unassigned_str += ": names and elements "
2171

2172 8
                unassigned_topology_atoms = []
2173

2174
                # Pull and add additional helpful info on missing terms
2175 8
                for atom_idx in unassigned_tuple:
2176 8
                    topology_atom = topology.atom(atom_idx)
2177 8
                    unassigned_topology_atoms.append(topology_atom)
2178 8
                    unassigned_str += f"({topology_atom.atom.name} {topology_atom.atom.element.symbol}), "
2179 8
                unassigned_topology_atom_tuples.append(tuple(unassigned_topology_atoms))
2180 8
            err_msg += (
2181
                "{parameter_handler} was not able to find parameters for the following valence terms:\n"
2182
                "{unassigned_str}"
2183
            ).format(parameter_handler=cls.__name__, unassigned_str=unassigned_str)
2184 8
        if len(not_found_terms) > 0:
2185 0
            if err_msg != "":
2186 0
                err_msg += "\n"
2187 0
            not_found_str = "\n- ".join([str(x) for x in not_found_terms])
2188 0
            err_msg += (
2189
                "{parameter_handler} assigned terms that were not found in the topology:\n"
2190
                "- {not_found_str}"
2191
            ).format(parameter_handler=cls.__name__, not_found_str=not_found_str)
2192 8
        if err_msg != "":
2193 8
            err_msg += "\n"
2194 8
            exception = exception_cls(err_msg)
2195 8
            exception.unassigned_topology_atom_tuples = unassigned_topology_atom_tuples
2196 8
            exception.handler_class = cls
2197 8
            raise exception
2198

2199 8
    def _check_attributes_are_equal(
2200
        self, other, identical_attrs=(), tolerance_attrs=(), tolerance=1e-6
2201
    ):
2202
        """Utility function to check that the given attributes of the two handlers are equal.
2203

2204
        Parameters
2205
        ----------
2206
        identical_attrs : List[str]
2207
            Names of the parameters that must be checked with the equality operator.
2208
        tolerance_attrs : List[str]
2209
            Names of the parameters that must be equal up to a tolerance.
2210
        tolerance : float
2211
            The absolute tolerance used to compare the parameters.
2212
        """
2213

2214 8
        def get_unitless_values(attr):
2215 8
            this_val = getattr(self, attr)
2216 8
            other_val = getattr(other, attr)
2217
            # Strip quantities of their units before comparison.
2218 8
            try:
2219 8
                u = this_val.unit
2220 8
            except AttributeError:
2221 8
                return this_val, other_val
2222 8
            return this_val / u, other_val / u
2223

2224 8
        for attr in identical_attrs:
2225 8
            this_val, other_val = get_unitless_values(attr)
2226

2227 8
            if this_val != other_val:
2228 8
                raise IncompatibleParameterError(
2229
                    "{} values are not identical. "
2230
                    "(handler value: {}, incompatible value: {}".format(
2231
                        attr, this_val, other_val
2232
                    )
2233
                )
2234

2235 8
        for attr in tolerance_attrs:
2236 8
            this_val, other_val = get_unitless_values(attr)
2237 8
            if abs(this_val - other_val) > tolerance:
2238 8
                raise IncompatibleParameterError(
2239
                    "Difference between '{}' values is beyond allowed tolerance {}. "
2240
                    "(handler value: {}, incompatible value: {}".format(
2241
                        attr, tolerance, this_val, other_val
2242
                    )
2243
                )
2244

2245

2246
# =============================================================================================
2247

2248

2249 8
class ConstraintHandler(ParameterHandler):
2250
    """Handle SMIRNOFF ``<Constraints>`` tags
2251

2252
    ``ConstraintHandler`` must be applied before ``BondHandler`` and ``AngleHandler``,
2253
    since those classes add constraints for which equilibrium geometries are needed from those tags.
2254

2255
    .. warning :: This API is experimental and subject to change.
2256
    """
2257

2258 8
    class ConstraintType(ParameterType):
2259
        """A SMIRNOFF constraint type
2260

2261
        .. warning :: This API is experimental and subject to change.
2262
        """
2263

2264 8
        _VALENCE_TYPE = "Bond"
2265 8
        _ELEMENT_NAME = "Constraint"
2266

2267 8
        distance = ParameterAttribute(default=None, unit=unit.angstrom)
2268

2269 8
    _TAGNAME = "Constraints"
2270 8
    _INFOTYPE = ConstraintType
2271 8
    _OPENMMTYPE = None  # don't create a corresponding OpenMM Force class
2272

2273 8
    def create_force(self, system, topology, **kwargs):
2274 8
        constraint_matches = self.find_matches(topology)
2275 8
        for (atoms, constraint_match) in constraint_matches.items():
2276
            # Update constrained atom pairs in topology
2277
            # topology.add_constraint(*atoms, constraint.distance)
2278
            # If a distance is specified (constraint.distance != True), add the constraint here.
2279
            # Otherwise, the equilibrium bond length will be used to constrain the atoms in HarmonicBondHandler
2280 8
            constraint = constraint_match.parameter_type
2281

2282 8
            if constraint.distance is None:
2283 8
                topology.add_constraint(*atoms, True)
2284
            else:
2285 8
                system.addConstraint(*atoms, constraint.distance)
2286 8
                topology.add_constraint(*atoms, constraint.distance)
2287

2288

2289
# =============================================================================================
2290

2291

2292 8
class BondHandler(ParameterHandler):
2293
    """Handle SMIRNOFF ``<Bonds>`` tags
2294

2295
    .. warning :: This API is experimental and subject to change.
2296
    """
2297

2298 8
    class BondType(ParameterType):
2299
        """A SMIRNOFF bond type
2300

2301
        .. warning :: This API is experimental and subject to change.
2302
        """
2303

2304
        # ChemicalEnvironment valence type string expected by SMARTS string for this Handler
2305 8
        _VALENCE_TYPE = "Bond"
2306 8
        _ELEMENT_NAME = "Bond"
2307

2308
        # These attributes may be indexed (by integer bond order) if fractional bond orders are used.
2309 8
        length = ParameterAttribute(unit=unit.angstrom)
2310 8
        k = ParameterAttribute(unit=unit.kilocalorie_per_mole / unit.angstrom ** 2)
2311

2312 8
    _TAGNAME = "Bonds"  # SMIRNOFF tag name to process
2313 8
    _INFOTYPE = BondType  # class to hold force type info
2314 8
    _OPENMMTYPE = openmm.HarmonicBondForce  # OpenMM force class to create
2315 8
    _DEPENDENCIES = [ConstraintHandler]  # ConstraintHandler must be executed first
2316

2317 8
    potential = ParameterAttribute(default="harmonic")
2318 8
    fractional_bondorder_method = ParameterAttribute(default=None)
2319 8
    fractional_bondorder_interpolation = ParameterAttribute(default="linear")
2320

2321 8
    def check_handler_compatibility(self, other_handler):
2322
        """
2323
        Checks whether this ParameterHandler encodes compatible physics as another ParameterHandler. This is
2324
        called if a second handler is attempted to be initialized for the same tag.
2325

2326
        Parameters
2327
        ----------
2328
        other_handler : a ParameterHandler object
2329
            The handler to compare to.
2330

2331
        Raises
2332
        ------
2333
        IncompatibleParameterError if handler_kwargs are incompatible with existing parameters.
2334
        """
2335 8
        string_attrs_to_compare = [
2336
            "potential",
2337
            "fractional_bondorder_method",
2338
            "fractional_bondorder_interpolation",
2339
        ]
2340 8
        self._check_attributes_are_equal(
2341
            other_handler, identical_attrs=string_attrs_to_compare
2342
        )
2343

2344 8
    def create_force(self, system, topology, **kwargs):
2345
        # Create or retrieve existing OpenMM Force object
2346
        # TODO: The commented line below should replace the system.getForce search
2347
        # force = super(BondHandler, self).create_force(system, topology, **kwargs)
2348 8
        existing = [system.getForce(i) for i in range(system.getNumForces())]
2349 8
        existing = [f for f in existing if type(f) == self._OPENMMTYPE]
2350 8
        if len(existing) == 0:
2351 8
            force = self._OPENMMTYPE()
2352 8
            system.addForce(force)
2353
        else:
2354 0
            force = existing[0]
2355

2356
        # Add all bonds to the system.
2357 8
        bond_matches = self.find_matches(topology)
2358

2359 8
        skipped_constrained_bonds = (
2360
            0  # keep track of how many bonds were constrained (and hence skipped)
2361
        )
2362 8
        for (topology_atom_indices, bond_match) in bond_matches.items():
2363
            # Get corresponding particle indices in Topology
2364
            # particle_indices = tuple([ atom.particle_index for atom in atoms ])
2365

2366
            # Ensure atoms are actually bonded correct pattern in Topology
2367 8
            self._assert_correct_connectivity(bond_match)
2368
            # topology.assert_bonded(atoms[0], atoms[1])
2369 8
            bond_params = bond_match.parameter_type
2370 8
            match = bond_match.environment_match
2371

2372
            # Compute equilibrium bond length and spring constant.
2373 8
            bond = match.reference_molecule.get_bond_between(
2374
                *match.reference_atom_indices
2375
            )
2376

2377 8
            if hasattr(bond_params, "k_bondorder1"):
2378 0
                raise NotImplementedError(
2379
                    "Partial bondorder treatment is not implemented for bonds."
2380
                )
2381

2382
                # Interpolate using fractional bond orders
2383
                # TODO: Do we really want to allow per-bond specification of interpolation schemes?
2384
                # order = bond.fractional_bond_order
2385
                # if self.fractional_bondorder_interpolation == 'interpolate-linear':
2386
                #    k = bond_params.k[0] + (bond_params.k[1] - bond_params.k[0]) * (order - 1.)
2387
                #    length = bond_params.length[0] + (
2388
                #        bond_params.length[1] - bond_params.length[0]) * (order - 1.)
2389
                # else:
2390
                #    raise Exception(
2391
                #        "Partial bondorder treatment {} is not implemented.".
2392
                #        format(self.fractional_bondorder_method))
2393
            else:
2394 8
                [k, length] = [bond_params.k, bond_params.length]
2395

2396 8
            is_constrained = topology.is_constrained(*topology_atom_indices)
2397

2398
            # Handle constraints.
2399 8
            if is_constrained:
2400
                # Atom pair is constrained; we don't need to add a bond term.
2401 8
                skipped_constrained_bonds += 1
2402
                # Check if we need to add the constraint here to the equilibrium bond length.
2403 8
                if is_constrained is True:
2404
                    # Mark that we have now assigned a specific constraint distance to this constraint.
2405 8
                    topology.add_constraint(*topology_atom_indices, length)
2406
                    # Add the constraint to the System.
2407 8
                    system.addConstraint(*topology_atom_indices, length)
2408
                    # system.addConstraint(*particle_indices, length)
2409 8
                continue
2410

2411
            # Add harmonic bond to HarmonicBondForce
2412 8
            force.addBond(*topology_atom_indices, length, k)
2413

2414 8
        logger.info(
2415
            "{} bonds added ({} skipped due to constraints)".format(
2416
                len(bond_matches) - skipped_constrained_bonds, skipped_constrained_bonds
2417
            )
2418
        )
2419

2420
        # Check that no topological bonds are missing force parameters.
2421 8
        valence_terms = [list(b.atoms) for b in topology.topology_bonds]
2422 8
        self._check_all_valence_terms_assigned(
2423
            assigned_terms=bond_matches,
2424
            valence_terms=valence_terms,
2425
            exception_cls=UnassignedBondParameterException,
2426
        )
2427

2428

2429
# =============================================================================================
2430

2431

2432 8
class AngleHandler(ParameterHandler):
2433
    """Handle SMIRNOFF ``<AngleForce>`` tags
2434

2435
    .. warning :: This API is experimental and subject to change.
2436
    """
2437

2438 8
    class AngleType(ParameterType):
2439
        """A SMIRNOFF angle type.
2440

2441
        .. warning :: This API is experimental and subject to change.
2442
        """
2443

2444 8
        _VALENCE_TYPE = "Angle"  # ChemicalEnvironment valence type string expected by SMARTS string for this Handler
2445 8
        _ELEMENT_NAME = "Angle"
2446

2447 8
        angle = ParameterAttribute(unit=unit.degree)
2448 8
        k = ParameterAttribute(unit=unit.kilocalorie_per_mole / unit.degree ** 2)
2449

2450 8
    _TAGNAME = "Angles"  # SMIRNOFF tag name to process
2451 8
    _INFOTYPE = AngleType  # class to hold force type info
2452 8
    _OPENMMTYPE = openmm.HarmonicAngleForce  # OpenMM force class to create
2453 8
    _DEPENDENCIES = [ConstraintHandler]  # ConstraintHandler must be executed first
2454

2455 8
    potential = ParameterAttribute(default="harmonic")
2456

2457 8
    def check_handler_compatibility(self, other_handler):
2458
        """
2459
        Checks whether this ParameterHandler encodes compatible physics as another ParameterHandler. This is
2460
        called if a second handler is attempted to be initialized for the same tag.
2461

2462
        Parameters
2463
        ----------
2464
        other_handler : a ParameterHandler object
2465
            The handler to compare to.
2466

2467
        Raises
2468
        ------
2469
        IncompatibleParameterError if handler_kwargs are incompatible with existing parameters.
2470
        """
2471 8
        string_attrs_to_compare = ["potential"]
2472 8
        self._check_attributes_are_equal(
2473
            other_handler, identical_attrs=string_attrs_to_compare
2474
        )
2475

2476 8
    def create_force(self, system, topology, **kwargs):
2477
        # force = super(AngleHandler, self).create_force(system, topology, **kwargs)
2478 8
        existing = [system.getForce(i) for i in range(system.getNumForces())]
2479 8
        existing = [f for f in existing if type(f) == self._OPENMMTYPE]
2480 8
        if len(existing) == 0:
2481 8
            force = self._OPENMMTYPE()
2482 8
            system.addForce(force)
2483
        else:
2484 0
            force = existing[0]
2485

2486
        # Add all angles to the system.
2487 8
        angle_matches = self.find_matches(topology)
2488 8
        skipped_constrained_angles = (
2489
            0  # keep track of how many angles were constrained (and hence skipped)
2490
        )
2491 8
        for (atoms, angle_match) in angle_matches.items():
2492
            # Ensure atoms are actually bonded correct pattern in Topology
2493
            # for (i, j) in [(0, 1), (1, 2)]:
2494
            #     topology.assert_bonded(atoms[i], atoms[j])
2495 8
            self._assert_correct_connectivity(angle_match)
2496

2497 8
            if (
2498
                topology.is_constrained(atoms[0], atoms[1])
2499
                and topology.is_constrained(atoms[1], atoms[2])
2500
                and topology.is_constrained(atoms[0], atoms[2])
2501
            ):
2502
                # Angle is constrained; we don't need to add an angle term.
2503 8
                skipped_constrained_angles += 1
2504 8
                continue
2505

2506 8
            angle = angle_match.parameter_type
2507 8
            force.addAngle(*atoms, angle.angle, angle.k)
2508

2509 8
        logger.info(
2510
            "{} angles added ({} skipped due to constraints)".format(
2511
                len(angle_matches) - skipped_constrained_angles,
2512
                skipped_constrained_angles,
2513
            )
2514
        )
2515

2516
        # Check that no topological angles are missing force parameters
2517 8
        self._check_all_valence_terms_assigned(
2518
            assigned_terms=angle_matches,
2519
            valence_terms=list(topology.angles),
2520
            exception_cls=UnassignedAngleParameterException,
2521
        )
2522

2523

2524
# =============================================================================================
2525

2526
# TODO: This is technically a validator, not a converter, but ParameterAttribute doesn't support them yet (it'll be easy if we switch to use the attrs library).
2527 8
def _allow_only(allowed_values):
2528
    """A converter that checks the new value is only in a set."""
2529 8
    allowed_values = frozenset(allowed_values)
2530

2531 8
    def _value_checker(instance, attr, new_value):
2532
        # This statement means that, in the "SMIRNOFF Data Dict" format, the string "None"
2533
        # and the Python None are the same thing
2534 8
        if new_value == "None":
2535 8
            new_value = None
2536

2537
        # Ensure that the new value is in the list of allowed values
2538 8
        if new_value not in allowed_values:
2539

2540 8
            err_msg = (
2541
                f"Attempted to set {instance.__class__.__name__}.{attr.name} "
2542
                f"to {new_value}. Currently, only the following values "
2543
                f"are supported: {sorted(allowed_values)}."
2544
            )
2545 8
            raise SMIRNOFFSpecError(err_msg)
2546 8
        return new_value
2547

2548 8
    return _value_checker
2549

2550

2551
# TODO: There's a lot of duplicated code in ProperTorsionHandler and ImproperTorsionHandler
2552 8
class ProperTorsionHandler(ParameterHandler):
2553
    """Handle SMIRNOFF ``<ProperTorsionForce>`` tags
2554

2555
    .. warning :: This API is experimental and subject to change.
2556
    """
2557

2558 8
    class ProperTorsionType(ParameterType):
2559
        """A SMIRNOFF torsion type for proper torsions.
2560

2561
        .. warning :: This API is experimental and subject to change.
2562
        """
2563

2564 8
        _VALENCE_TYPE = "ProperTorsion"
2565 8
        _ELEMENT_NAME = "Proper"
2566

2567 8
        periodicity = IndexedParameterAttribute(converter=int)
2568 8
        phase = IndexedParameterAttribute(unit=unit.degree)
2569 8
        k = IndexedParameterAttribute(default=None, unit=unit.kilocalorie_per_mole)
2570 8
        idivf = IndexedParameterAttribute(default=None, converter=float)
2571

2572
        # fractional bond order params
2573 8
        k_bondorder = IndexedMappedParameterAttribute(
2574
            default=None, unit=unit.kilocalorie_per_mole
2575
        )
2576

2577 8
    _TAGNAME = "ProperTorsions"  # SMIRNOFF tag name to process
2578 8
    _KWARGS = ["partial_bond_orders_from_molecules"]
2579 8
    _INFOTYPE = ProperTorsionType  # info type to store
2580 8
    _OPENMMTYPE = openmm.PeriodicTorsionForce  # OpenMM force class to create
2581

2582 8
    potential = ParameterAttribute(
2583
        default="k*(1+cos(periodicity*theta-phase))",
2584
        converter=_allow_only(["k*(1+cos(periodicity*theta-phase))"]),
2585
    )
2586 8
    default_idivf = ParameterAttribute(default="auto")
2587 8
    fractional_bondorder_method = ParameterAttribute(default="AM1-Wiberg")
2588 8
    fractional_bondorder_interpolation = ParameterAttribute(default="linear")
2589

2590 8
    def check_handler_compatibility(self, other_handler):
2591
        """
2592
        Checks whether this ParameterHandler encodes compatible physics as another ParameterHandler. This is
2593
        called if a second handler is attempted to be initialized for the same tag.
2594

2595
        Parameters
2596
        ----------
2597
        other_handler : a ParameterHandler object
2598
            The handler to compare to.
2599

2600
        Raises
2601
        ------
2602
        IncompatibleParameterError if handler_kwargs are incompatible with existing parameters.
2603
        """
2604 8
        float_attrs_to_compare = []
2605 8
        string_attrs_to_compare = [
2606
            "potential",
2607
            "fractional_bondorder_method",
2608
            "fractional_bondorder_interpolation",
2609
        ]
2610

2611 8
        if self.default_idivf == "auto":
2612 8
            string_attrs_to_compare.append("default_idivf")
2613
        else:
2614 0
            float_attrs_to_compare.append("default_idivf")
2615

2616 8
        self._check_attributes_are_equal(
2617
            other_handler,
2618
            identical_attrs=string_attrs_to_compare,
2619
            tolerance_attrs=float_attrs_to_compare,
2620
        )
2621

2622 8
    def check_partial_bond_orders_from_molecules_duplicates(self, pb_mols):
2623 8
        if len(set(map(Molecule.to_smiles, pb_mols))) < len(pb_mols):
2624 8
            raise ValueError(
2625
                "At least two user-provided fractional bond order "
2626
                "molecules are isomorphic"
2627
            )
2628

2629 8
    def assign_partial_bond_orders_from_molecules(self, topology, pbo_mols):
2630

2631
        # for each reference molecule in our topology, we'll walk through the provided partial bond order molecules
2632
        # if we find a match, we'll apply the partial bond orders and skip to the next molecule
2633 8
        for ref_mol in topology.reference_molecules:
2634 8
            for pbo_mol in pbo_mols:
2635
                # we are as stringent as we are in the ElectrostaticsHandler
2636
                # TODO: figure out whether bond order matching is redundant with aromatic matching
2637 8
                isomorphic, topology_atom_map = Molecule.are_isomorphic(
2638
                    ref_mol,
2639
                    pbo_mol,
2640
                    return_atom_map=True,
2641
                    aromatic_matching=True,
2642
                    formal_charge_matching=True,
2643
                    bond_order_matching=True,
2644
                    atom_stereochemistry_matching=True,
2645
                    bond_stereochemistry_matching=True,
2646
                )
2647

2648
                # if matching, assign bond orders and skip to next molecule
2649
                # first match wins
2650 8
                if isomorphic:
2651
                    # walk through bonds on reference molecule
2652 8
                    for bond in ref_mol.bonds:
2653
                        # use atom mapping to translate to pbo_molecule bond
2654 8
                        pbo_bond = pbo_mol.get_bond_between(
2655
                            topology_atom_map[bond.atom1_index],
2656
                            topology_atom_map[bond.atom2_index],
2657
                        )
2658
                        # extract fractional bond order
2659
                        # assign fractional bond order to reference molecule bond
2660 8
                        if pbo_bond.fractional_bond_order is None:
2661 0
                            raise ValueError(
2662
                                f"Molecule '{ref_mol}' was requested to be parameterized "
2663
                                f"with user-provided fractional bond orders from '{pbo_mol}', but not "
2664
                                "all bonds were provided with `fractional_bond_order` specified"
2665
                            )
2666

2667 8
                        bond.fractional_bond_order = pbo_bond.fractional_bond_order
2668

2669 8
                    break
2670
                # not necessary, but explicit
2671
                else:
2672 0
                    continue
2673

2674 8
    def create_force(self, system, topology, **kwargs):
2675
        # force = super(ProperTorsionHandler, self).create_force(system, topology, **kwargs)
2676 8
        existing = [system.getForce(i) for i in range(system.getNumForces())]
2677 8
        existing = [f for f in existing if type(f) == self._OPENMMTYPE]
2678

2679 8
        if len(existing) == 0:
2680 8
            force = self._OPENMMTYPE()
2681 8
            system.addForce(force)
2682
        else:
2683 8
            force = existing[0]
2684

2685
        # check whether any of the reference molecules in the topology
2686
        # are in the partial_bond_orders_from_molecules list
2687 8
        if "partial_bond_orders_from_molecules" in kwargs:
2688
            # check whether molecules in the partial_bond_orders_from_molecules
2689
            # list have any duplicates
2690 8
            self.check_partial_bond_orders_from_molecules_duplicates(
2691
                kwargs["partial_bond_orders_from_molecules"]
2692
            )
2693

2694 8
            self.assign_partial_bond_orders_from_molecules(
2695
                topology, kwargs["partial_bond_orders_from_molecules"]
2696
            )
2697

2698
        # find all proper torsions for which we have parameters
2699
        # operates on reference molecules in topology
2700
        # but gives back matches for atoms for instance molecules
2701 8
        torsion_matches = self.find_matches(topology)
2702

2703 8
        for (atom_indices, torsion_match) in torsion_matches.items():
2704
            # Ensure atoms are actually bonded correct pattern in Topology
2705
            # Currently does nothing
2706 8
            self._assert_correct_connectivity(torsion_match)
2707

2708 8
            if torsion_match.parameter_type.k_bondorder is None:
2709
                # TODO: add a check here that we have same number of terms for
2710
                # `kX_bondorder*`, `periodicityX`, `phaseX`
2711
                # only count a given `kX_bondorder*` once
2712

2713
                # assign torsion with no interpolation
2714 8
                self._assign_torsion(atom_indices, torsion_match, force)
2715
            else:
2716
                # TODO: add a check here that we have same number of terms for
2717
                # `kX_bondorder*`, `periodicityX`, `phaseX`
2718
                # only count a given `kX_bondorder*` once
2719

2720
                # assign torsion with interpolation
2721 8
                self._assign_fractional_bond_orders(
2722
                    atom_indices, torsion_match, force, **kwargs
2723
                )
2724

2725 8
        logger.info("{} torsions added".format(len(torsion_matches)))
2726

2727
        # Check that no topological torsions are missing force parameters
2728

2729
        # I can see the apeal of these kind of methods as an 'absolute' check
2730
        # that things have gone well, but I think just making sure that the
2731
        # reference molecule has been fully parametrised should have the same
2732
        # effect! It would be good to eventually refactor things so that everything
2733
        # is focused on the single unique molecules, and then simply just cloned
2734
        # onto the system. It seems like John's proposed System object would do
2735
        # exactly this.
2736 8
        self._check_all_valence_terms_assigned(
2737
            assigned_terms=torsion_matches,
2738
            valence_terms=list(topology.propers),
2739
            exception_cls=UnassignedProperTorsionParameterException,
2740
        )
2741

2742 8
    def _assign_torsion(self, atom_indices, torsion_match, force):
2743

2744 8
        torsion_params = torsion_match.parameter_type
2745

2746 8
        for (periodicity, phase, k, idivf) in zip(
2747
            torsion_params.periodicity,
2748
            torsion_params.phase,
2749
            torsion_params.k,
2750
            torsion_params.idivf,
2751
        ):
2752

2753 8
            if idivf == "auto":
2754
                # TODO: Implement correct "auto" behavior
2755 0
                raise NotImplementedError(
2756
                    "The OpenForceField toolkit hasn't implemented "
2757
                    "support for the torsion `idivf` value of 'auto'"
2758
                )
2759

2760 8
            force.addTorsion(
2761
                atom_indices[0],
2762
                atom_indices[1],
2763
                atom_indices[2],
2764
                atom_indices[3],
2765
                periodicity,
2766
                phase,
2767
                k / idivf,
2768
            )
2769

2770 8
    def _assign_fractional_bond_orders(
2771
        self, atom_indices, torsion_match, force, **kwargs
2772
    ):
2773 8
        from openforcefield.utils.toolkits import GLOBAL_TOOLKIT_REGISTRY
2774

2775 8
        torsion_params = torsion_match.parameter_type
2776 8
        match = torsion_match.environment_match
2777

2778 8
        for (periodicity, phase, k_bondorder, idivf) in zip(
2779
            torsion_params.periodicity,
2780
            torsion_params.phase,
2781
            torsion_params.k_bondorder,
2782
            torsion_params.idivf,
2783
        ):
2784

2785 8
            if len(k_bondorder) < 2:
2786 0
                raise ValueError(
2787
                    "At least 2 bond order values required for `k_bondorder`; "
2788
                    "got {}".format(len(k_bondorder))
2789
                )
2790

2791 8
            if idivf == "auto":
2792
                # TODO: Implement correct "auto" behavior
2793 0
                raise NotImplementedError(
2794
                    "The OpenForceField toolkit hasn't implemented "
2795
                    "support for the torsion `idivf` value of 'auto'"
2796
                )
2797

2798
            # get central bond for reference molecule
2799 8
            central_bond = match.reference_molecule.get_bond_between(
2800
                match.reference_atom_indices[1], match.reference_atom_indices[2]
2801
            )
2802

2803
            # if fractional bond order not calculated yet, we calculate it
2804
            # should only happen once per reference molecule for which we care
2805
            # about fractional bond interpolation
2806
            # and not at all for reference molecules we don't
2807 8
            if central_bond.fractional_bond_order is None:
2808 8
                toolkit_registry = kwargs.get(
2809
                    "toolkit_registry", GLOBAL_TOOLKIT_REGISTRY
2810
                )
2811 8
                match.reference_molecule.assign_fractional_bond_orders(
2812
                    toolkit_registry=toolkit_registry,
2813
                    use_conformers=match.reference_molecule.conformers,
2814
                )
2815

2816
            # scale k based on the bondorder of the central bond
2817 8
            if self.fractional_bondorder_interpolation == "linear":
2818
                # we only interpolate on k
2819 8
                k = self._linear_interpolate_k(
2820
                    k_bondorder, central_bond.fractional_bond_order
2821
                )
2822
            else:
2823 0
                raise Exception(
2824
                    "Fractional bondorder treatment {} is not implemented.".format(
2825
                        self.fractional_bondorder_method
2826
                    )
2827
                )
2828

2829
            # add a torsion with given parameters for topology atoms
2830 8
            force.addTorsion(
2831
                atom_indices[0],
2832
                atom_indices[1],
2833
                atom_indices[2],
2834
                atom_indices[3],
2835
                periodicity,
2836
                phase,
2837
                k / idivf,
2838
            )
2839

2840 8
    @staticmethod
2841
    def _linear_interpolate_k(k_bondorder, fractional_bond_order):
2842

2843
        # pre-empt case where no interpolation is necessary
2844 8
        if fractional_bond_order in k_bondorder:
2845 0
            return k_bondorder[fractional_bond_order]
2846

2847
        # TODO: error out for nonsensical fractional bond orders
2848

2849
        # find the nearest bond_order beneath our fractional value
2850 8
        try:
2851 8
            below = max(bo for bo in k_bondorder if bo < fractional_bond_order)
2852 8
        except ValueError:
2853 8
            below = None
2854

2855
        # find the nearest bond_order above our fractional value
2856 8
        try:
2857 8
            above = min(bo for bo in k_bondorder if bo > fractional_bond_order)
2858 8
        except ValueError:
2859 8
            above = None
2860

2861
        # handle case where we can clearly interpolate
2862 8
        if (above is not None) and (below is not None):
2863 8
            return k_bondorder[below] + (k_bondorder[above] - k_bondorder[below]) * (
2864
                (fractional_bond_order - below) / (above - below)
2865
            )
2866

2867
        # error if we can't hope to interpolate at all
2868 8
        elif (above is None) and (below is None):
2869 0
            raise NotImplementedError(
2870
                f"Failed to find interpolation references for "
2871
                f"`fractional bond order` '{fractional_bond_order}', "
2872
                f"with `k_bond_order` '{k_bondorder}'"
2873
            )
2874

2875
        # extrapolate for fractional bond orders below our lowest defined bond order
2876 8
        elif below is None:
2877 8
            bond_orders = sorted(k_bondorder)
2878 8
            k = k_bondorder[bond_orders[0]] - (
2879
                (k_bondorder[bond_orders[1]] - k_bondorder[bond_orders[0]])
2880
                / (bond_orders[1] - bond_orders[0])
2881
            ) * (bond_orders[0] - fractional_bond_order)
2882 8
            return k
2883

2884
        # extrapolate for fractional bond orders above our highest defined bond order
2885 8
        elif above is None:
2886 8
            bond_orders = sorted(k_bondorder)
2887 8
            k = k_bondorder[bond_orders[-1]] + (
2888
                (k_bondorder[bond_orders[-1]] - k_bondorder[bond_orders[-2]])
2889
                / (bond_orders[-1] - bond_orders[-2])
2890
            ) * (fractional_bond_order - bond_orders[-1])
2891 8
            return k
2892

2893

2894
# TODO: There's a lot of duplicated code in ProperTorsionHandler and ImproperTorsionHandler
2895 8
class ImproperTorsionHandler(ParameterHandler):
2896
    """Handle SMIRNOFF ``<ImproperTorsionForce>`` tags
2897

2898
    .. warning :: This API is experimental and subject to change.
2899
    """
2900

2901 8
    class ImproperTorsionType(ParameterType):
2902
        """A SMIRNOFF torsion type for improper torsions.
2903

2904
        .. warning :: This API is experimental and subject to change.
2905
        """
2906

2907 8
        _VALENCE_TYPE = "ImproperTorsion"
2908 8
        _ELEMENT_NAME = "Improper"
2909

2910 8
        periodicity = IndexedParameterAttribute(converter=int)
2911 8
        phase = IndexedParameterAttribute(unit=unit.degree)
2912 8
        k = IndexedParameterAttribute(unit=unit.kilocalorie_per_mole)
2913 8
        idivf = IndexedParameterAttribute(default=None, converter=float)
2914

2915 8
    _TAGNAME = "ImproperTorsions"  # SMIRNOFF tag name to process
2916 8
    _INFOTYPE = ImproperTorsionType  # info type to store
2917 8
    _OPENMMTYPE = openmm.PeriodicTorsionForce  # OpenMM force class to create
2918

2919 8
    potential = ParameterAttribute(
2920
        default="k*(1+cos(periodicity*theta-phase))",
2921
        converter=_allow_only(["k*(1+cos(periodicity*theta-phase))"]),
2922
    )
2923 8
    default_idivf = ParameterAttribute(default="auto")
2924

2925 8
    def check_handler_compatibility(self, other_handler):
2926
        """
2927
        Checks whether this ParameterHandler encodes compatible physics as another ParameterHandler. This is
2928
        called if a second handler is attempted to be initialized for the same tag.
2929

2930
        Parameters
2931
        ----------
2932
        other_handler : a ParameterHandler object
2933
            The handler to compare to.
2934

2935
        Raises
2936
        ------
2937
        IncompatibleParameterError if handler_kwargs are incompatible with existing parameters.
2938
        """
2939 8
        float_attrs_to_compare = []
2940 8
        string_attrs_to_compare = ["potential"]
2941

2942 8
        if self.default_idivf == "auto":
2943 8
            string_attrs_to_compare.append("default_idivf")
2944
        else:
2945 0
            float_attrs_to_compare.append("default_idivf")
2946

2947 8
        self._check_attributes_are_equal(
2948
            other_handler,
2949
            identical_attrs=string_attrs_to_compare,
2950
            tolerance_attrs=float_attrs_to_compare,
2951
        )
2952

2953 8
    def find_matches(self, entity):
2954
        """Find the improper torsions in the topology/molecule matched by a parameter type.
2955

2956
        Parameters
2957
        ----------
2958
        entity : openforcefield.topology.Topology
2959
            Topology to search.
2960

2961
        Returns
2962
        ---------
2963
        matches : ImproperDict[Tuple[int], ParameterHandler._Match]
2964
            ``matches[atom_indices]`` is the ``ParameterType`` object
2965
            matching the 4-tuple of atom indices in ``entity``.
2966

2967
        """
2968 8
        return self._find_matches(entity, transformed_dict_cls=ImproperDict)
2969

2970 8
    def create_force(self, system, topology, **kwargs):
2971
        # force = super(ImproperTorsionHandler, self).create_force(system, topology, **kwargs)
2972
        # force = super().create_force(system, topology, **kwargs)
2973 8
        existing = [system.getForce(i) for i in range(system.getNumForces())]
2974 8
        existing = [f for f in existing if type(f) == openmm.PeriodicTorsionForce]
2975 8
        if len(existing) == 0:
2976 8
            force = openmm.PeriodicTorsionForce()
2977 8
            system.addForce(force)
2978
        else:
2979 0
            force = existing[0]
2980

2981
        # Add all improper torsions to the system
2982 8
        improper_matches = self.find_matches(topology)
2983 8
        for (atom_indices, improper_match) in improper_matches.items():
2984
            # Ensure atoms are actually bonded correct pattern in Topology
2985
            # For impropers, central atom is atom 1
2986
            # for (i, j) in [(0, 1), (1, 2), (1, 3)]:
2987
            #     topology.assert_bonded(atom_indices[i], atom_indices[j])
2988 8
            self._assert_correct_connectivity(improper_match, [(0, 1), (1, 2), (1, 3)])
2989

2990 8
            improper = improper_match.parameter_type
2991

2992
            # TODO: This is a lazy hack. idivf should be set according to the ParameterHandler's default_idivf attrib
2993 8
            if improper.idivf is None:
2994 8
                improper.idivf = [3 for item in improper.k]
2995
            # Impropers are applied in three paths around the trefoil having the same handedness
2996 8
            for (
2997
                improper_periodicity,
2998
                improper_phase,
2999
                improper_k,
3000
                improper_idivf,
3001
            ) in zip(improper.periodicity, improper.phase, improper.k, improper.idivf):
3002
                # TODO: Implement correct "auto" behavior
3003 8
                if improper_idivf == "auto":
3004 0
                    improper_idivf = 3
3005 0
                    logger.warning(
3006
                        "The OpenForceField toolkit hasn't implemented "
3007
                        "support for the torsion `idivf` value of 'auto'."
3008
                        "Currently assuming a value of '3' for impropers."
3009
                    )
3010
                # Permute non-central atoms
3011 8
                others = [atom_indices[0], atom_indices[2], atom_indices[3]]
3012
                # ((0, 1, 2), (1, 2, 0), and (2, 0, 1)) are the three paths around the trefoil
3013 8
                for p in [
3014
                    (others[i], others[j], others[k])
3015
                    for (i, j, k) in [(0, 1, 2), (1, 2, 0), (2, 0, 1)]
3016
                ]:
3017
                    # The torsion force gets added three times, since the k is divided by three
3018 8
                    force.addTorsion(
3019
                        atom_indices[1],
3020
                        p[0],
3021
                        p[1],
3022
                        p[2],
3023
                        improper_periodicity,
3024
                        improper_phase,
3025
                        improper_k / improper_idivf,
3026
                    )
3027 8
        logger.info(
3028
            "{} impropers added, each applied in a six-fold trefoil".format(
3029
                len(improper_matches)
3030
            )
3031
        )
3032

3033

3034 8
class _NonbondedHandler(ParameterHandler):
3035
    """Base class for ParameterHandlers that deal with OpenMM NonbondedForce objects."""
3036

3037 8
    _OPENMMTYPE = openmm.NonbondedForce
3038

3039 8
    def create_force(self, system, topology, **kwargs):
3040
        # If we aren't yet keeping track of which molecules' charges have been assigned by which charge methods,
3041
        # initialize a dict for that here.
3042
        # TODO: This should be an attribute of the _system_, not the _topology_. However, since we're still using
3043
        #  OpenMM's System class, I am storing this data on the OFF Topology until we make an OFF System class.
3044 8
        if not hasattr(topology, "_ref_mol_to_charge_method"):
3045 8
            topology._ref_mol_to_charge_method = {
3046
                ref_mol: None for ref_mol in topology.reference_molecules
3047
            }
3048

3049
        # Retrieve the system's OpenMM NonbondedForce
3050 8
        existing = [system.getForce(i) for i in range(system.getNumForces())]
3051 8
        existing = [f for f in existing if type(f) == self._OPENMMTYPE]
3052

3053
        # If there isn't yet one, initialize it and populate it with particles
3054 8
        if len(existing) == 0:
3055 8
            force = self._OPENMMTYPE()
3056 8
            system.addForce(force)
3057
            # Create all particles.
3058 8
            for _ in topology.topology_particles:
3059 8
                force.addParticle(0.0, 1.0, 0.0)
3060
        else:
3061 8
            force = existing[0]
3062

3063 8
        return force
3064

3065 8
    def mark_charges_assigned(self, ref_mol, topology):
3066
        """
3067
        Record that charges have been assigned for a reference molecule.
3068

3069
        Parameters
3070
        ----------
3071
        ref_mol : openforcefield.topology.Molecule
3072
            The molecule to mark as having charges assigned
3073
        topology : openforcefield.topology.Topology
3074
            The topology to record this information on.
3075

3076
        """
3077
        # TODO: Change this to interface with system object instead of topology once we move away from OMM's System
3078 8
        topology._ref_mol_to_charge_method[ref_mol] = self.__class__
3079

3080 8
    @staticmethod
3081
    def check_charges_assigned(ref_mol, topology):
3082
        """
3083
        Check whether charges have been assigned for a reference molecule.
3084

3085
        Parameters
3086
        ----------
3087
        ref_mol : openforcefield.topology.Molecule
3088
            The molecule to check for having charges assigned
3089
        topology : openforcefield.topology.Topology
3090
            The topology to query for this information
3091

3092
        Returns
3093
        -------
3094
        charges_assigned : bool
3095
            Whether charges have already been assigned to this molecule
3096

3097
        """
3098
        # TODO: Change this to interface with system object instead of topology once we move away from OMM's System
3099 8
        return topology._ref_mol_to_charge_method[ref_mol] is not None
3100

3101

3102 8
class vdWHandler(_NonbondedHandler):
3103
    """Handle SMIRNOFF ``<vdW>`` tags
3104

3105
    .. warning :: This API is experimental and subject to change.
3106
    """
3107

3108 8
    class vdWType(ParameterType):
3109
        """A SMIRNOFF vdWForce type.
3110

3111
        .. warning :: This API is experimental and subject to change.
3112
        """
3113

3114 8
        _VALENCE_TYPE = "Atom"  # ChemicalEnvironment valence type expected for SMARTS
3115 8
        _ELEMENT_NAME = "Atom"
3116

3117 8
        epsilon = ParameterAttribute(unit=unit.kilocalorie_per_mole)
3118 8
        sigma = ParameterAttribute(default=None, unit=unit.angstrom)
3119 8
        rmin_half = ParameterAttribute(default=None, unit=unit.angstrom)
3120

3121 8
        def __init__(self, **kwargs):
3122 8
            sigma = kwargs.get("sigma", None)
3123 8
            rmin_half = kwargs.get("rmin_half", None)
3124 8
            if (sigma is None) and (rmin_half is None):
3125 0
                raise SMIRNOFFSpecError("Either sigma or rmin_half must be specified.")
3126 8
            if (sigma is not None) and (rmin_half is not None):
3127 0
                raise SMIRNOFFSpecError(
3128
                    "BOTH sigma and rmin_half cannot be specified simultaneously."
3129
                )
3130

3131 8
            super().__init__(**kwargs)
3132

3133 8
    _TAGNAME = "vdW"  # SMIRNOFF tag name to process
3134 8
    _INFOTYPE = vdWType  # info type to store
3135
    # _KWARGS = ['ewaldErrorTolerance',
3136
    #            'useDispersionCorrection',
3137
    #            'usePbc'] # Kwargs to catch when create_force is called
3138

3139 8
    potential = ParameterAttribute(
3140
        default="Lennard-Jones-12-6", converter=_allow_only(["Lennard-Jones-12-6"])
3141
    )
3142 8
    combining_rules = ParameterAttribute(
3143
        default="Lorentz-Berthelot", converter=_allow_only(["Lorentz-Berthelot"])
3144
    )
3145

3146 8
    scale12 = ParameterAttribute(default=0.0, converter=float)
3147 8
    scale13 = ParameterAttribute(default=0.0, converter=float)
3148 8
    scale14 = ParameterAttribute(default=0.5, converter=float)
3149 8
    scale15 = ParameterAttribute(default=1.0, converter=float)
3150

3151 8
    cutoff = ParameterAttribute(default=9.0 * unit.angstroms, unit=unit.angstrom)
3152 8
    switch_width = ParameterAttribute(default=1.0 * unit.angstroms, unit=unit.angstrom)
3153 8
    method = ParameterAttribute(
3154
        default="cutoff", converter=_allow_only(["cutoff", "PME"])
3155
    )
3156

3157
    # TODO: Use _allow_only when ParameterAttribute will support multiple converters (it'll be easy when we switch to use the attrs library)
3158 8
    @scale12.converter
3159
    def scale12(self, attrs, new_scale12):
3160 0
        if new_scale12 != 0.0:
3161 0
            raise SMIRNOFFSpecError(
3162
                "Current OFF toolkit is unable to handle scale12 values other than 0.0. "
3163
                "Specified 1-2 scaling was {}".format(self._scale12)
3164
            )
3165 0
        return new_scale12
3166

3167 8
    @scale13.converter
3168
    def scale13(self, attrs, new_scale13):
3169 0
        if new_scale13 != 0.0:
3170 0
            raise SMIRNOFFSpecError(
3171
                "Current OFF toolkit is unable to handle scale13 values other than 0.0. "
3172
                "Specified 1-3 scaling was {}".format(self._scale13)
3173
            )
3174 0
        return new_scale13
3175

3176 8
    @scale15.converter
3177
    def scale15(self, attrs, new_scale15):
3178 0
        if new_scale15 != 1.0:
3179 0
            raise SMIRNOFFSpecError(
3180
                "Current OFF toolkit is unable to handle scale15 values other than 1.0. "
3181
                "Specified 1-5 scaling was {}".format(self._scale15)
3182
            )
3183 0
        return new_scale15
3184

3185
    # Tolerance when comparing float attributes for handler compatibility.
3186 8
    _SCALETOL = 1e-5
3187

3188 8
    def check_handler_compatibility(self, other_handler):
3189
        """