1
#!/usr/bin/env python
2 8
"""
3
Utility subroutines
4

5
"""
6

7 8
__all__ = [
8
    "MessageException",
9
    "IncompatibleUnitError",
10
    "inherit_docstrings",
11
    "all_subclasses",
12
    "temporary_cd",
13
    "get_data_file_path",
14
    "unit_to_string",
15
    "quantity_to_string",
16
    "string_to_unit",
17
    "string_to_quantity",
18
    "object_to_quantity",
19
    "check_units_are_compatible",
20
    "extract_serialized_units_from_dict",
21
    "attach_units",
22
    "detach_units",
23
    "serialize_numpy",
24
    "deserialize_numpy",
25
    "convert_all_quantities_to_string",
26
    "convert_all_strings_to_quantity",
27
    "convert_0_1_smirnoff_to_0_2",
28
    "convert_0_2_smirnoff_to_0_3",
29
    "get_molecule_parameterIDs",
30
]
31

32
# =============================================================================================
33
# GLOBAL IMPORTS
34
# =============================================================================================
35

36 8
import contextlib
37 8
import functools
38 8
import logging
39

40 8
from simtk import unit
41

42
# =============================================================================================
43
# CONFIGURE LOGGER
44
# =============================================================================================
45

46 8
logger = logging.getLogger(__name__)
47

48

49
# =============================================================================================
50
# COMMON EXCEPTION TYPES
51
# =============================================================================================
52

53

54 8
class MessageException(Exception):
55
    """A base class for exceptions that print out a string given in their constructor"""
56

57 8
    def __init__(self, msg):
58 8
        super().__init__(self, msg)
59 8
        self.msg = msg
60

61 8
    def __str__(self):
62 8
        return self.msg
63

64

65 8
class IncompatibleUnitError(MessageException):
66
    """
67
    Exception for when a parameter is in the wrong units for a ParameterHandler's unit system
68
    """
69

70 8
    pass
71

72

73
# =============================================================================================
74
# UTILITY SUBROUTINES
75
# =============================================================================================
76

77

78 8
def inherit_docstrings(cls):
79
    """Inherit docstrings from parent class"""
80 8
    from inspect import getmembers, isfunction
81

82 8
    for name, func in getmembers(cls, isfunction):
83 8
        if func.__doc__:
84 8
            continue
85 8
        for parent in cls.__mro__[1:]:
86 8
            if hasattr(parent, name):
87 8
                func.__doc__ = getattr(parent, name).__doc__
88 8
    return cls
89

90

91 8
def all_subclasses(cls):
92
    """Recursively retrieve all subclasses of the specified class"""
93 8
    return cls.__subclasses__() + [
94
        g for s in cls.__subclasses__() for g in all_subclasses(s)
95
    ]
96

97

98 8
@contextlib.contextmanager
99
def temporary_cd(dir_path):
100
    """Context to temporary change the working directory.
101
    Parameters
102
    ----------
103
    dir_path : str
104
        The directory path to enter within the context
105
    Examples
106
    --------
107
    >>> dir_path = '/tmp'
108
    >>> with temporary_cd(dir_path):
109
    ...     pass  # do something in dir_path
110
    """
111 8
    import os
112

113 8
    prev_dir = os.getcwd()
114 8
    os.chdir(os.path.abspath(dir_path))
115 8
    try:
116 8
        yield
117
    finally:
118 8
        os.chdir(prev_dir)
119

120

121 8
def get_data_file_path(relative_path):
122
    """Get the full path to one of the reference files in testsystems.
123
    In the source distribution, these files are in ``openforcefield/data/``,
124
    but on installation, they're moved to somewhere in the user's python
125
    site-packages directory.
126
    Parameters
127
    ----------
128
    name : str
129
        Name of the file to load (with respect to the repex folder).
130
    """
131

132 8
    import os
133

134 8
    from pkg_resources import resource_filename
135

136 8
    fn = resource_filename("openforcefield", os.path.join("data", relative_path))
137

138 8
    if not os.path.exists(fn):
139 0
        raise ValueError(
140
            f"Sorry! {fn} does not exist. If you just added it, you'll have to re-install"
141
        )
142

143 8
    return fn
144

145

146 8
def unit_to_string(input_unit):
147
    """
148
    Serialize a simtk.unit.Unit and return it as a string.
149

150
    Parameters
151
    ----------
152
    input_unit : A simtk.unit
153
        The unit to serialize
154

155
    Returns
156
    -------
157
    unit_string : str
158
        The serialized unit.
159
    """
160

161 8
    if input_unit == unit.dimensionless:
162 8
        return "dimensionless"
163

164
    # Decompose output_unit into a tuples of (base_dimension_unit, exponent)
165 8
    unit_string = None
166

167 8
    for unit_component in input_unit.iter_base_or_scaled_units():
168 8
        unit_component_name = unit_component[0].name
169
        # Convert, for example "elementary charge" --> "elementary_charge"
170 8
        unit_component_name = unit_component_name.replace(" ", "_")
171 8
        if unit_component[1] == 1:
172 8
            contribution = "{}".format(unit_component_name)
173
        else:
174 8
            contribution = "{}**{}".format(unit_component_name, int(unit_component[1]))
175 8
        if unit_string is None:
176 8
            unit_string = contribution
