1
"""
2
cluster_graph.py
3

4
ClusterGraph are a class for tracking all possible smirks decorators in a group
5
(or cluster) of molecular fragments. Moving forward these will be used to
6
find the minimum number of smirks decorators that are required to have a
7
set of smirks patterns that maintain a given clustering of fragments.
8
"""
9

10 100
import networkx as nx
11 100
from functools import total_ordering
12 100
from chemper.graphs.single_graph import SingleGraph
13 100
from chemper.graphs.environment import ChemicalEnvironment as CE
14 100
from chemper.mol_toolkits import mol_toolkit
15

16

17 100
@total_ordering
18 100
class ClusterGraph(SingleGraph):
19
    """
20
    ChemPerGraphs are a graph based class for storing atom and bond information.
21
    They use the chemper.mol_toolkits Atoms, Bonds, and Mols
22
    """
23 100
    @total_ordering
24
    class AtomStorage:
25
        """
26
        AtomStorage tracks information about an atom
27
        """
28 100
        def __init__(self, atoms=None, label=None):
29
            """
30
            Parameters
31
            ----------
32
            atoms : ChemPer Atom or list of ChemPer Atoms
33
                this is one or more atoms whose information should be stored
34
            label : int
35
                SMIRKS index (:n) for writing SMIRKS
36
                if the value is less than zero it is used for storage purposes
37
                only as SMIRKS can only be written with positive integer indices
38
            """
39 100
            self.decorators = set()
40 100
            if atoms is not None:
41
                # check if this is a single atom
42 100
                if 'Atom' in str(type(atoms)):
43 100
                    atoms = [atoms]
44

45
                # otherwise it should be iterable
46 100
                for atom in atoms:
47 100
                    self.decorators.add(self.make_atom_decorators(atom))
48 100
            self.label = label
49

50 100
        def __lt__(self, other):
51
            """
52
            Overrides the default implementation
53
            This method was primarily written for making SMIRKS patterns predictable.
54
            If atoms are sortable, then the SMIRKS patterns are always the same making
55
            tests easier to write. However, the specific sorting was created to also make SMIRKS
56
            output as human readable as possible, that is to at least make it easier for a
57
            human to see how the indexed atoms are related to each other.
58
            It is typically easier for humans to read SMILES/SMARTS/SMIRKS with less branching (indicated with ()).
59

60
            For example in:
61
            [C:1]([H])([H])~[N:2]([C])~[O:3]
62
            it is easier to see that the atoms C~N~O are connected in a "line" instead of:
63
            [C:1]([N:2]([O:3])[C])([H])[H]
64
            which is equivalent, but with all the () it is hard for a human to read the branching
65

66
            Parameters
67
            ----------
68
            other : AtomStorage
69

70
            Returns
71
            -------
72
            is_less_than : boolean
73
                self is less than other
74
            """
75
            # if either smirks index is None, then you can't directly compare
76
            # make a temporary index that is negative if it was None
77 100
            self_index = self.label if self.label is not None else -1000
78 100
            other_index = other.label if other.label is not None else -1000
79
            # if either index is greater than 0, the one that is largest should go at the end of the list
80 100
            if self_index > 0 or other_index > 0:
81 100
                return self_index < other_index
82

83
            # Both SMIRKS indices are not positive or None so compare the SMIRKS patterns instead
84 100
            return self.as_smirks() < other.as_smirks()
85

86 100
        def __eq__(self, other): return self.as_smirks() == other.as_smirks() and self.label == other.label
87

88 100
        def __hash__(self): return id(self)
89

90 100
        def __str__(self): return self.as_smirks()
91

92 100
        def make_atom_decorators(self, atom):
93
            """
94
            extract information from a ChemPer Atom that would be useful in a smirks
95

96
            parameters
97
            ----------
98
            atom : ChemPer atom object
99

100
            returns
101
            -------
102
            decorators : tuple of str
103
                tuple of all possible decorators for this atom
104
            """
105 100
            aromatic = 'a' if atom.is_aromatic() else 'A'
106 100
            charge = atom.formal_charge()
107 100
            if charge >= 0:
108 100
                charge = '+%i' % charge
109
            else:
110 100
                charge = '%i' % charge
111 100
            min_ring_size = atom.min_ring_size()
112 100
            if min_ring_size == 0:
113 100
                ring = '!r'
114
            else:
115 100
                ring = 'r%i' % min_ring_size
116

117 100
            return (
118
                '#%i' % atom.atomic_number(),
119
                'H%i' % atom.hydrogen_count(),
120
                'X%i' % atom.connectivity(),
121
                'x%i' % atom.ring_connectivity(),
122
                ring,
123
                charge,
124
                aromatic,
125
                )
126

127 100
        def as_smirks(self, compress=False):
128
            """
129
            Parameters
130
            ----------
131
            compress : boolean
132
                should decorators common to all sets be combined
133
                for example '#6X4,#7X3;+0!r...'
134

135
            Returns
136
            -------
137
            smirks : str
138
                how this atom would be represented in a SMIRKS string
139
                with the minimal combination of SMIRKS decorators
140
            """
