1
"""
2
cluster_graph.py
3

4
ClusterGraph are a class for tracking all possible smirks decorators in a group (or cluster)
5
of molecular fragments. Moving forward these will be used to find the minimum number of
6
smirks decorators that are required to have a set of smirks patterns that maintain
7
a given clustering of fragments.
8

9
# TODO: add specific example like the one at the top of fragment_graph
10

11
AUTHORS:
12

13
Caitlin C. Bannan <bannanc@uci.edu>, Mobley Group, University of California Irvine
14
"""
15

16 9
import networkx as nx
17 9
from functools import total_ordering
18 9
from chemper.graphs.fragment_graph import ChemPerGraph
19 9
from chemper.graphs.environment import ChemicalEnvironment as CE
20 9
from chemper.mol_toolkits import mol_toolkit
21

22

23 9
@total_ordering
24 9
class ClusterGraph(ChemPerGraph):
25
    """
26
    ChemPerGraphs are a graph based class for storing atom and bond information.
27
    They use the chemper.mol_toolkits Atoms, Bonds, and Mols
28
    """
29 9
    @total_ordering
30
    class AtomStorage:
31
        """
32
        AtomStorage tracks information about an atom
33
        """
34 9
        def __init__(self, atoms=None, label=None):
35
            """
36
            Parameters
37
            ----------
38
            atoms: chemper Atom or list of chemper Atoms
39
                this is one or more atoms whose information should be stored
40
            label: int
41
                SMIRKS index (:n) for writing SMIRKS
42
                if the value is less than zero it is used for storage purposes
43
                only as SMIRKS can only be written with positive integer indices
44
            """
45 9
            self.decorators = set()
46 9
            if atoms is not None:
47
                # check if this is a single atom
48 9
                if 'Atom' in str(type(atoms)):
49 9
                    atoms = [atoms]
50

51
                # otherwise it should be iterable
52 9
                for atom in atoms:
53 9
                    self.decorators.add(self.make_atom_decorators(atom))
54 9
            self.label = label
55

56 9
        def __lt__(self, other):
57
            """
58
            Overrides the default implementation
59
            This method was primarily written for making SMIRKS patterns predictable.
60
            If atoms are sortable, then the SMIRKS patterns are always the same making
61
            tests easier to write. However, the specific sorting was created to also make SMIRKS
62
            output as human readable as possible, that is to at least make it easier for a
63
            human to see how the indexed atoms are related to each other.
64
            It is typically easier for humans to read SMILES/SMARTS/SMIRKS with less branching (indicated with ()).
65

66
            For example in:
67
            [C:1]([H])([H])~[N:2]([C])~[O:3]
68
            it is easier to see that the atoms C~N~O are connected in a "line" instead of:
69
            [C:1]([N:2]([O:3])[C])([H])[H]
70
            which is equivalent, but with all the () it is hard for a human to read the branching
71

72
            Parameters
73
            ----------
74
            other: AtomStorage
75

76
            Returns
77
            -------
78
            is_less_than: boolean
79
                self is less than other
80
            """
81
            # if either smirks index is None, then you can't directly compare
82
            # make a temporary index that is negative if it was None
83 9
            self_index = self.label if self.label is not None else -1000
84 9
            other_index = other.label if other.label is not None else -1000
85
            # if either index is greater than 0, the one that is largest should go at the end of the list
86 9
            if self_index > 0 or other_index > 0:
87 9
                return self_index < other_index
88

89
            # Both SMIRKS indices are not positive or None so compare the SMIRKS patterns instead
90 9
            return self.as_smirks() < other.as_smirks()
91

92 9
        def __eq__(self, other): return self.as_smirks() == other.as_smirks() and self.label == other.label
93

94 9
        def __hash__(self): return id(self)
95

96 9
        def __str__(self): return self.as_smirks()
97

98 9
        def make_atom_decorators(self, atom):
99
            """
100
            extract information from a chemper atom that would be useful in a smirks
101

102
            parameters
103
            ----------
104
            atom: chemper atom object
105

106
            returns
107
            -------
108
            decorators: tuple of str
109
                tuple of all possible decorators for this atom
110
            """
111 9
            aromatic = 'a' if atom.is_aromatic() else 'A'
112 9
            charge = atom.formal_charge()
113 9
            if charge >= 0:
114 9
                charge = '+%i' % charge
115
            else:
116 9
                charge = '%i' % charge
117 9
            min_ring_size = atom.min_ring_size()
118 9
            if min_ring_size == 0:
119 9
                ring = '!r'
120
            else:
121 9
                ring = 'r%i' % min_ring_size