177
        else:
178 8
            unit_string += " * {}".format(contribution)
179

180 8
    return unit_string
181

182

183 8
def quantity_to_string(input_quantity):
184
    """
185
    Serialize a simtk.unit.Quantity to a string.
186

187
    Parameters
188
    ----------
189
    input_quantity : simtk.unit.Quantity
190
        The quantity to serialize
191

192
    Returns
193
    -------
194
    output_string : str
195
        The serialized quantity
196

197
    """
198 8
    import numpy as np
199

200 8
    if input_quantity is None:
201 8
        return None
202 8
    unitless_value = input_quantity.value_in_unit(input_quantity.unit)
203
    # The string representation of a numpy array doesn't have commas and breaks the
204
    # parser, thus we convert any arrays to list here
205 8
    if isinstance(unitless_value, np.ndarray):
206 4
        unitless_value = list(unitless_value)
207 8
    unit_string = unit_to_string(input_quantity.unit)
208 8
    output_string = "{} * {}".format(unitless_value, unit_string)
209 8
    return output_string
210

211

212 8
def _ast_eval(node):
213
    """
214
    Performs an algebraic syntax tree evaluation of a unit.
215

216
    Parameters
217
    ----------
218
    node : An ast parsing tree node
219
    """
220 8
    import ast
221 8
    import operator as op
222

223 8
    operators = {
224
        ast.Add: op.add,
225
        ast.Sub: op.sub,
226
        ast.Mult: op.mul,
227
        ast.Div: op.truediv,
228
        ast.Pow: op.pow,
229
        ast.BitXor: op.xor,
230
        ast.USub: op.neg,
231
    }
232

233 8
    if isinstance(node, ast.Num):  # <number>
234 8
        return node.n
235 8
    elif isinstance(node, ast.BinOp):  # <left> <operator> <right>
236 8
        return operators[type(node.op)](_ast_eval(node.left), _ast_eval(node.right))
237 8
    elif isinstance(node, ast.UnaryOp):  # <operator> <operand> e.g., -1
238 8
        return operators[type(node.op)](_ast_eval(node.operand))
239 8
    elif isinstance(node, ast.Name):
240
        # see if this is a simtk unit
241 8
        b = getattr(unit, node.id)
242 8
        return b
243
    # TODO: This was a quick hack that surprisingly worked. We should validate this further.
244 8
    elif isinstance(node, ast.List):
245 8
        return ast.literal_eval(node)
246
    else:
247 8
        raise TypeError(node)
248

249

250 8
def string_to_unit(unit_string):
251
    """
252
    Deserializes a simtk.unit.Quantity from a string representation, for
253
    example: "kilocalories_per_mole / angstrom ** 2"
254

255

256
    Parameters
257
    ----------
258
    unit_string : dict
259
        Serialized representation of a simtk.unit.Quantity.
260

261
    Returns
262
    -------
263
    output_unit: simtk.unit.Quantity
264
        The deserialized unit from the string
265
    """
266 8
    import ast
267

268 8
    output_unit = _ast_eval(ast.parse(unit_string, mode="eval").body)
269 8
    return output_unit
270

271
    # if (serialized['unitless_value'] is None) and (serialized['unit'] is None):
272
    #    return None
273
    # quantity_unit = None
274
    # for unit_name, power in serialized['unit']:
275
    #    unit_name = unit_name.replace(
276
    #        ' ', '_')  # Convert eg. 'elementary charge' to 'elementary_charge'
277
    #    if quantity_unit is None:
278
    #        quantity_unit = (getattr(unit, unit_name)**power)
279
    #    else:
280
    #        quantity_unit *= (getattr(unit, unit_name)**power)
281
    # quantity = unit.Quantity(serialized['unitless_value'], quantity_unit)
282
    # return quantity
283

284

285 8
def string_to_quantity(quantity_string):
286
    """
287
    Takes a string representation of a quantity and returns a simtk.unit.Quantity
288

289
    Parameters
290
    ----------
291
    quantity_string : str
292
        The quantity to deserialize
293

294
    Returns
295
    -------
296
    output_quantity : simtk.unit.Quantity
297
        The deserialized quantity
298
    """
299 8
    if quantity_string is None:
300 8
        return None
301
    # This can be the exact same as string_to_unit
302 8
    import ast
303

304 8
    output_quantity = _ast_eval(ast.parse(quantity_string, mode="eval").body)
305 8
    return output_quantity
306

307

308 8
def convert_all_strings_to_quantity(smirnoff_data):
309
    """
310
    Traverses a SMIRNOFF data structure, attempting to convert all
311
    quantity-defining strings into simtk.unit.Quantity objects.
312

313
    Integers and floats are ignored and not converted into a dimensionless
314
    ``simtk.unit.Quantity`` object.
315

316
    Parameters
317
    ----------
318
    smirnoff_data : dict
319
        A hierarchical dict structured in compliance with the SMIRNOFF spec
320

321
    Returns
322
    -------
323
    converted_smirnoff_data : dict
324
        A hierarchical dict structured in compliance with the SMIRNOFF spec,
325
        with quantity-defining strings converted to simtk.unit.Quantity objects
326
    """
327 8
    if isinstance(smirnoff_data, dict):
328 8
        for key, value in smirnoff_data.items():