141 100
            if len(self.decorators) == 0:
142 100
                if self.label is None or self.label <= 0:
143 100
                    return '[*]'
144 100
                return '[*:%i]' % self.label
145

146 100
            if compress and len(self.decorators) > 1:
147 100
                base_smirks = self._compress_smirks()
148
            else:
149 100
                base_smirks = ','.join(sorted([''.join(l) for l in self.decorators]))
150

151 100
            if self.label is None or self.label <= 0:
152 100
                return '[%s]' % base_smirks
153

154 100
            return '[%s:%i]' % (base_smirks, self.label)
155

156 100
        def _sort_decs(self, dec_set, wild=True):
157
            """
158
            Parameters
159
            ----------
160
            dec_set : list like
161
                single set of atom decorators
162
            wild : boolean
163
                insert * for decorator lists with no #n decorator
164

165
            Returns
166
            -------
167
            sorted_dec_set : list
168
                same set of decorators sorted with atomic number or * first
169
            """
170 100
            temp_dec_set = list(dec_set)
171 100
            atom_num = [i for i in temp_dec_set if '#' in i]
172 100
            if len(atom_num) == 0 and wild:
173 0
                atom_num = ["*"]
174

175 100
            temp_dec_set = set(temp_dec_set) - set(atom_num)
176

177 100
            aro = [i for i in temp_dec_set if 'a' in i.lower()]
178 100
            temp_dec_set = set(temp_dec_set) - set(aro)
179

180 100
            return atom_num + sorted(list(temp_dec_set)) + aro
181

182 100
        def _compress_smirks(self):
183
            """
184
            Returns
185
            -------
186
            smirks : str
187
                This SMIRKS is compressed with all common decorators and'd to
188
                the end of the pattern
189
            """
190 100
            set_decs = [set(d) for d in self.decorators]
191 100
            ands = set_decs[0]
192

193 100
            for d_set in set_decs:
194 100
                ands = ands & d_set
195

196
            # check for atomic number in the "ands"
197 100
            atomic = [a for a in ands if '#' in a]
198 100
            if len(atomic) == 1:
199
                # remove from and
200 100
                ands.remove(atomic[0])
201
                # put in all sets
202 100
                for s in set_decs:
203 100
                    s.add(atomic[0])
204

205 100
            or_sets = [self._sort_decs(d.difference(ands)) for d in set_decs]
206 100
            ors = [''.join(o) for o in or_sets]
207

208
            # add commas between ors
209 100
            base = ','.join(sorted(ors))
210
            # add and decorators
211 100
            if len(ands) > 0:
212 100
                base += ';'+ ';'.join(self._sort_decs(ands, wild=False))
213 100
            return base
214

215 100
        def add_atom(self, atom):
216
            """
217
            Expand current AtomStorage by adding information about
218
            a new ChemPer Atom
219

220
            Parameters
221
            ----------
222
            atom : ChemPer Atom
223
            """
224 100
            self.decorators.add(self.make_atom_decorators(atom))
225

226 100
        def compare_atom(self, atom):
227
            """
228
            Compares decorators in this AtomStorage with the provided
229
            ChemPer atom. The decorators are compared separately and
230
            the highest score is returned. For example,
231
            if this storage had two sets of decorators
232
                - #7H1X3x0!r+0A
233
                - #6H1X4x0!r+0A
234
            and the input atom would have the decorators:
235
                - #6H1X3x2!r+0a
236

237
            The score is calculated by finding the number of decorators
238
            in common which would be
239
                - #7H1X3x0!r+0A and #6H1X3x2r6+0a
240
                    have 3 decorators in common (H1,X3,+0)
241
                - #6H1X4x0!r+0A and #6H1X3x2r6+0a
242
                    also have 3 decorators in common (#6, H1, +0)
243
            However, we weight atoms with the same atomic number as more
244
            similar by dividing the score by 10 if the atomic numbers do
245
            not agree. Therefore the final scores will be:
246
                - 0.3 for #7H1X3x0!r+0A
247
                - 3 for #6H1X4x0!r+0A
248

249
            The highest score for any set of decorators is returned so
250
            3 is the returned score in this example.
251

252
            Parameters
253
            ----------
254
            atom : ChemPer Atom
255

256
            Returns
257
            -------
258
            score : float
259
                A score describing how similar the input atom is to any set of
260
                decorators currently in this storage, based on its SMIRKS decorators.
261
                This score ranges from 0 to 7. 7 comes from the number of decorators
262
                on any atom, if this atom matches perfectly with one of the
263
                current decorator sets then 7 decorators agree.However, if the atomic
264
                number doesn't agree, then that set of decorators is considered
265
                less ideal, thus if the atomic numbers don't agree, then the score
266
                is given by the number other decorators divided by 10.
267
                If the current storage is empty, then the score is given as 7
268
                since any atom matches a wildcard atom.
269
            """
270
            # If decorators is empty (no known atom information, return 7 (current max)
271 100
            if len(self.decorators) == 0:
272 0
                return 7
273

274 100
            score = 0
275 100
            decs = self.make_atom_decorators(atom)
276

277 100
            for ref in self.decorators:
278
                # get atomic number for this set of decorators
279 100
                current = len(set(ref) & set(decs))
280

281
                # if atomic numbers don't agree, get the number of common decorators / 10
282
                # if there are no matching atomic numbers, priority should still be given
283
                # when the current atom matches stored decorators most closely
284 100
                if ref[0] != decs[0]:
285 100
                    current = current / 10.0
286

287 100
                if current > score:
288 100
                    score = current
289

290 100
            return score
291

292 100
    @total_ordering
293
    class BondStorage:
294
        """
295
        BondStorage tracks information about a bond
296
        """
297 100
        def __init__(self, bonds=None, label=None):
298
            """
299
            Parameters
300
            ----------
301
            bonds : list of ChemPer Bonds
302
                this is one or more bonds whose information should be stored
303
            label : a label for the object, it can be anything
304
                unlike atoms, bonds in smirks don't have labels
305
                so this is only used for labeling the object if wanted
306
            """
307 100
            self.order = set()
308 100
            self.ring = set()
309 100
            self.order_dict = {1:'-', 1.5:':', 2:'=', 3:'#'}
310 100
            if bonds is not None:
311 100
                if 'Bond' in str(type(bonds)):
312 100
                    bonds = [bonds]
313 100
                for bond in bonds:
314 100
                    self.order.add(bond.get_order())
315 100
                    self.ring.add(bond.is_ring())
316

317 100
            self.label = label
318

319 100
        def __str__(self): return self.as_smirks()
320

321 100
        def __lt__(self, other):
322 0
            if self.as_smirks() == other.as_smirks():
323 0
                return self.label < other.label
324 0
            return self.as_smirks() < other.as_smirks()
325

326 100
        def __eq__(self, other):
327 0
            return self.label == other.label and self.as_smirks() == other.as__smirks()
328

329 100
        def __hash__(self): return id(self)
330
        
331 100
        def as_smirks(self):
332
            """
333
            Returns
334
            -------
335
            smirks : str
336
                how this bond would be represented in a SMIRKS string
337
                using only the required number of
338
            """
339 100
            if len(self.order) == 0:
340 100
                order = '~'
341
            else:
342 100
                order = ','.join([self.order_dict.get(o, '~') for o in sorted(list(self.order))])
343

344
            # the ring set has booleans, if the length of the set is 1 then only ring (@) or non-ring (!@)
345
            # bonds haven been added to this storage and we AND that decorator to the end of the bond
346 100
            if len(self.ring) == 1:
347 100
                if list(self.ring)[0]:
348 100
                    return order+';@'
349
                else:
350 100
                    return order+';!@'
351

352 100
            return order
353

354 100
        def add_bond(self, bond):
355
            """
356
            Expand current BondStorage by adding information about
357
            a new ChemPer Bond
358

359
            Parameters
360
            ----------
361
            bond : ChemPer Bond
362
            """
363 100
            self.order.add(bond.get_order())
364 100
            self.ring.add(bond.is_ring())
365

366 100
        def compare_bond(self, bond):
367
            """
368

369
            Parameters
370
            ----------
371
            bond : ChemPer Bond
372
                bond you want to compare to the current storage
373

374
            Returns
375
            -------
376
            score : int (0,1,2)
377
                A score describing how similar the input bond is to any set of decorators currently
378
                in this storage, based on its SMIRKS decorators.
379

380
                1 for the bond order +
381
                1 base on if this is a ring bond
382
            """
383 100
            score = 0
384 100
            if bond.get_order() in self.order or len(self.order) == 0:
385 100
                score += 1
386

387
            # the ring set has booleans, if the length of the set is 1 then only ring or non-ring
388
            # bonds haven been added to this storage. That is the only time the ring contributes to the score
389 100
            if len(self.ring) == 1 and list(self.ring)[0] == bond.is_ring():
390 100
                score += 1
391

392 100
            return score
393

394
    # Initiate ClusterGraph
395 100
    def __init__(self, mols=None, smirks_atoms_lists=None, layers=0):
396
        """
397
        Initialize a SingleGraph from a molecule and list of indexed atoms
398

399
        For the example, imagine we wanted to get a SMIRKS that
400
        would match the carbon-carbon bonds in ethane and propane.
401
        The carbon atoms are have indices (0,1) in ethane and (0,1) and (1,2)
402
        in propane. For this example, we will assume we also want to include
403
        the atoms one bond away from the indexed atoms (1 layer away).
404

405
        Parameters
406
        ----------
407
        mols : list of molecules (optional)
408
            default = None (makes an empty graph)
409
            these can be ChemPer Mols or molecule objects from
410
            any supported toolkit (currently OpenEye or RDKit)
411

412
        smirks_atoms_lists : list of list of tuples (optional)
413
            default = None (must be paired with mols=None)
414
            There is a list of tuples for each molecule, where each tuple specifies
415
            a molecular fragment using the atoms' indices.
416
            In the ethane and propane example, the `smirks_atoms_lists` would be
417
                [ [ (0,1) ], [ (0,1), (1,2) ] ]
418
            with one carbon-carbon bond in ethane and two carbon-carbon bonds in propane
419

420
        layers : int (optional)
421
            default = 0
422
            layers specifies how many bonds away from the indexed atoms should be included in the
423
            the SMIRKS patterns.
424
            Instead of an int, the string 'all' would lead to all atoms in the molecules
425
            being included in the SMIRKS (not recommended)
426
        """