122

123 9
            return (
124
                '#%i' % atom.atomic_number(),
125
                'H%i' % atom.hydrogen_count(),
126
                'X%i' % atom.connectivity(),
127
                'x%i' % atom.ring_connectivity(),
128
                ring,
129
                charge,
130
                aromatic,
131
                )
132

133 9
        def as_smirks(self, compress=False):
134
            """
135
            Parameters
136
            ----------
137
            compress: boolean
138
                should decorators common to all sets be combined
139

140
            Returns
141
            -------
142
            smirks: str
143
                how this atom would be represented in a SMIRKS string
144
                with the minimal combination of SMIRKS decorators
145
            """
146 9
            if len(self.decorators) == 0:
147 9
                if self.label is None or self.label <= 0:
148 9
                    return '[*]'
149 9
                return '[*:%i]' % self.label
150

151 9
            if compress and len(self.decorators) > 1:
152 9
                base_smirks = self._compress_smirks()
153
            else:
154 9
                base_smirks = ','.join(sorted([''.join(l) for l in self.decorators]))
155

156 9
            if self.label is None or self.label <= 0:
157 9
                return '[%s]' % base_smirks
158

159 9
            return '[%s:%i]' % (base_smirks, self.label)
160

161 9
        def _sort_decs(self, dec_set, wild=True):
162
            """
163
            Parameters
164
            ----------
165
            dec_set: list like
166
                single set of atom decorators
167
            wild: boolean
168
                insert * for decorator lists with no #n decorator
169

170
            Returns
171
            -------
172
            sorted_dec_set: list
173
                same set of decorators sorted with atomic number or * first
174
            """
175 9
            temp_dec_set = list(dec_set)
176 9
            atom_num = [i for i in temp_dec_set if '#' in i]
177 9
            if len(atom_num) == 0 and wild:
178 0
                atom_num = ["*"]
179

180 9
            temp_dec_set = set(temp_dec_set) - set(atom_num)
181

182 9
            aro = [i for i in temp_dec_set if 'a' in i.lower()]
183 9
            temp_dec_set = set(temp_dec_set) - set(aro)
184

185 9
            return atom_num + sorted(list(temp_dec_set)) + aro
186

187 9
        def _compress_smirks(self):
188
            """
189
            Returns
190
            -------
191
            smirks: str
192
                This SMIRKS is compressed with all common decorators and'd to
193
                the end of the pattern
194
            """
195 9
            set_decs = [set(d) for d in self.decorators]
196 9
            ands = set_decs[0]
197

198 9
            for d_set in set_decs:
199 9
                ands = ands & d_set
200

201
            # check for atomic number in the "ands"
202 9
            atomic = [a for a in ands if '#' in a]
203 9
            if len(atomic) == 1:
204
                # remove from and
205 9
                ands.remove(atomic[0])
206
                # put in all sets
207 9
                for s in set_decs:
208 9
                    s.add(atomic[0])
209

210 9
            or_sets = [self._sort_decs(d.difference(ands)) for d in set_decs]
211 9
            ors = [''.join(o) for o in or_sets]
212

213
            # add commas between ors
214 9
            base = ','.join(sorted(ors))
215
            # add and decorators
216 9
            if len(ands) > 0:
217 9
                base += ';'+ ';'.join(self._sort_decs(ands, wild=False))
218 9
            return base
219

220 9
        def add_atom(self, atom):
221
            """
222
            Expand current AtomStorage by adding information about
223
            a new chemper Atom
224

225
            Parameters
226
            ----------
227
            atom: chemper Atom
228
            """
229 9
            self.decorators.add(self.make_atom_decorators(atom))
230

231 9
        def compare_atom(self, atom):
232
            """
233
            # TODO: add better description here
234
            Parameters
235
            ----------
236
            atom: chemper Atom
237

238
            Returns
239
            -------
240
            score: float
241
                A score describing how similar the input atom is to any set of decorators currently
242
                in this storage, based on its SMIRKS decorators.
243
                This score ranges from 0 to 7. 7 comes from the number of decorators
244
                on any atom, if this atom matches perfectly with one of the current decorator sets
245
                then 7 decorators agree.
246
                However, if the atomic number doesn't agree, then that set of decorators is considered
247
                less ideal, thus if the atomic numbers don't agree, then the score is given by
248
                the number other decorators divided by 10.
249
                If the current storage is empty, then the score is given as 7 any atom matches
250
                a wildcard atom.
251
            """
252
            # If decorators is empty (no known atom information, return 7 (current max)
253 9
            if len(self.decorators) == 0:
254 0
                return 7
255

256 9
            score = 0
257 9
            decs = self.make_atom_decorators(atom)
258

259 9
            for ref in self.decorators:
260
                # get atomic number for this set of decorators
261 9
                current = len(set(ref) & set(decs))
262

263
                # if atomic numbers don't agree, get the number of common decorators / 10
264
                # if there are no matching atomic numbers, priority should still be given
265
                # when the current atom matches stored decorators most closely
266 9
                if ref[0] != decs[0]:
267 9
                    current = current / 10.0
268

269 9
                if current > score:
270 9
                    score = current
271

272 9
            return score
273

274 9
    @total_ordering
275
    class BondStorage:
276
        """
277
        BondStorage tracks information about a bond
278
        """
279 9
        def __init__(self, bonds=None, label=None):
280
            """
281
            Parameters
282
            ----------
283
            bonds: list of chemper Bond objects
284
                this is one or more bonds whose information should be stored
285
            label: a label for the object, it can be anything
286
                unlike atoms, bonds in smirks don't have labels
287
                so this is only used for labeling the object if wanted
288
            """
289 9
            self.order = set()
290 9
            self.ring = set()
291 9
            self.order_dict = {1:'-', 1.5:':', 2:'=', 3:'#'}
292 9
            if bonds is not None:
293 9
                if 'Bond' in str(type(bonds)):
294 9
                    bonds = [bonds]
295 9
                for bond in bonds:
296 9
                    self.order.add(bond.get_order())
297 9
                    self.ring.add(bond.is_ring())
298

299 9
            self.label = label
300

301 9
        def __str__(self): return self.as_smirks()
302

303 9
        def __lt__(self, other):
304 0
            if self.as_smirks() == other.as_smirks():
305 0
                return self.label < other.label
306 0
            return self.as_smirks() < other.as_smirks()
307

308 9
        def __eq__(self, other):
309 0
            return self.label == other.label and self.as_smirks() == other.as__smirks()
310

311 9
        def __hash__(self): return id(self)
312
        
313 9
        def as_smirks(self):
314
            """
315
            Returns
316
            -------
317
            smirks: str
318
                how this bond would be represented in a SMIRKS string
319
                using only the required number of
320
            """
321 9
            if len(self.order) == 0:
322 9
                order = '~'
323
            else:
324 9
                order = ','.join([self.order_dict.get(o, '~') for o in sorted(list(self.order))])
325

326
            # the ring set has booleans, if the length of the set is 1 then only ring (@) or non-ring (!@)
327
            # bonds haven been added to this storage and we AND that decorator to the end of the bond
328 9
            if len(self.ring) == 1:
329 9
                if list(self.ring)[0]:
330 9
                    return order+';@'
331
                else:
332 9
                    return order+';!@'
333

334 9
            return order
335

336 9
        def add_bond(self, bond):
337
            """
338
            Expand current BondStorage by adding information about
339
            a new chemper Bond
340

341
            Parameters
342
            ----------
343
            bond: chemper Bond
344
            """
345 9
            self.order.add(bond.get_order())
346 9
            self.ring.add(bond.is_ring())
347

348 9
        def compare_bond(self, bond):
349
            """
350

351
            Parameters
352
            ----------
353
            bond: chemper Bond
354
                bond you want to compare to the current storage
355

356
            Returns
357
            -------
358
            score: int (0,1,2)
359
                A score describing how similar the input bond is to any set of decorators currently
360
                in this storage, based on its SMIRKS decorators.
361

362
                1 for the bond order +
363
                1 base on if this is a ring bond
364
            """
365 9
            score = 0
366 9
            if bond.get_order() in self.order or len(self.order) == 0:
367 9
                score += 1
368

369
            # the ring set has booleans, if the length of the set is 1 then only ring or non-ring
370
            # bonds haven been added to this storage. That is the only time the ring contributes to the score
371 9
            if len(self.ring) == 1 and list(self.ring)[0] == bond.is_ring():
372 9
                score += 1
373

374 9
            return score
375

376
    # Initiate ClusterGraph
377 9
    def __init__(self, mols=None, smirks_atoms_lists=None, layers=0):