329 8
            smirnoff_data[key] = convert_all_strings_to_quantity(value)
330 8
        obj_to_return = smirnoff_data
331

332 8
    elif isinstance(smirnoff_data, list):
333 8
        for index, item in enumerate(smirnoff_data):
334 8
            smirnoff_data[index] = convert_all_strings_to_quantity(item)
335 8
        obj_to_return = smirnoff_data
336

337 8
    elif isinstance(smirnoff_data, int) or isinstance(smirnoff_data, float):
338 8
        obj_to_return = smirnoff_data
339

340
    else:
341 8
        try:
342 8
            obj_to_return = object_to_quantity(smirnoff_data)
343 8
        except (AttributeError, TypeError, SyntaxError):
344 8
            obj_to_return = smirnoff_data
345

346 8
    return obj_to_return
347

348

349 8
def convert_all_quantities_to_string(smirnoff_data):
350
    """
351
    Traverses a SMIRNOFF data structure, attempting to convert all
352
    quantities into strings.
353

354
    Parameters
355
    ----------
356
    smirnoff_data : dict
357
        A hierarchical dict structured in compliance with the SMIRNOFF spec
358

359
    Returns
360
    -------
361
    converted_smirnoff_data : dict
362
        A hierarchical dict structured in compliance with the SMIRNOFF spec,
363
        with simtk.unit.Quantitys converted to string
364
    """
365

366 8
    if isinstance(smirnoff_data, dict):
367 8
        for key, value in smirnoff_data.items():
368 8
            smirnoff_data[key] = convert_all_quantities_to_string(value)
369 8
        obj_to_return = smirnoff_data
370 8
    elif isinstance(smirnoff_data, list):
371 8
        for index, item in enumerate(smirnoff_data):
372 8
            smirnoff_data[index] = convert_all_quantities_to_string(item)
373 8
        obj_to_return = smirnoff_data
374 8
    elif isinstance(smirnoff_data, unit.Quantity):
375 8
        obj_to_return = quantity_to_string(smirnoff_data)
376
    else:
377 8
        obj_to_return = smirnoff_data
378

379 8
    return obj_to_return
380

381

382 8
@functools.singledispatch
383
def object_to_quantity(object):
384
    """
385
    Attempts to turn the provided object into simtk.unit.Quantity(s).
386

387
    Can handle float, int, strings, quantities, or iterators over
388
    the same. Raises an exception if unable to convert all inputs.
389

390
    Parameters
391
    ----------
392
    object : int, float, string, quantity, or iterator of strings of quantities
393
        The object to convert to a ``simtk.unit.Quantity`` object.
394

395
    Returns
396
    -------
397
    converted_object : simtk.unit.Quantity or List[simtk.unit.Quantity]
398

399
    """
400
    # If we can't find a custom type, we treat this as a generic iterator.
401 8
    return [object_to_quantity(sub_obj) for sub_obj in object]
402

403

404 8
@object_to_quantity.register(unit.Quantity)
405
def _(obj):
406 8
    return obj
407

408

409 8
@object_to_quantity.register(str)
410
def _(obj):
411 8
    return string_to_quantity(obj)
412

413

414 8
@object_to_quantity.register(int)
415 8
@object_to_quantity.register(float)
416
def _(obj):
417 8
    return unit.Quantity(obj)
418

419

420 8
def check_units_are_compatible(object_name, object, unit_to_check, context=None):
421
    """
422
    Checks whether a simtk.unit.Quantity or list of simtk.unit.Quantitys is compatible with given unit.
423

424
    Parameters
425
    ----------
426
    object_name : string
427
        Name of object, used in printing exception.
428
    object : A simtk.unit.Quantity or list of simtk.unit.Quantitys
429
    unit_to_check : A simtk.unit.Unit
430
    context : string, optional. Default=None
431
        Additional information to provide at the beginning of the exception message if raised
432

433
    Raises
434
    ------
435
    IncompatibleUnitError
436
    """
437 8
    from simtk import unit
438

439
    # If context is not provided, explicitly make it a blank string
440 8
    if context is None:
441 8
        context = ""
442
    # Otherwise add a space after the end of it to correct message printing
443
    else:
444 0
        context += " "
445

446 8
    if isinstance(object, list):
447 0
        for sub_object in object:
448 0
            check_units_are_compatible(
449
                object_name, sub_object, unit_to_check, context=context
450
            )
451 8
    elif isinstance(object, unit.Quantity):
452 8
        if not object.unit.is_compatible(unit_to_check):
453 0
            msg = (
454
                f"{context}{object_name} with "
455
                f"value {object} is incompatible with expected unit {unit_to_check}"
456
            )
457 0
            raise IncompatibleUnitError(msg)
458
    else:
459 0
        msg = (
460
            f"{context}{object_name} with "
461
            f"value {object} is incompatible with expected unit {unit_to_check}"
462
        )
463 0
        raise IncompatibleUnitError(msg)
464

465