427 100
        SingleGraph.__init__(self)
428

429 100
        self.mols = list()
430 100
        self.smirks_atoms_lists = list()
431 100
        self.layers = layers
432 100
        self._symmetry_funct = self._no_symmetry
433

434 100
        if mols is not None:
435 100
            temp_mols = [mol_toolkit.Mol(m) for m in mols]
436 100
            if len(temp_mols) != len(smirks_atoms_lists):
437 100
                raise Exception('Number of molecules and smirks dictionaries should be equal')
438

439 100
            for idx, mol in enumerate(temp_mols):
440 100
                self.add_mol(mol, smirks_atoms_lists[idx])
441

442 100
    def as_smirks(self, compress=False):
443
        """
444
        Parameters
445
        ----------
446
        compress : boolean
447
            returns the shorter version of atom SMIRKS patterns
448
            that is atoms have decorators "anded" to the end rather than listed
449
            in each set that are OR'd together.
450
            For example "[#6AH2X3x0!r+0,#6AH1X3x0!r+0:1]-;!@[#1AH0X1x0!r+0]"
451
            compresses to: "[#6H2,#6H1;AX3x0!r+0:1]-;!@[#1AH0X1x0!r+0]"
452

453
        Returns
454
        -------
455
        SMIRKS : str
456
            a SMIRKS string matching the exact atom and bond information stored
457
        """
458
        # The atom compression is different, but otherwise this is the
459
        # same function as the parent class (SingleGraph)
460 100
        return SingleGraph.as_smirks(self, compress)
461

462 100
    def get_symmetry_funct(self, sym_label):
463
        """
464
        Determine the symmetry function that should be used
465
        when adding atoms to this graph.
466

467
        For example, imagine a user is trying to make a
468
        SMIRKS for all of the C-H bonds in methane. In most
469
        toolkits the index for the carbon is 0 and the hydrogens are 1,2,3,4.
470
        The final SMIRKS should have the form [#6AH4X4x0!r+0:1]-;!@[#1AH0X1x0!r+0]
471
        no matter what order the atoms are input into ClusterGraph.
472
        So if the user provides (0,1), (0,2), (3,0), (4,0) ClusterGraph
473
        should figure out that the carbons in (3,0) and (4,0) should be in
474
        the atom index :1 place like they were in the first set of atoms.
475

476
        Bond atoms in (1,2) or (2,1) are symmetric, for angles its (1,2,3) or (3,2,1)
477
        for proper torsions (1,2,3,4) or (4,3,2,1) and for
478
        improper torsions (1,2,3,4), (3,2,1,4), (4,2,1,3).
479
        For any other fragment type the atoms will be added to the graph in
480
        the order they are provided since the symmetry function is unknown.
481

482
        # TODO: In theory you could generalize this for generic linear fragments
483
        # where those with an odd number of atoms behave like angles and an
484
        # even number behave like proper torsions, however I think that is
485
        # going to be outside the scope of ChemPer for the foreseeable future.
486

487
        Parameters
488
        ----------
489
        sym_label : str or None
490
            type of symmetry, options which will change the way symmetry is
491
            handled in the graph are "bond", "angle", "ProperTorsion", and "ImproperTorsion"
492

493
        Returns
494
        -------
495
        symmetry_funct : function
496
            returns the function that should be used to handle the appropriate symmetry
497
        """
498 100
        if sym_label is None:
499 0
            return self._no_symmetry
500 100
        if sym_label.lower() == 'bond':
501 100
            return self._bond_symmetry
502 100
        if sym_label.lower() == 'angle':
503 100
            return self._angle_symmetry
504 100
        if sym_label.lower() == 'propertorsion':
505 100
            return self._proper_torsion_symmetry
506 100
        if sym_label.lower() == 'impropertorsion':
507 0
            return self._improper_torsion_symmetry
508 100
        return self._no_symmetry
509

510 100
    def add_mol(self, input_mol, smirks_atoms_list):
511
        """
512
        Expand the information in this graph by adding a new molecule
513

514
        Parameters
515
        ----------
516
        input_mol : ChemPer Mol
517
        smirks_atoms_list : list of tuples
518
            This is a list of tuples with atom indices [ (indices), ... ]
519
        """
520 100
        mol = mol_toolkit.Mol(input_mol)
521

522 100
        if len(smirks_atoms_list) == 0:
523 100
            return
524

525 100
        if len(self.mols) == 0:
526 100
            self._add_first_smirks_atoms(mol, smirks_atoms_list[0])