378
        """
379
        Initialize a ChemPerGraph from a molecule and list of indexed atoms
380

381
        For the example, imagine we wanted to get a SMIRKS that
382
        would match the carbon-carbon bonds in ethane and propane.
383
        The carbon atoms are have indices (0,1) in ethane and (0,1) and (1,2)
384
        in propane. For this example, we will assume we also want to include
385
        the atoms one bond away from the indexed atoms (1 layer away).
386

387
        Parameters
388
        ----------
389
        mols: list of molecules (optional)
390
            these can be ChemPer Mols or molecule objects from
391
            any supported toolkit (currently OpenEye or RDKit)
392

393
        smirks_atoms_lists: list of list of tuples (optional)
394
            There is a list of tuples for each molecule, where each tuple specifies
395
            a molecular fragment using the atoms' indices.
396
            In the ethane and propane example, the `smirks_atoms_lists` would be
397
                [ [ (0,1) ], [ (0,1), (1,2) ] ]
398
            with one carbon-carbon bond in ethane and two carbon-carbon bonds in propane
399

400
        layers: int (optional, default=0)
401
            layers specifies how many bonds away from the indexed atoms should be included in the
402
            the SMIRKS patterns.
403
            Instead of an int, the string 'all' would lead to all atoms in the molecules
404
            being included in the SMIRKS (not recommended)
405
        """
406 9
        ChemPerGraph.__init__(self)
407

408 9
        self.mols = list()
409 9
        self.smirks_atoms_lists = list()
410 9
        self.layers = layers
411 9
        self._symmetry_funct = self._no_symmetry
412

413 9
        if mols is not None:
414 9
            temp_mols = [mol_toolkit.Mol(m) for m in mols]
415 9
            if len(temp_mols) != len(smirks_atoms_lists):
416 9
                raise Exception('Number of molecules and smirks dictionaries should be equal')
417

418 9
            for idx, mol in enumerate(temp_mols):
419 9
                self.add_mol(mol, smirks_atoms_lists[idx])
420

421 9
    def as_smirks(self, compress=False):
422
        """
423
        Parameters
424
        ----------
425
        compress: boolean
426
                  returns the shorter version of atom SMIRKS patterns
427
                  that is atoms have decorators "anded" to the end rather than listed
428
                  in each set that are OR'd together.
429
                  For example "[#6AH2X3x0r0+0,#6AH1X3x0r0+0:1]-;!@[#1AH0X1x0r0+0]"
430
                  compresses to: "[#6H2,#6H1;AX3x0r0+0:1]-;!@[#1AH0X1x0r0+0]"
431

432
        Returns
433
        -------
434
        SMIRKS: str
435
            a SMIRKS string matching the exact atom and bond information stored
436
        """
437 9
        return ChemPerGraph.as_smirks(self, compress)
438

439 9
    def get_symmetry_funct(self, sym_label):
440
        """
441
        Parameters
442
        ----------
443
        sym_label: str or None
444
            type of symmetry, options which will change the way symmetry is
445
            handled in the graph are "bond", "angle", "ProperTorsion", and "ImproperTorsion"
446

447
        Returns
448
        -------
449
        symmetry_funct: function
450
            returns the function that should be used to handle the appropriate symmetry
451
        """
452 9
        if sym_label is None:
453 0
            return self._no_symmetry
454 9
        if sym_label.lower() == 'bond':
455 9
            return self._bond_symmetry
456 9
        if sym_label.lower() == 'angle':
457 9
            return self._angle_symmetry
458 9
        if sym_label.lower() == 'propertorsion':
459 9
            return self._proper_torsion_symmetry
460 9
        if sym_label.lower() == 'impropertorsion':
461 0
            return self._improper_torsion_symmetry
462 9
        return self._no_symmetry
463

464 9
    def add_mol(self, input_mol, smirks_atoms_list):
465
        """
466
        Expand the information in this graph by adding a new molecule
467

468
        Parameters
469
        ----------
470
        input_mol: chemper Mol object
471
        smirks_atoms_list: list of tuples
472
            This is a list of tuples with atom indices [ (indices), ... ]
473
        """
474 9
        mol = mol_toolkit.Mol(input_mol)
475

476 9
        if len(smirks_atoms_list) == 0:
477 9
            return
478

479 9
        if len(self.mols) == 0:
480 9
            self._add_first_smirks_atoms(mol, smirks_atoms_list[0])
481 9
            self._symmetry_funct = self.get_symmetry_funct(CE(self.as_smirks()).getType())
482 9
            self._add_mol(mol, smirks_atoms_list[1:])
483
        else:
484 9
            self._add_mol(mol, smirks_atoms_list)
485

486 9
        self.mols.append(mol)
487 9
        self.smirks_atoms_lists.append(smirks_atoms_list)
488

489 9
    def _add_first_smirks_atoms(self, mol, smirks_atoms):