466 8
def extract_serialized_units_from_dict(input_dict):
467
    """
468
    Create a mapping of (potentially unit-bearing) quantities from a dictionary, where some keys exist in pairs like
469
    {'length': 8, 'length_unit':'angstrom'}.
470

471
    Parameters
472
    ----------
473
    input_dict : dict
474
       Dictionary where some keys are paired like {'X': 1.0, 'X_unit': angstrom}.
475

476
    Returns
477
    -------
478
    unitless_dict : dict
479
       input_dict, but with keys ending in ``_unit`` removed.
480
    attached_units : dict str : simtk.unit.Unit
481
       ``attached_units[parameter_name]`` is the simtk.unit.Unit combination that should be attached to corresponding
482
       parameter ``parameter_name``. For example ``attached_units['X'] = simtk.unit.angstrom.
483

484
    """
485

486
    # TODO: Should this scheme also convert "1" to int(1) and "8.0" to float(8.0)?
487 8
    from collections import OrderedDict
488

489 8
    attached_units = OrderedDict()
490 8
    unitless_dict = input_dict.copy()
491 8
    keys_to_delete = []
492 8
    for key in input_dict.keys():
493 8
        if key.endswith("_unit"):
494 0
            parameter_name = key[:-5]
495 0
            parameter_units_string = input_dict[key]
496 0
            try:
497 0
                parameter_units = string_to_unit(parameter_units_string)
498 0
            except Exception as e:
499 0
                e.msg = (
500
                    "Could not parse units {}\n".format(parameter_units_string) + e.msg
501
                )
502 0
                raise e
503 0
            attached_units[parameter_name] = parameter_units
504
            # Remember this key and delete it later (we break the dict if we delete a key in the loop)
505 0
            keys_to_delete.append(key)
506
    # Clean out the '*_unit' keys that we processed
507 8
    for key in keys_to_delete:
508 0
        del unitless_dict[key]
509

510 8
    return unitless_dict, attached_units
511

512

513 8
def attach_units(unitless_dict, attached_units):
514
    """
515
    Attach units to dict entries for which units are specified.
516

517
    Parameters
518
    ----------
519
    unitless_dict : dict
520
       Dictionary, where some items are to have units applied.
521
    attached_units : dict [str : simtk.unit.Unit]
522
       ``attached_units[parameter_name]`` is the simtk.unit.Unit combination that should be attached to corresponding
523
       parameter ``parameter_name``
524

525
    Returns
526
    -------
527
    unit_bearing_dict : dict
528
       Updated dict with simtk.unit.Unit units attached to values for which units were specified for their keys
529

530
    """
531 8
    temp_dict = unitless_dict.copy()
532 8
    for parameter_name, units_to_attach in attached_units.items():
533 0
        if parameter_name in temp_dict.keys():
534 0
            parameter_attrib_string = temp_dict[parameter_name]
535 0
            try:
536 0
                temp_dict[parameter_name] = (
537
                    float(parameter_attrib_string) * units_to_attach
538
                )
539 0
            except ValueError as e:
540 0
                e.msg = (
541
                    "Expected numeric value for parameter '{}',"
542
                    "instead found '{}' when trying to attach units '{}'\n"
543
                ).format(parameter_name, parameter_attrib_string, units_to_attach)
544 0
                raise e
545

546
        # Now check for matches like "phase1", "phase2"
547 0
        c = 1
548 0
        while (parameter_name + str(c)) in temp_dict.keys():
549 0
            indexed_parameter_name = parameter_name + str(c)
550 0
            parameter_attrib_string = temp_dict[indexed_parameter_name]
551 0
            try:
552 0
                temp_dict[indexed_parameter_name] = (
553
                    float(parameter_attrib_string) * units_to_attach
554
                )
555 0
            except ValueError as e:
556 0
                e.msg = "Expected numeric value for parameter '{}', instead found '{}' when trying to attach units '{}'\n".format(
557
                    indexed_parameter_name, parameter_attrib_string, units_to_attach
558
                )
559 0
                raise e
560 0
            c += 1
561 8
    return temp_dict
562

563

564 8
def detach_units(unit_bearing_dict, output_units=None):
565
    """
566
    Given a dict which may contain some simtk.unit.Quantity objects, return the same dict with the Quantities
567
    replaced with unitless values, and a new dict containing entries with the suffix "_unit" added, containing
568
    the units.
569

570
    Parameters
571
    ----------
572
    unit_bearing_dict : dict
573
        A dictionary potentially containing simtk.unit.Quantity objects as values.
574
    output_units : dict[str : simtk.unit.Unit], optional. Default = None
575
        A mapping from parameter fields to the output unit its value should be converted to.
576
        For example, {'length_unit': unit.angstrom}. If no output_unit is defined for a key:value pair in which
577
        the value is a simtk.unit.Quantity, the output unit will be the Quantity's unit, and this information
578
        will be included in the unit_dict return value.
579

580
    Returns
581
    -------
582
    unitless_dict : dict
583
        The input smirnoff_dict object, with all simtk.unit.Quantity values converted to unitless values.
584
    unit_dict : dict
585
        A dictionary in which keys are keys of simtk.unit.Quantity values in unit_bearing_dict,
586
        but suffixed with "_unit". Values are simtk.unit.Unit .
587
    """
588 8
    from simtk import unit
589

590 8
    if output_units is None:
591 8
        output_units = {}
592

593
    # initialize dictionaries for outputs
594 8
    unit_dict = {}
595 8
    unitless_dict = unit_bearing_dict.copy()
596

597 8
    for key, value in unit_bearing_dict.items():
598
        # If no conversion is needed, skip this item