527 100
            self._symmetry_funct = self.get_symmetry_funct(CE(self.as_smirks()).get_type())
528 100
            self._add_mol(mol, smirks_atoms_list[1:])
529
        else:
530 100
            self._add_mol(mol, smirks_atoms_list)
531

532 100
        self.mols.append(mol)
533 100
        self.smirks_atoms_lists.append(smirks_atoms_list)
534

535 100
    def _add_first_smirks_atoms(self, mol, smirks_atoms):
536
        """
537
        private function for adding the first molecule to an empty ClusterGraph
538
        add_mol calls this if the graph is empty
539

540
        Parameters
541
        ----------
542
        mol : ChemPer Mol
543
        smirks_atoms : tuple
544
            tuple of atom indices for the first atoms to add to the graph. i.e. (0, 1)
545
        """
546 100
        atom_dict = dict()
547 100
        for key, atom_index in enumerate(smirks_atoms, 1):
548 100
            atom_dict[atom_index] = key
549

550 100
            atom1 = mol.get_atom_by_index(atom_index)
551 100
            new_atom_storage = self.AtomStorage([atom1], key)
552 100
            self._graph.add_node(new_atom_storage)
553 100
            self.atom_by_label[key] = new_atom_storage
554

555
            # Check for bonded atoms already in the graph
556 100
            for neighbor_key in range(len(smirks_atoms), 0, -1):
557 100
                if neighbor_key not in self.atom_by_label:
558 100
                    continue
559

560
                # check if atoms are already connected on the graph
561 100
                neighbor_storage = self.atom_by_label[neighbor_key]
562 100
                if nx.has_path(self._graph, new_atom_storage, neighbor_storage):
563 100
                    continue
564

565
                # check if atoms are connected in the molecule
566 100
                atom2 = mol.get_atom_by_index(smirks_atoms[neighbor_key-1])
567 100
                bond = mol.get_bond_by_atoms(atom1, atom2)
568

569 100
                if bond is not None: # Atoms are connected add edge
570 100
                    bond_smirks = tuple(sorted([neighbor_key, key]))
571 100
                    bond_storage = self.BondStorage([bond], bond_smirks)
572 100
                    self.bond_by_label[bond_smirks] = bond_storage
573 100
                    self._graph.add_edge(new_atom_storage,
574
                                         neighbor_storage,
575
                                         bond=bond_storage)
576

577
        # for each indexed atoms add unindexed atoms for the number of specified layers
578 100
        for atom_label, atom_index in enumerate(smirks_atoms, 1):
579 100
            atom = mol.get_atom_by_index(atom_index)
580 100
            storage = self.atom_by_label[atom_label]
581 100
            self._add_layers(mol, atom, storage, self.layers, atom_dict, is_first=True)
582

583 100
    def _add_layers(self, mol, atom, storage, layers, idx_dict, is_first=False):
584
        """
585
        Parameters
586
        ----------
587
        mol : ChemPer Mol
588
            molecule containing provided atom
589
        atom : ChemPer Atom
590
        storage: AtomStorage
591
            corresponding to the ChemPer Atom provided
592
        layers : int or 'all'
593
            number of layers left to add (or all)
594
        idx_dict : dict
595
            form {atom index: label} for this smirks_list in this molecule
596
        """
597
        # if layers is 0 there are no more atoms to add so end the recursion
598 100
        if layers == 0:
599 100
            return
600

601
        # find atom neighbors that are not already included in SMIRKS indexed atoms
602 100
        atom_neighbors = [(a, mol.get_bond_by_atoms(a,atom)) for a in atom.get_neighbors() \
603
                          if a.get_index() not in idx_dict]
604

605
        # get the smirks indices already added to the storage
606
        # This includes all previous layers since the idx_dict is updated as you go
607 100
        storage_labels = [e for k,e in idx_dict.items()]
608

609
        # similar to atoms find neighbors already in the graph that haven't already been used
610 100
        storage_neighbors = [(s, self.get_connecting_bond(s, storage)) for s in self.get_neighbors(storage) \
611
                             if s.label not in storage_labels]
612

613 100
        new_pairs = list()
614
        # if this is the first set of atoms added, just make a new
615
        # storage for all neighboring atoms
616 100
        if is_first:
617 100
            min_smirks = storage.label * 10
618 100
            if min_smirks > 0:
619 100
                min_smirks = min_smirks * -1
620

621 100
            for a, b in atom_neighbors:
622 100
                new_bond_smirks = tuple(sorted([storage.label, min_smirks]))
623

624 100
                adding_new_storage = self.add_atom(a,b,storage,
625
                                                   min_smirks, new_bond_smirks)
626

627 100
                idx_dict[a.get_index()] = min_smirks
628 100
                self.atom_by_label[min_smirks] = adding_new_storage
629 100
                min_smirks -= 1
630 100
                new_pairs.append((a, adding_new_storage))
631

632
        else: # this isn't the first set of atoms so you need to
633
            # pair up the atoms with their storage