490
        """
491
        private function for adding the first molecule to an empty ClusterGraph
492
        add_mol calls this if the graph is empty
493

494
        Parameters
495
        ----------
496
        mol: chemper Mol
497
        smirks_atoms: tuple
498
            tuple of atom indices for the first atoms to add to the graph. i.e. (0, 1)
499
        """
500 9
        atom_dict = dict()
501 9
        for key, atom_index in enumerate(smirks_atoms, 1):
502 9
            atom_dict[atom_index] = key
503

504 9
            atom1 = mol.get_atom_by_index(atom_index)
505 9
            new_atom_storage = self.AtomStorage([atom1], key)
506 9
            self._graph.add_node(new_atom_storage)
507 9
            self.atom_by_label[key] = new_atom_storage
508

509
            # Check for bonded atoms already in the graph
510 9
            for neighbor_key in range(len(smirks_atoms), 0, -1):
511 9
                if neighbor_key not in self.atom_by_label:
512 9
                    continue
513

514
                # check if atoms are already connected on the graph
515 9
                neighbor_storage = self.atom_by_label[neighbor_key]
516 9
                if nx.has_path(self._graph, new_atom_storage, neighbor_storage):
517 9
                    continue
518

519
                # check if atoms are connected in the molecule
520 9
                atom2 = mol.get_atom_by_index(smirks_atoms[neighbor_key-1])
521 9
                bond = mol.get_bond_by_atoms(atom1, atom2)
522

523 9
                if bond is not None: # Atoms are connected add edge
524 9
                    bond_smirks = tuple(sorted([neighbor_key, key]))
525 9
                    bond_storage = self.BondStorage([bond], bond_smirks)
526 9
                    self.bond_by_label[bond_smirks] = bond_storage
527 9
                    self._graph.add_edge(new_atom_storage,
528
                                         neighbor_storage,
529
                                         bond=bond_storage)
530

531
        # for each indexed atoms add unindexed atoms for the number of specified layers
532 9
        for atom_label, atom_index in enumerate(smirks_atoms, 1):
533 9
            atom = mol.get_atom_by_index(atom_index)
534 9
            storage = self.atom_by_label[atom_label]
535 9
            self._add_layers(mol, atom, storage, self.layers, atom_dict, is_first=True)
536

537 9
    def _add_layers(self, mol, atom, storage, layers, idx_dict, is_first=False):
538
        """
539
        Parameters
540
        ----------
541
        mol: chemper Mol
542
            molecule containing provided atom
543
        atom: chemper Atom
544
        storage: AtomStorage
545
            corresponding to the chemper Atom provided
546
        layers: int or 'all'
547
            number of layers left to add (or all)
548
        idx_dict: dict
549
            form {atom index: label} for this smirks_list in this molecule
550
        """
551
        # if layers is 0 there are no more atoms to add so end the recursion
552 9
        if layers == 0:
553 9
            return
554

555
        # find atom neighbors that are not already included in SMIRKS indexed atoms
556 9
        atom_neighbors = [(a, mol.get_bond_by_atoms(a,atom)) for a in atom.get_neighbors() \
557
                          if a.get_index() not in idx_dict]
558

559
        # get the smirks indices already added to the storage
560
        # This includes all previous layers since the idx_dict is updated as you go
561 9
        storage_labels = [e for k,e in idx_dict.items()]
562

563
        # similar to atoms find neighbors already in the graph that haven't already been used
564 9
        storage_neighbors = [(s, self.get_connecting_bond(s, storage)) for s in self.get_neighbors(storage) \
565
                             if s.label not in storage_labels]
566

567 9
        new_pairs = list()
568
        # if this is the first set of atoms added, just make a new
569
        # storage for all neighboring atoms
570 9
        if is_first:
571 9
            min_smirks = storage.label * 10
572 9
            if min_smirks > 0:
573 9
                min_smirks = min_smirks * -1
574

575 9
            for a, b in atom_neighbors:
576 9
                new_bond_smirks = tuple(sorted([storage.label, min_smirks]))
577

578 9
                adding_new_storage = self.add_atom(a,b,storage,
579
                                                   min_smirks, new_bond_smirks)
580

581 9
                idx_dict[a.get_index()] = min_smirks
582 9
                self.atom_by_label[min_smirks] = adding_new_storage
583 9
                min_smirks -= 1
584 9
                new_pairs.append((a, adding_new_storage))
585

586
        else: # this isn't the first set of atoms so you need to
587
            # pair up the atoms with their storage
588 9
            pairs = self.find_pairs(atom_neighbors, storage_neighbors)
589 9
            for new_atom, new_bond, new_storage_atom, new_storage_bond in pairs:
590
                # if no storage is paired to this atom skip it
591 9
                if new_storage_atom is None:
592 9
                    continue
593
                # if there is no atom paired to a storage remove that branch
594 9
                if new_atom is None:
595 9
                    self.remove_atom(new_storage_atom)
596 9
                    continue
597
                # add atom and bond information to the storage
598 9
                new_storage_atom.add_atom(new_atom)
599 9
                new_storage_bond.add_bond(new_bond)
600 9
                new_pairs.append((new_atom, new_storage_atom))
601 9
                idx_dict[new_atom.get_index()] = new_storage_atom.label
602

603
        # Repeat for the extra layers
604 9
        if layers == 'all':
605 9
            new_layers = 'all'
606
        else:
607 9
            new_layers = layers - 1
608 9
            if new_layers == 0:
609 9
                return
610

611 9
        for new_atom, new_storage in new_pairs:
612 9
            self._add_layers(mol, new_atom, new_storage, new_layers, idx_dict, is_first)
613

614 9
    def find_pairs(self, atoms_and_bonds, storages):
615
        """
616
        Find pairs is used to determine which current AtomStorage from storages
617
        atoms should be paired with.
618
        This function takes advantage of the maximum scoring function in networkx
619
        to find the pairing with the highest "score".
620
        Scores are determined using functions in the atom and bond storage objects
621
        that compare those storages to the new atom or bond.
622

623
        If there are less atoms than storages then the atoms with the lowest pair are
624
        assigned a None pairing.
625

626
        Parameters
627
        ----------
628
        atoms_and_bonds: list of tuples in form (chemper Atom, chemper Bond, ...)
629
        storages: list of tuples in form (AtomStorage, BondStorage, ...)
630

631
        Tuples can be of any length as long as they are the same, so for example, in
632
        a bond you might only care about the outer atoms for comparison so you would compare
633
        (atom1,) and (atom2,) with (atom_storage1,) and (atom_storage2,)
634
        However, in a torsion, you might want the atoms and bonds for each outer bond
635
        so in that case you would compare
636
        (atom1, bond1, atom2) and (atom4, bond3, atom3)
637
        with the corresponding storage objects.
638

639
        Returns
640
        -------
641
        pairs: list of lists
642
            pairs of atoms and storage objects that are most similar,
643
            these lists always come in the form (all atom/bonds, all storage objects)
644
            for the bond example above you might get
645
            [ [atom1, storage1], [atom2, storage2] ]
646
            for the torsion example you might get
647
            [ [atom4, bond4, atom3, atom_storage1, bond_storage1, atom_storage2],
648
              [atom1, bond1, atom2, atom_storage4, bond_storage3, atom_storage3]
649

650
        """
651
        # store paired stets of atoms/bonds and corresponding storages
652 9
        pairs = list()
653
        # check for odd cases
654 9
        combo = atoms_and_bonds + storages
655
        # 1. both lists are empty
656 9
        if len(combo) == 0:
657 9
            return pairs
658

659 9
        nones = [None] * len(combo[0])
660
        # 2. no atom/bond storage
661 9
        if len(atoms_and_bonds) == 0:
662 9
            for storage_set in storages:
663 9
                pairs.append(nones + list(storage_set))
664 9
            return pairs
665

666
        # 3. no storages
667 9
        if len(storages) == 0:
668 9
            for atom_set in atoms_and_bonds:
669 9
                pairs.append(list(atom_set) + nones)
670 9
            return pairs
671

672 9
        g = nx.Graph()
673

674 9
        atom_dict = dict()
675 9
        storage_dict = dict()
676

677
        # create a bipartite graph with atoms/bonds on one side
678 9
        for idx, atom_set in enumerate(atoms_and_bonds):
679 9
            g.add_node(idx+1, bipartite=0)
680 9
            atom_dict[idx+1] = atom_set
681
        # and atom/bond storage objects on the other
682 9
        for idx, storage_set in enumerate(storages):
683 9
            g.add_node((idx*-1)-1, bipartite=1)
684 9
            storage_dict[(idx*-1)-1] = storage_set
685

686
        # Fill in the weight on each edge of the graph using the compare_atom/bond functions
687 9
        for a_idx, atom_set in atom_dict.items():
688 9
            for s_idx, storage_set in storage_dict.items():
689
                # sum up score for every entry in the atom and storage set
690 9
                score = 0
691 9
                for sa, a in zip(storage_set, atom_set):
692 9
                    if isinstance(sa, self.BondStorage):
693 9
                        score += sa.compare_bond(a)
694
                    else:
695 9
                        score += sa.compare_atom(a)
696
                # score can't be zero so save score+1
697 9
                g.add_edge(a_idx,s_idx,weight=score+1)
698

699
        # calculate maximum matching, that is the pairing of atoms/bonds to
700
        # storage objects that leads the the highest overall score
701 9
        matching = nx.algorithms.max_weight_matching(g,maxcardinality=False)
702
        # track the atoms assigned a paired storage object
703 9
        pair_set = set()
704

705
        # store all pairs
706 9
        for idx_1, idx_2 in matching:
707 9
            pair_set.add(idx_1)
708 9
            pair_set.add(idx_2)
709 9
            if idx_1 in atom_dict:
710 9
                atom_set = atom_dict[idx_1]
711 9
                storage_set = storage_dict[idx_2]
712
            else:
713 9
                atom_set = atom_dict[idx_2]
714 9
                storage_set = storage_dict[idx_1]
715 9
            pairs.append(list(atom_set) + list(storage_set))
716

717
        # check for missing atom storages
718 9
        for a_idx, atom_set in atom_dict.items():
719 9
            if a_idx not in pair_set:
720 9
                pairs.append(list(atom_set) + nones)
721

722
        # check for missing atoms
723 9
        for s_idx, storage_set in storage_dict.items():
724 9
            if s_idx not in pair_set:
725 9
                pairs.append(nones + list(storage_set))
726

727 9
        return pairs
728

729 9
    def _add_mol(self, mol, smirks_atoms_list):
730
        """
731
        private function for adding a new molecule
732
        This is used by add_mol if the graph is not empty, allowing the user to
733
        not have to track if the graph already has information before adding molecules
734

735
        Parameters
736
        ----------
737
        mol: any Mol
738
        smirks_atoms_list: list of dicts
739
            This is a list of dictionaries of the form [{smirks index: atom index}]
740
            each atom (by index) in the dictionary will be added the relevant
741
            AtomStorage by smirks index
742
        """
743 9
        for smirks_atoms in smirks_atoms_list:
744 9
            atom_dict = dict()
745 9
            sorted_smirks_atoms = self._symmetry_funct(mol, smirks_atoms)
746 9
            for key, atom_index in enumerate(sorted_smirks_atoms, 1):
747 9
                atom_dict[atom_index] = key
748 9
                atom1 = mol.get_atom_by_index(atom_index)
749 9
                self.atom_by_label[key].add_atom(atom1)
750

751 9
                for neighbor_key, neighbor_index in enumerate(sorted_smirks_atoms, 1):
752
                    # check for connecting bond
753 9
                    atom2 = mol.get_atom_by_index(neighbor_index)
754 9
                    bond = mol.get_bond_by_atoms(atom1, atom2)
755 9
                    if bond is not None and (neighbor_key, key) in self.bond_by_label:
756 9
                        bond_smirks = tuple(sorted([neighbor_key, key]))
757 9
                        self.bond_by_label[bond_smirks].add_bond(bond)
758

759 9
            for atom_label, atom_index in enumerate(sorted_smirks_atoms, 1):
760 9
                atom = mol.get_atom_by_index(atom_index)
761 9
                storage = self.atom_by_label[atom_label]
762 9
                self._add_layers(mol, atom, storage, self.layers, atom_dict)
763

764 9
    def _no_symmetry(self, mol, smirks_atoms):
765 9
        return smirks_atoms
766

767 9
    def _bond_symmetry(self, mol, smirks_atoms):
768
        """
769
        Returns a tuple of two atom indices in the order that
770
        leads to the atoms that match with previously stored atoms.
771
        """
772
        # pair atoms and bonds
773 9
        atom1 = mol.get_atom_by_index(smirks_atoms[0])
774 9
        atom2 = mol.get_atom_by_index(smirks_atoms[1])
775
        # Find potential storages for those atoms and bonds
776 9
        atoms_and_bonds = [(atom1,), (atom2,)]
777 9
        storages = [
778
            (self.atom_by_label[1],),
779
            (self.atom_by_label[2],)
780
        ]
781 9
        pairs = self.find_pairs(atoms_and_bonds, storages)
782 9
        ordered_smirks_atoms = [p[0].get_index() for p in sorted(pairs, key=lambda x: x[1].label)]
783 9
        return tuple(ordered_smirks_atoms)
784

785 9
    def _angle_symmetry(self, mol, smirks_atoms):
786
        """
787
        Returns a tuple of three atom indices in the order that
788
        leads to the atoms that match with previously stored atoms.
789
        """
790
        # get all three atoms