599 8
        if not isinstance(value, unit.Quantity):
600 8
            continue
601

602
        # If conversion is needed, see if the user has requested an output unit
603 8
        unit_key = key + "_unit"
604

605 8
        if unit_key in output_units:
606 8
            output_unit = output_units[unit_key]
607
        else:
608 8
            output_unit = value.unit
609 8
        if not (output_unit.is_compatible(value.unit)):
610 8
            raise ValueError(
611
                "Requested output unit {} is not compatible with "
612
                "quantity unit {} .".format(output_unit, value.unit)
613
            )
614 8
        unitless_dict[key] = value.value_in_unit(output_unit)
615 8
        unit_dict[unit_key] = output_unit
616

617 8
    return unitless_dict, unit_dict
618

619

620 8
def serialize_numpy(np_array):
621
    """
622
    Serializes a numpy array into a JSON-compatible string. Leverages the numpy.save function,
623
    thereby preserving the shape of the input array
624

625
    from https://stackoverflow.com/questions/30698004/how-can-i-serialize-a-numpy-array-while-preserving-matrix-dimensions#30699208
626

627
    Parameters
628
    ----------
629
    np_array : A numpy array
630
        Input numpy array
631

632
    Returns
633
    -------
634
    serialized : str
635
        A serialized representation of the numpy array.
636
    shape : tuple of ints
637
        The shape of the serialized array
638
    """
639

640 8
    bigendian_array = np_array.newbyteorder(">")
641 8
    serialized = bigendian_array.tobytes()
642 8
    shape = np_array.shape
643 8
    return serialized, shape
644

645

646 8
def deserialize_numpy(serialized_np, shape):
647
    """
648
    Deserializes a numpy array from a JSON-compatible string.
649

650
    from https://stackoverflow.com/questions/30698004/how-can-i-serialize-a-numpy-array-while-preserving-matrix-dimensions#30699208
651

652
    Parameters
653
    ----------
654
    serialized_np : str
655
        A serialized numpy array
656
    shape : tuple of ints
657
        The shape of the serialized array
658
    Returns
659
    -------
660
    np_array : numpy.ndarray
661
        The deserialized numpy array
662
    """
663

664 8
    import numpy as np
665

666 8
    dt = np.dtype("float")
667 8
    dt.newbyteorder(">")  # set to big-endian
668 8
    np_array = np.frombuffer(serialized_np, dtype=dt)
669 8
    np_array = np_array.reshape(shape)
670 8
    return np_array
671

672

673 8
def convert_0_2_smirnoff_to_0_3(smirnoff_data_0_2):
674
    """
675
    Convert an 0.2-compliant SMIRNOFF dict to an 0.3-compliant one.
676
    This involves removing units from header tags and adding them
677
    to attributes of child elements.
678
    It also requires converting ProperTorsions and ImproperTorsions
679
    potentials from "charmm" to "fourier".
680

681
    Parameters
682
    ----------
683
    smirnoff_data_0_2 : dict
684
        Hierarchical dict representing a SMIRNOFF data structure according the the 0.2 spec
685

686
    Returns
687
    -------
688
    smirnoff_data_0_3
689
        Hierarchical dict representing a SMIRNOFF data structure according the the 0.3 spec
690
    """
691
    # Legacy forcefields sometimes specify the NonbondedForce's sigma_unit value, but then provide
692
    # atom size as rmin_half. Here we correct for this behavior by explicitly defining both as
693
    # the same unit if either one is defined.
694 8
    if "vdW" in smirnoff_data_0_2["SMIRNOFF"].keys():
695 8
        rmh_unit = smirnoff_data_0_2["SMIRNOFF"]["vdW"].get("rmin_half_unit", None)
696 8
        sig_unit = smirnoff_data_0_2["SMIRNOFF"]["vdW"].get("sigma_unit", None)
697 8
        if (rmh_unit is not None) and (sig_unit is None):
698 8
            smirnoff_data_0_2["SMIRNOFF"]["vdW"]["sigma_unit"] = rmh_unit
699 8
        elif (sig_unit is not None) and (rmh_unit is None):
700 8
            smirnoff_data_0_2["SMIRNOFF"]["vdW"]["rmin_half_unit"] = sig_unit
701
        # If both are None, or both are defined, don't overwrite anything
702
        else:
703
            pass
704

705
    # Recursively attach unit strings
706 8
    smirnoff_data = recursive_attach_unit_strings(smirnoff_data_0_2, {})
707

708
    # Change TorsionHandler potential from "charmm" to "k*(1+cos(periodicity*theta-phase))". Note that, scientifically,
709
    # we should have used "k*(1+cos(periodicity*theta-phase))" all along, since "charmm" technically
710
    # implies that we would support a harmonic potential for torsion terms with periodicity 0
711
    # More at: https://github.com/openforcefield/openforcefield/issues/303#issuecomment-490156779
712 8
    if "ProperTorsions" in smirnoff_data["SMIRNOFF"]:
713 8
        if "potential" in smirnoff_data["SMIRNOFF"]["ProperTorsions"]:
714 8
            if smirnoff_data["SMIRNOFF"]["ProperTorsions"]["potential"] == "charmm":
715 8
                smirnoff_data["SMIRNOFF"]["ProperTorsions"][
716
                    "potential"
717
                ] = "k*(1+cos(periodicity*theta-phase))"