634 100
            pairs = self.find_pairs(atom_neighbors, storage_neighbors)
635 100
            for new_atom, new_bond, new_storage_atom, new_storage_bond in pairs:
636
                # if no storage is paired to this atom skip it
637 100
                if new_storage_atom is None:
638 100
                    continue
639
                # if there is no atom paired to a storage remove that branch
640 100
                if new_atom is None:
641 100
                    self.remove_atom(new_storage_atom)
642 100
                    continue
643
                # add atom and bond information to the storage
644 100
                new_storage_atom.add_atom(new_atom)
645 100
                new_storage_bond.add_bond(new_bond)
646 100
                new_pairs.append((new_atom, new_storage_atom))
647 100
                idx_dict[new_atom.get_index()] = new_storage_atom.label
648

649
        # Repeat for the extra layers
650 100
        if layers == 'all':
651 100
            new_layers = 'all'
652
        else:
653 100
            new_layers = layers - 1
654 100
            if new_layers == 0:
655 100
                return
656

657 100
        for new_atom, new_storage in new_pairs:
658 100
            self._add_layers(mol, new_atom, new_storage, new_layers, idx_dict, is_first)
659

660 100
    def find_pairs(self, atoms_and_bonds, storages):
661
        """
662
        Find pairs is used to determine which current AtomStorage from storages
663
        atoms should be paired with.
664
        This function takes advantage of the maximum scoring function in networkx
665
        to find the pairing with the highest "score".
666
        Scores are determined using functions in the atom and bond storage objects
667
        that compare those storages to the new atom or bond.
668

669
        If there are less atoms than storages then the atoms with the lowest pair are
670
        assigned a None pairing.
671

672
        Parameters
673
        ----------
674
        atoms_and_bonds : list of tuples in form (ChemPer Atom, ChemPer Bond, ...)
675
        storages: list of tuples in form (AtomStorage, BondStorage, ...)
676

677
        Tuples can be of any length as long as they are the same, so for example, in
678
        a bond you might only care about the outer atoms for comparison so you would compare
679
        (atom1,) and (atom2,) with (atom_storage1,) and (atom_storage2,)
680
        However, in a torsion, you might want the atoms and bonds for each outer bond
681
        so in that case you would compare
682
        (atom1, bond1, atom2) and (atom4, bond3, atom3)
683
        with the corresponding storage objects.
684

685
        Returns
686
        -------
687
        pairs : list of lists
688
            pairs of atoms and storage objects that are most similar,
689
            these lists always come in the form (all atom/bonds, all storage objects)
690
            for the bond example above you might get
691
            [ [atom1, storage1], [atom2, storage2] ]
692
            for the torsion example you might get
693
            [ [atom4, bond4, atom3, atom_storage1, bond_storage1, atom_storage2],
694
              [atom1, bond1, atom2, atom_storage4, bond_storage3, atom_storage3]
695

696
        """
697
        # store paired stets of atoms/bonds and corresponding storages
698 100
        pairs = list()
699
        # check for odd cases
700 100
        combo = atoms_and_bonds + storages
701
        # 1. both lists are empty
702 100
        if len(combo) == 0:
703 100
            return pairs
704

705 100
        nones = [None] * len(combo[0])
706
        # 2. no atom/bond storage
707 100
        if len(atoms_and_bonds) == 0:
708 100
            for storage_set in storages:
709 100
                pairs.append(nones + list(storage_set))
710 100
            return pairs
711

712
        # 3. no storages
713 100
        if len(storages) == 0:
714 100
            for atom_set in atoms_and_bonds:
715 100
                pairs.append(list(atom_set) + nones)
716 100
            return pairs
717

718 100
        g = nx.Graph()
719

720 100
        atom_dict = dict()
721 100
        storage_dict = dict()
722

723
        # create a bipartite graph with atoms/bonds on one side
724 100
        for idx, atom_set in enumerate(atoms_and_bonds):
725 100
            g.add_node(idx+1, bipartite=0)
726 100
            atom_dict[idx+1] = atom_set
727
        # and atom/bond storage objects on the other
728 100
        for idx, storage_set in enumerate(storages):
729 100
            g.add_node((idx*-1)-1, bipartite=1)
730 100
            storage_dict[(idx*-1)-1] = storage_set
731

732
        # Fill in the weight on each edge of the graph using the compare_atom/bond functions
733 100
        for a_idx, atom_set in atom_dict.items():
734 100
            for s_idx, storage_set in storage_dict.items():
735
                # sum up score for every entry in the atom and storage set
736 100
                score = 0
737 100
                for sa, a in zip(storage_set, atom_set):
738 100
                    if isinstance(sa, self.BondStorage):
739 100
                        score += sa.compare_bond(a)
740
                    else:
741 100
                        score += sa.compare_atom(a)
742
                # score can't be zero so save score+1
743 100
                g.add_edge(a_idx,s_idx,weight=score+1)
744

745
        # calculate maximum matching, that is the pairing of atoms/bonds to
746
        # storage objects that leads the the highest overall score