791 9
        atom1 = mol.get_atom_by_index(smirks_atoms[0])
792 9
        atom2 = mol.get_atom_by_index(smirks_atoms[1])
793 9
        atom3 = mol.get_atom_by_index(smirks_atoms[2])
794
        # get both bonds
795 9
        bond1 = mol.get_bond_by_atoms(atom1, atom2)
796 9
        bond2 = mol.get_bond_by_atoms(atom2, atom3)
797 9
        if None in (bond1, bond2):
798 0
            return smirks_atoms
799
        # save atom and bond pairs that could be reordered
800 9
        atoms_and_bonds = [(atom1, bond1), (atom3, bond2)]
801
        # find current atom and bond storage
802 9
        storages = [
803
            (self.atom_by_label[1], self.bond_by_label[(1,2)]),
804
            (self.atom_by_label[3], self.bond_by_label[(2,3)])
805
        ]
806 9
        pairs = self.find_pairs(atoms_and_bonds, storages)
807 9
        order = [p[0].get_index() for p in sorted(pairs, key=lambda x: x[2].label)]
808 9
        return tuple((order[0], smirks_atoms[1], order[1]))
809

810 9
    def _proper_torsion_symmetry(self, mol, smirks_atoms):
811
        """
812
        Returns a tuple of four atom indices for a proper torsion
813
        reordered to match with previously stored atoms.
814
        """
815
        # get all four atoms
816 9
        atom1 = mol.get_atom_by_index(smirks_atoms[0])
817 9
        atom2 = mol.get_atom_by_index(smirks_atoms[1])
818 9
        atom3 = mol.get_atom_by_index(smirks_atoms[2])
819 9
        atom4 = mol.get_atom_by_index(smirks_atoms[3])
820
        # get two relevant bonds
821 9
        bond1 = mol.get_bond_by_atoms(atom1, atom2)
822 9
        bond3 = mol.get_bond_by_atoms(atom3, atom4)
823 9
        if None in (bond1, bond3):
824 0
            return smirks_atoms
825
        # make pairs
826 9
        atoms_and_bonds = [ (atom2, bond1, atom1), (atom3, bond3, atom4) ]
827
        # get atom and bond storages
828 9
        storages = [
829
            (self.atom_by_label[2], self.bond_by_label[(1,2)], self.atom_by_label[1]),
830
            (self.atom_by_label[3], self.bond_by_label[(3,4)], self.atom_by_label[4])
831
        ]
832 9
        pairs = self.find_pairs(atoms_and_bonds, storages)
833 9
        order = [p[0].get_index() for p in sorted(pairs, key=lambda x: x[3].label)]
834 9
        if order[0] == smirks_atoms[1]:
835 9
            return smirks_atoms
836 9
        temp = list(smirks_atoms)
837 9
        temp.reverse()
838 9
        return tuple(temp)
839

840 9
    def _improper_torsion_symmetry(self, mol, smirks_atoms):
841
        """
842
        Returns a tuple of four atom indices for an improper torsion
843
        reordered to match with previously stored atoms.
844
        """
845
        # get all four atoms
846 0
        atom1 = mol.get_atom_by_index(smirks_atoms[0])
847 0
        atom2 = mol.get_atom_by_index(smirks_atoms[1])
848 0
        atom3 = mol.get_atom_by_index(smirks_atoms[2])
849 0
        atom4 = mol.get_atom_by_index(smirks_atoms[3])
850
        # get all three bonds
851 0
        bond1 = mol.get_bond_by_atoms(atom1, atom2)
852 0
        bond2 = mol.get_bond_by_atoms(atom2, atom3)
853 0
        bond3 = mol.get_bond_by_atoms(atom2, atom4)
854 0
        if None in (bond1, bond2, bond3):
855 0
            return smirks_atoms
856
        # make pairs of atoms and bonds to be reordered
857 0
        atoms_and_bonds = [
858
            (atom1, bond1), (atom3, bond2), (atom4, bond3)
859
        ]
860
        # find current atom and bond storages
861 0
        storages = [
862
            (self.atom_by_label[1], self.bond_by_label[(1,2)]),
863
            (self.atom_by_label[3], self.bond_by_label[(2,3)]),
864
            (self.atom_by_label[4], self.bond_by_label[(2,4)])
865
        ]
866 0
        pairs = self.find_pairs(atoms_and_bonds, storages)
867 0
        order = [p[0].get_index() for p in sorted(pairs, key=lambda x: x[2].label)]
868 0
        return tuple((order[0], smirks_atoms[1], order[1], order[2]))

Read our documentation on viewing source code .

Loading