718 8
    if "ImproperTorsions" in smirnoff_data["SMIRNOFF"]:
719 8
        if "potential" in smirnoff_data["SMIRNOFF"]["ImproperTorsions"]:
720 8
            if smirnoff_data["SMIRNOFF"]["ImproperTorsions"]["potential"] == "charmm":
721 8
                smirnoff_data["SMIRNOFF"]["ImproperTorsions"][
722
                    "potential"
723
                ] = "k*(1+cos(periodicity*theta-phase))"
724

725
    # Add per-section tag
726 8
    sections_not_to_version_0_3 = ["Author", "Date", "version", "aromaticity_model"]
727 8
    for l1_tag in smirnoff_data["SMIRNOFF"].keys():
728 8
        if l1_tag not in sections_not_to_version_0_3:
729

730 8
            if smirnoff_data["SMIRNOFF"][l1_tag] is None:
731
                # Handle empty entries, such as the ToolkitAM1BCC handler.
732 8
                smirnoff_data["SMIRNOFF"][l1_tag] = {}
733

734 8
            smirnoff_data["SMIRNOFF"][l1_tag]["version"] = 0.3
735

736
    # Update top-level tag
737 8
    smirnoff_data["SMIRNOFF"]["version"] = 0.3
738

739 8
    return smirnoff_data
740

741

742 8
def convert_0_1_smirnoff_to_0_2(smirnoff_data_0_1):
743
    """
744
    Convert an 0.1-compliant SMIRNOFF dict to an 0.2-compliant one.
745
    This involves renaming several tags, adding Electrostatics and ToolkitAM1BCC tags, and
746
    separating improper torsions into their own section.
747

748
    Parameters
749
    ----------
750
    smirnoff_data_0_1 : dict
751
        Hierarchical dict representing a SMIRNOFF data structure according the the 0.1 spec
752

753
    Returns
754
    -------
755
    smirnoff_data_0_2
756
        Hierarchical dict representing a SMIRNOFF data structure according the the 0.2 spec
757
    """
758 8
    smirnoff_data = smirnoff_data_0_1.copy()
759

760 8
    l0_replacement_dict = {"SMIRFF": "SMIRNOFF"}
761 8
    l1_replacement_dict = {
762
        "HarmonicBondForce": "Bonds",
763
        "HarmonicAngleForce": "Angles",
764
        "PeriodicTorsionForce": "ProperTorsions",
765
        "NonbondedForce": "vdW",
766
    }
767 8
    for old_l0_tag, new_l0_tag in l0_replacement_dict.items():
768
        # Convert first-level smirnoff_data tags.
769
        # Right now this just changes the SMIRFF tag to SMIRNOFF
770 8
        if old_l0_tag in smirnoff_data.keys():
771 8
            smirnoff_data[new_l0_tag] = smirnoff_data[old_l0_tag]
772 8
            del smirnoff_data[old_l0_tag]
773

774
    # SMIRFF tag will have been converted to SMIRNOFF here
775
    # Convert second-level tags here
776 8
    for old_l1_tag, new_l1_tag in l1_replacement_dict.items():
777 8
        if old_l1_tag in smirnoff_data["SMIRNOFF"].keys():
778 8
            smirnoff_data["SMIRNOFF"][new_l1_tag] = smirnoff_data["SMIRNOFF"][
779
                old_l1_tag
780
            ]
781 8
            del smirnoff_data["SMIRNOFF"][old_l1_tag]
782

783
    # Add 'potential' field to each l1 tag
784 8
    default_potential = {
785
        "Bonds": "harmonic",
786
        "Angles": "harmonic",
787
        "ProperTorsions": "charmm",
788
        # Note that "charmm" isn't actually correct, and was later changed
789
        # in the 0.3 spec. More info at
790
        # https://github.com/openforcefield/openforcefield/pull/311#commitcomment-33494506
791
        "vdW": "Lennard-Jones-12-6",
792
    }
793 8
    for l1_tag in smirnoff_data["SMIRNOFF"].keys():
794 8
        if l1_tag in default_potential.keys():
795
            # Ensure that it isn't there already (shouldn't happen, but better to be safe)
796 8
            if "potential" in smirnoff_data["SMIRNOFF"][l1_tag].keys():
797 0
                assert smirnoff_data[l1_tag].keys == default_potential[l1_tag]
798 0
                continue
799
            # Issue an informative warning about assumptions made during conversion.
800 8
            logger.warning(
801
                f"0.1 SMIRNOFF spec file does not contain 'potential' attribute for '{l1_tag}' tag. "
802
                f"The SMIRNOFF spec converter is assuming it has a value of '{default_potential[l1_tag]}'"
803
            )
804 8
            smirnoff_data["SMIRNOFF"][l1_tag]["potential"] = default_potential[l1_tag]
805

806
    # Separate improper torsions from propers
807 8
    if "ProperTorsions" in smirnoff_data["SMIRNOFF"]:
808 8
        if "Improper" in smirnoff_data["SMIRNOFF"]["ProperTorsions"]:
809
            # First generate an ImproperTorsions header, taking the relevant values from the ProperTorsions header