747 100
        matching = nx.algorithms.max_weight_matching(g,maxcardinality=False)
748
        # track the atoms assigned a paired storage object
749 100
        pair_set = set()
750

751
        # store all pairs
752 100
        for idx_1, idx_2 in matching:
753 100
            pair_set.add(idx_1)
754 100
            pair_set.add(idx_2)
755 100
            if idx_1 in atom_dict:
756 100
                atom_set = atom_dict[idx_1]
757 100
                storage_set = storage_dict[idx_2]
758
            else:
759 100
                atom_set = atom_dict[idx_2]
760 100
                storage_set = storage_dict[idx_1]
761 100
            pairs.append(list(atom_set) + list(storage_set))
762

763
        # check for missing atom storages
764 100
        for a_idx, atom_set in atom_dict.items():
765 100
            if a_idx not in pair_set:
766 100
                pairs.append(list(atom_set) + nones)
767

768
        # check for missing atoms
769 100
        for s_idx, storage_set in storage_dict.items():
770 100
            if s_idx not in pair_set:
771 100
                pairs.append(nones + list(storage_set))
772

773 100
        return pairs
774

775 100
    def _add_mol(self, mol, smirks_atoms_list):
776
        """
777
        private function for adding a new molecule
778
        This is used by add_mol if the graph is not empty, allowing the user to
779
        not have to track if the graph already has information before adding molecules
780

781
        Parameters
782
        ----------
783
        mol : any Mol
784
        smirks_atoms_list : list of dicts
785
            This is a list of dictionaries of the form [{smirks index: atom index}]
786
            each atom (by index) in the dictionary will be added the relevant
787
            AtomStorage by smirks index
788
        """
789 100
        for smirks_atoms in smirks_atoms_list:
790 100
            atom_dict = dict()
791 100
            sorted_smirks_atoms = self._symmetry_funct(mol, smirks_atoms)
792 100
            for key, atom_index in enumerate(sorted_smirks_atoms, 1):
793 100
                atom_dict[atom_index] = key
794 100
                atom1 = mol.get_atom_by_index(atom_index)
795 100
                self.atom_by_label[key].add_atom(atom1)
796

797 100
                for neighbor_key, neighbor_index in enumerate(sorted_smirks_atoms, 1):
798
                    # check for connecting bond
799 100
                    atom2 = mol.get_atom_by_index(neighbor_index)
800 100
                    bond = mol.get_bond_by_atoms(atom1, atom2)
801 100
                    if bond is not None and (neighbor_key, key) in self.bond_by_label:
802 100
                        bond_smirks = tuple(sorted([neighbor_key, key]))
803 100
                        self.bond_by_label[bond_smirks].add_bond(bond)
804

805 100
            for atom_label, atom_index in enumerate(sorted_smirks_atoms, 1):
806 100
                atom = mol.get_atom_by_index(atom_index)
807 100
                storage = self.atom_by_label[atom_label]
808 100
                self._add_layers(mol, atom, storage, self.layers, atom_dict)
809

810 100
    def _no_symmetry(self, mol, smirks_atoms):
811
        """
812
        No change is made to the atom order for this molecule
813
        """
814 100
        return smirks_atoms
815

816 100
    def _bond_symmetry(self, mol, smirks_atoms):
817
        """
818
        Returns a tuple of two atom indices in the order that
819
        leads to the atoms that match with previously stored atoms.
820

821
        Parameters
822
        -----------
823
        mol : ChemPer Mol
824
        smirks_atoms : two tuple
825
            tuple of atom indices
826

827
        Returns
828
        --------
829
        ordered_smirks_atoms : two tuple
830
            tuple of atom indices as they should be added to the graph
831
        """
832
        # pair atoms and bonds
833 100
        atom1 = mol.get_atom_by_index(smirks_atoms[0])
834 100
        atom2 = mol.get_atom_by_index(smirks_atoms[1])
835
        # Find potential storages for those atoms and bonds
836 100
        atoms_and_bonds = [(atom1,), (atom2,)]
837 100
        storages = [
838
            (self.atom_by_label[1],),
839
            (self.atom_by_label[2],)
840
        ]
841 100
        pairs = self.find_pairs(atoms_and_bonds, storages)
842 100
        ordered_smirks_atoms = [p[0].get_index() for p in sorted(pairs, key=lambda x: x[1].label)]
843 100
        return tuple(ordered_smirks_atoms)
844

845 100
    def _angle_symmetry(self, mol, smirks_atoms):
846
        """
847
        Returns a tuple of three atom indices in the order that
848
        leads to the atoms that match with previously stored atoms.
849

850
        Parameters
851
        -----------
852
        mol : ChemPer Mol
853
        smirks_atoms : three tuple
854
            tuple of atom indices
855

856
        Returns
857
        --------
858
        ordered_smirks_atoms : three tuple
859
            tuple of atom indices as they should be added to the graph
860
        """
861
        # get all three atoms
862 100
        atom1 = mol.get_atom_by_index(smirks_atoms[0])
863 100
        atom2 = mol.get_atom_by_index(smirks_atoms[1])