810 8
            improper_section = {
811
                "k_unit": smirnoff_data["SMIRNOFF"]["ProperTorsions"]["k_unit"],
812
                "phase_unit": smirnoff_data["SMIRNOFF"]["ProperTorsions"]["phase_unit"],
813
                "potential": smirnoff_data["SMIRNOFF"]["ProperTorsions"]["potential"],
814
                "Improper": smirnoff_data["SMIRNOFF"]["ProperTorsions"]["Improper"],
815
            }
816

817
            # Then, attach the newly-made ImproperTorsions section
818 8
            smirnoff_data["SMIRNOFF"]["ImproperTorsions"] = improper_section
819 8
            del smirnoff_data["SMIRNOFF"]["ProperTorsions"]["Improper"]
820

821
    # Add Electrostatics tag, setting several values to their defaults and
822
    # warning about assumptions that are being made
823 8
    electrostatics_section = {
824
        "method": "PME",
825
        "scale12": 0.0,
826
        "scale13": 0.0,
827
        "scale15": 1.0,
828
        "cutoff": 9.0,
829
        "cutoff_unit": "angstrom",
830
    }
831 8
    logger.warning(
832
        "0.1 SMIRNOFF spec did not allow the 'Electrostatics' tag. Adding it in 0.2 spec conversion, and "
833
        "assuming the following values:"
834
    )
835 8
    for key, val in electrostatics_section.items():
836 8
        logger.warning(f"\t{key}: {val}")
837

838
    # Take electrostatics 1-4 scaling term from 0.1 spec's NonBondedForce tag
839 8
    electrostatics_section["scale14"] = smirnoff_data["SMIRNOFF"]["vdW"][
840
        "coulomb14scale"
841
    ]
842 8
    del smirnoff_data["SMIRNOFF"]["vdW"]["coulomb14scale"]
843 8
    smirnoff_data["SMIRNOFF"]["Electrostatics"] = electrostatics_section
844

845
    # Change vdW's lj14scale to 14scale, add other scaling terms
846 8
    vdw_section_additions = {
847
        "method": "cutoff",
848
        "combining_rules": "Lorentz-Berthelot",
849
        "scale12": "0.0",
850
        "scale13": "0.0",
851
        "scale15": "1",
852
        "switch_width": "1.0",
853
        "switch_width_unit": "angstrom",
854
        "cutoff": "9.0",
855
        "cutoff_unit": "angstrom",
856
    }
857 8
    for key, val in vdw_section_additions.items():
858 8
        if not key in smirnoff_data["SMIRNOFF"]["vdW"].keys():
859 8
            logger.warning(
860
                f"0.1 SMIRNOFF spec file does not contain '{key}' attribute for 'NonBondedMethod/vdW'' tag. "
861
                f"The SMIRNOFF spec converter is assuming it has a value of '{val}'"
862
            )
863 8
            smirnoff_data["SMIRNOFF"]["vdW"][key] = val
864

865
    # Rename L-J 1-4 scaling term from 0.1 spec's NonBondedForce tag to vdW's scale14
866 8
    smirnoff_data["SMIRNOFF"]["vdW"]["scale14"] = smirnoff_data["SMIRNOFF"]["vdW"][
867
        "lj14scale"
868
    ]
869 8
    del smirnoff_data["SMIRNOFF"]["vdW"]["lj14scale"]
870

871
    # Add <ToolkitAM1BCC/> tag
872 8
    smirnoff_data["SMIRNOFF"]["ToolkitAM1BCC"] = {}
873

874
    # Update top-level tag
875 8
    smirnoff_data["SMIRNOFF"]["version"] = 0.2
876

877 8
    return smirnoff_data
878

879

880 8
def recursive_attach_unit_strings(smirnoff_data, units_to_attach):
881
    """
882
    Recursively traverse a SMIRNOFF data structure, appending "* {unit}" to values in key:value pairs
883
    where "key_unit":"unit_string" is present at a higher level in the hierarchy.
884
    This function expects all items in smirnoff_data to be formatted as strings.
885

886
    Parameters
887
    ----------
888
    smirnoff_data : dict
889
        Any level of hierarchy that is part of a SMIRNOFF dict, with all data members
890
        formatted as string.
891
    units_to_attach : dict
892
        Dict of the form {key:unit_string}
893

894
    Returns
895
    -------
896
    unit_appended_smirnoff_data: dict
897
    """
898 8
    import re
899

900
    # Make a copy of units_to_attach so we don't modify the original (otherwise things like k_unit could
901
    # leak between sections)
902 8
    units_to_attach = units_to_attach.copy()
903
    # smirnoff_data = smirnoff_data.copy()
904

905
    # If we're working with a dict, see if there are any new unit entries and store them,
906
    # then operate recursively on the values in the dict.
907 8
    if isinstance(smirnoff_data, dict):
908

909
        # Go over all key:value pairs once to see if there are new units to attach.
910
        # Note that units to be attached can be defined in the same dict as the
911
        # key:value pair they will be attached to, so we need to complete this check
912
        # before we are able to check other items in the dict.
913 8
        for key, value in list(smirnoff_data.items()):
914 8
            if key[-5:] == "_unit":
915 8
                units_to_attach[key[:-5]] = value
916 8
                del smirnoff_data[key]
917

918
        # Go through once more to attach units as appropriate
919 8
        for key in smirnoff_data.keys():
920

921
            # We use regular expressions to catch possible indexed attributes
922 8
            attach_unit = None
923 8
            for unit_key, unit_string in units_to_attach.items():
924 8
                if re.match(f"{unit_key}[0-9]*", key):
925 8
                    attach_unit = unit_string
926

927 8
            if attach_unit is not None:
928 8
                smirnoff_data[key] = str(smirnoff_data[key]) + " * " + attach_unit
929

930
            # And recursively act on value, in case it's a deeper level of hierarchy
931 8
            smirnoff_data[key] = recursive_attach_unit_strings(
932
                smirnoff_data[key], units_to_attach
933
            )
934

935
    # If it's a list, operate on each member of the list
936 8
    elif isinstance(smirnoff_data, list):
937 8
        for index, value in enumerate(smirnoff_data):
938 8
            smirnoff_data[index] = recursive_attach_unit_strings(value, units_to_attach)
939

940
    # Otherwise, just return smirnoff_data unchanged
941
    else:
942
        pass
943

944 8
    return smirnoff_data
945

946

947 8
def get_molecule_parameterIDs(molecules, forcefield):
948
    """Process a list of molecules with a specified SMIRNOFF ffxml file and determine which parameters are used by
949
    which molecules, returning collated results.
950

951
    Parameters
952
    ----------
953
    molecules : list of openforcefield.topology.Molecule
954
        List of molecules (with explicit hydrogens) to parse
955
    forcefield : openforcefield.typing.engines.smirnoff.ForceField
956
        The ForceField to apply
957

958
    Returns
959
    -------
960
    parameters_by_molecule : dict
961
        Parameter IDs used in each molecule, keyed by isomeric SMILES
962
        generated from provided OEMols. Each entry in the dict is a list
963
        which does not necessarily have unique entries; i.e. parameter IDs
964
        which are used more than once will occur multiple times.
965

966
    parameters_by_ID : dict
967
        Molecules in which each parameter ID occur, keyed by parameter ID.
968
        Each entry in the dict is a set of isomeric SMILES for molecules
969
        in which that parameter occurs. No frequency information is stored.
970

971
    """
972

973 0
    from openforcefield.topology import Topology
974

975
    # Create storage
976 0
    parameters_by_molecule = dict()
977 0
    parameters_by_ID = dict()
978

979
    # Generate isomeric SMILES for each molecule, ensuring all molecules are unique
980 0
    isosmiles = [molecule.to_smiles() for molecule in molecules]
981 0
    already_seen = set()
982 0
    duplicates = set(
983
        smiles
984
        for smiles in isosmiles
985
        if smiles in already_seen or already_seen.add(smiles)
986
    )
987 0
    if len(duplicates) > 0:
988 0
        raise ValueError(
989
            "Error: get_molecule_parameterIDs has been provided a list of oemols which contains some duplicates: {}".format(
990
                duplicates
991
            )
992
        )
993

994
    # Assemble molecules into a Topology
995 0
    topology = Topology()
996 0
    for molecule in molecules:
997 0
        topology.add_molecule(molecule)
998

999
    # Label molecules
1000 0
    labels = forcefield.label_molecules(topology)
1001

1002
    # Organize labels into output dictionary by looping over all molecules/smiles
1003 0
    for idx in range(len(isosmiles)):
1004
        # Pull smiles, initialize storage
1005 0
        smi = isosmiles[idx]
1006 0
        parameters_by_molecule[smi] = []
1007

1008
        # Organize data for this molecule
1009 0
        data = labels[idx]
1010 0
        for force_type in data.keys():
1011 0
            for atom_indices, parameter_type in data[force_type].items():
1012

1013 0
                pid = parameter_type.id
1014
                # Store pid to molecule
1015 0
                parameters_by_molecule[smi].append(pid)
1016

1017
                # Store which molecule this pid occurred in
1018 0
                if pid not in parameters_by_ID:
1019 0
                    parameters_by_ID[pid] = set()
1020 0
                    parameters_by_ID[pid].add(smi)
1021
                else:
1022 0
                    parameters_by_ID[pid].add(smi)
1023

1024 0
    return parameters_by_molecule, parameters_by_ID
1025

1026

1027 8
def sort_smirnoff_dict(data):
1028
    """
1029
    Recursively sort the keys in a dict of SMIRNOFF data.
1030

1031
    Adapted from https://stackoverflow.com/a/47882384/4248961
1032

1033
    TODO: Should this live elsewhere?
1034
    """
1035 8
    sorted_dict = dict()
1036 8
    for key, val in sorted(data.items()):
1037 8
        if isinstance(val, dict):
1038
            # This should hit each ParameterHandler and dicts within them
1039 8
            sorted_dict[key] = sort_smirnoff_dict(val)
1040 8
        elif isinstance(val, list):
1041
            # Handle case of ParameterLists, which show up in
1042
            # the smirnoff dicts as lists of OrderedDicts
1043 8
            new_parameter_list = list()
1044 8
            for param in val:
1045 8
                new_parameter_list.append(sort_smirnoff_dict(param))
1046 8
            sorted_dict[key] = new_parameter_list
1047
        else:
1048
            # Handle metadata or the bottom of a recursive dict
1049 8
            sorted_dict[key] = val
1050 8
    return sorted_dict

Read our documentation on viewing source code .

Loading