864 100
        atom3 = mol.get_atom_by_index(smirks_atoms[2])
865
        # get both bonds
866 100
        bond1 = mol.get_bond_by_atoms(atom1, atom2)
867 100
        bond2 = mol.get_bond_by_atoms(atom2, atom3)
868 100
        if None in (bond1, bond2):
869 0
            return smirks_atoms
870
        # save atom and bond pairs that could be reordered
871 100
        atoms_and_bonds = [(atom1, bond1), (atom3, bond2)]
872
        # find current atom and bond storage
873 100
        storages = [
874
            (self.atom_by_label[1], self.bond_by_label[(1,2)]),
875
            (self.atom_by_label[3], self.bond_by_label[(2,3)])
876
        ]
877 100
        pairs = self.find_pairs(atoms_and_bonds, storages)
878 100
        order = [p[0].get_index() for p in sorted(pairs, key=lambda x: x[2].label)]
879 100
        return tuple((order[0], smirks_atoms[1], order[1]))
880

881 100
    def _proper_torsion_symmetry(self, mol, smirks_atoms):
882
        """
883
        Returns a tuple of four atom indices for a proper torsion
884
        reordered to match with previously stored atoms.
885

886
        Parameters
887
        -----------
888
        mol : ChemPer Mol
889
        smirks_atoms : four tuple
890
            tuple of atom indices
891

892
        Returns
893
        --------
894
        ordered_smirks_atoms : four tuple
895
            tuple of atom indices as they should be added to the graph
896
        """
897
        # get all four atoms
898 100
        atom1 = mol.get_atom_by_index(smirks_atoms[0])
899 100
        atom2 = mol.get_atom_by_index(smirks_atoms[1])
900 100
        atom3 = mol.get_atom_by_index(smirks_atoms[2])
901 100
        atom4 = mol.get_atom_by_index(smirks_atoms[3])
902
        # get two relevant bonds
903 100
        bond1 = mol.get_bond_by_atoms(atom1, atom2)
904 100
        bond3 = mol.get_bond_by_atoms(atom3, atom4)
905 100
        if None in (bond1, bond3):
906 0
            return smirks_atoms
907
        # make pairs
908 100
        atoms_and_bonds = [ (atom2, bond1, atom1), (atom3, bond3, atom4) ]
909
        # get atom and bond storages
910 100
        storages = [
911
            (self.atom_by_label[2], self.bond_by_label[(1,2)], self.atom_by_label[1]),
912
            (self.atom_by_label[3], self.bond_by_label[(3,4)], self.atom_by_label[4])
913
        ]
914 100
        pairs = self.find_pairs(atoms_and_bonds, storages)
915 100
        order = [p[0].get_index() for p in sorted(pairs, key=lambda x: x[3].label)]
916 100
        if order[0] == smirks_atoms[1]:
917 100
            return smirks_atoms
918 100
        temp = list(smirks_atoms)
919 100
        temp.reverse()
920 100
        return tuple(temp)
921

922 100
    def _improper_torsion_symmetry(self, mol, smirks_atoms):
923
        """
924
        Returns a tuple of four atom indices for an improper torsion
925
        reordered to match with previously stored atoms.
926

927
        Parameters
928
        -----------
929
        mol : ChemPer Mol
930
        smirks_atoms : four tuple
931
            tuple of atom indices
932

933
        Returns
934
        --------
935
        ordered_smirks_atoms : four tuple
936
            tuple of atom indices as they should be added to the graph
937
        """
938
        # get all four atoms
939 0
        atom1 = mol.get_atom_by_index(smirks_atoms[0])
940 0
        atom2 = mol.get_atom_by_index(smirks_atoms[1])
941 0
        atom3 = mol.get_atom_by_index(smirks_atoms[2])
942 0
        atom4 = mol.get_atom_by_index(smirks_atoms[3])
943
        # get all three bonds
944 0
        bond1 = mol.get_bond_by_atoms(atom1, atom2)
945 0
        bond2 = mol.get_bond_by_atoms(atom2, atom3)
946 0
        bond3 = mol.get_bond_by_atoms(atom2, atom4)
947 0
        if None in (bond1, bond2, bond3):
948 0
            return smirks_atoms
949
        # make pairs of atoms and bonds to be reordered
950 0
        atoms_and_bonds = [
951
            (atom1, bond1), (atom3, bond2), (atom4, bond3)
952
        ]
953
        # find current atom and bond storages
954 0
        storages = [
955
            (self.atom_by_label[1], self.bond_by_label[(1,2)]),
956
            (self.atom_by_label[3], self.bond_by_label[(2,3)]),
957
            (self.atom_by_label[4], self.bond_by_label[(2,4)])
958
        ]
959 0
        pairs = self.find_pairs(atoms_and_bonds, storages)
960 0
        order = [p[0].get_index() for p in sorted(pairs, key=lambda x: x[2].label)]
961 0
        return tuple((order[0], smirks_atoms[1], order[1], order[2]))

Read our documentation on viewing source code .

Loading