MobleyLab / chemper
1
"""
2
cluster_graph.py
3

4
ClusterGraph are a class for tracking all possible smirks decorators in a group
5
(or cluster) of molecular fragments. Moving forward these will be used to
6
find the minimum number of smirks decorators that are required to have a
7
set of smirks patterns that maintain a given clustering of fragments.
8
"""
9

10 151
import networkx as nx
11 151
from functools import total_ordering
12 151
from chemper.graphs.single_graph import SingleGraph
13 151
from chemper.graphs.environment import ChemicalEnvironment as CE
14 151
from chemper.mol_toolkits import mol_toolkit
15

16

17 151
@total_ordering
18 151
class ClusterGraph(SingleGraph):
19
    """
20
    ChemPerGraphs are a graph based class for storing atom and bond information.
21
    They use the chemper.mol_toolkits Atoms, Bonds, and Mols
22
    """
23 151
    @total_ordering
24
    class AtomStorage:
25
        """
26
        AtomStorage tracks information about an atom
27
        """
28 151
        def __init__(self, atoms=None, label=None):
29
            """
30
            Parameters
31
            ----------
32
            atoms : ChemPer Atom or list of ChemPer Atoms
33
                this is one or more atoms whose information should be stored
34
            label : int
35
                SMIRKS index (:n) for writing SMIRKS
36
                if the value is less than zero it is used for storage purposes
37
                only as SMIRKS can only be written with positive integer indices
38
            """
39 151
            self.decorators = set()
40 151
            if atoms is not None:
41
                # check if this is a single atom
42 151
                if 'Atom' in str(type(atoms)):
43 151
                    atoms = [atoms]
44

45
                # otherwise it should be iterable
46 151
                for atom in atoms:
47 151
                    self.decorators.add(self.make_atom_decorators(atom))
48 151
            self.label = label
49

50 151
        def __lt__(self, other):
51
            """
52
            Overrides the default implementation
53
            This method was primarily written for making SMIRKS patterns predictable.
54
            If atoms are sortable, then the SMIRKS patterns are always the same making
55
            tests easier to write. However, the specific sorting was created to also make SMIRKS
56
            output as human readable as possible, that is to at least make it easier for a
57
            human to see how the indexed atoms are related to each other.
58
            It is typically easier for humans to read SMILES/SMARTS/SMIRKS with less branching (indicated with ()).
59

60
            For example in:
61
            [C:1]([H])([H])~[N:2]([C])~[O:3]
62
            it is easier to see that the atoms C~N~O are connected in a "line" instead of:
63
            [C:1]([N:2]([O:3])[C])([H])[H]
64
            which is equivalent, but with all the () it is hard for a human to read the branching
65

66
            Parameters
67
            ----------
68
            other : AtomStorage
69

70
            Returns
71
            -------
72
            is_less_than : boolean
73
                self is less than other
74
            """
75
            # if either smirks index is None, then you can't directly compare
76
            # make a temporary index that is negative if it was None
77 151
            self_index = self.label if self.label is not None else -1000
78 151
            other_index = other.label if other.label is not None else -1000
79
            # if either index is greater than 0, the one that is largest should go at the end of the list
80 151
            if self_index > 0 or other_index > 0:
81 151
                return self_index < other_index
82

83
            # Both SMIRKS indices are not positive or None so compare the SMIRKS patterns instead
84 151
            return self.as_smirks() < other.as_smirks()
85

86 151
        def __eq__(self, other): return self.as_smirks() == other.as_smirks() and self.label == other.label
87

88 151
        def __hash__(self): return id(self)
89

90 151
        def __str__(self): return self.as_smirks()
91

92 151
        def make_atom_decorators(self, atom):
93
            """
94
            extract information from a ChemPer Atom that would be useful in a smirks
95

96
            parameters
97
            ----------
98
            atom : ChemPer atom object
99

100
            returns
101
            -------
102
            decorators : tuple of str
103
                tuple of all possible decorators for this atom
104
            """
105 151
            aromatic = 'a' if atom.is_aromatic() else 'A'
106 151
            charge = atom.formal_charge()
107 151
            if charge >= 0:
108 151
                charge = '+%i' % charge
109
            else:
110 151
                charge = '%i' % charge
111 151
            min_ring_size = atom.min_ring_size()
112 151
            if min_ring_size == 0:
113 151
                ring = '!r'
114
            else:
115 151
                ring = 'r%i' % min_ring_size
116

117 151
            return (
118
                '#%i' % atom.atomic_number(),
119
                'H%i' % atom.hydrogen_count(),
120
                'X%i' % atom.connectivity(),
121
                'x%i' % atom.ring_connectivity(),
122
                ring,
123
                charge,
124
                aromatic,
125
                )
126

127 151
        def as_smirks(self, compress=False):
128
            """
129
            Parameters
130
            ----------
131
            compress : boolean
132
                should decorators common to all sets be combined
133
                for example '#6X4,#7X3;+0!r...'
134

135
            Returns
136
            -------
137
            smirks : str
138
                how this atom would be represented in a SMIRKS string
139
                with the minimal combination of SMIRKS decorators
140
            """
141 151
            if len(self.decorators) == 0:
142 151
                if self.label is None or self.label <= 0:
143 151
                    return '[*]'
144 151
                return '[*:%i]' % self.label
145

146 151
            if compress and len(self.decorators) > 1:
147 151
                base_smirks = self._compress_smirks()
148
            else:
149 151
                base_smirks = ','.join(sorted([''.join(l) for l in self.decorators]))
150

151 151
            if self.label is None or self.label <= 0:
152 151
                return '[%s]' % base_smirks
153

154 151
            return '[%s:%i]' % (base_smirks, self.label)
155

156 151
        def _sort_decs(self, dec_set, wild=True):
157
            """
158
            Parameters
159
            ----------
160
            dec_set : list like
161
                single set of atom decorators
162
            wild : boolean
163
                insert * for decorator lists with no #n decorator
164

165
            Returns
166
            -------
167
            sorted_dec_set : list
168
                same set of decorators sorted with atomic number or * first
169
            """
170 151
            temp_dec_set = list(dec_set)
171 151
            atom_num = [i for i in temp_dec_set if '#' in i]
172 151
            if len(atom_num) == 0 and wild:
173 0
                atom_num = ["*"]
174

175 151
            temp_dec_set = set(temp_dec_set) - set(atom_num)
176

177 151
            aro = [i for i in temp_dec_set if 'a' in i.lower()]
178 151
            temp_dec_set = set(temp_dec_set) - set(aro)
179

180 151
            return atom_num + sorted(list(temp_dec_set)) + aro
181

182 151
        def _compress_smirks(self):
183
            """
184
            Returns
185
            -------
186
            smirks : str
187
                This SMIRKS is compressed with all common decorators and'd to
188
                the end of the pattern
189
            """
190 151
            set_decs = [set(d) for d in self.decorators]
191 151
            ands = set_decs[0]
192

193 151
            for d_set in set_decs:
194 151
                ands = ands & d_set
195

196
            # check for atomic number in the "ands"
197 151
            atomic = [a for a in ands if '#' in a]
198 151
            if len(atomic) == 1:
199
                # remove from and
200 151
                ands.remove(atomic[0])
201
                # put in all sets
202 151
                for s in set_decs:
203 151
                    s.add(atomic[0])
204

205 151
            or_sets = [self._sort_decs(d.difference(ands)) for d in set_decs]
206 151
            ors = [''.join(o) for o in or_sets]
207

208
            # add commas between ors
209 151
            base = ','.join(sorted(ors))
210
            # add and decorators
211 151
            if len(ands) > 0:
212 151
                base += ';'+ ';'.join(self._sort_decs(ands, wild=False))
213 151
            return base
214

215 151
        def add_atom(self, atom):
216
            """
217
            Expand current AtomStorage by adding information about
218
            a new ChemPer Atom
219

220
            Parameters
221
            ----------
222
            atom : ChemPer Atom
223
            """
224 151
            self.decorators.add(self.make_atom_decorators(atom))
225

226 151
        def compare_atom(self, atom):
227
            """
228
            Compares decorators in this AtomStorage with the provided
229
            ChemPer atom. The decorators are compared separately and
230
            the highest score is returned. For example,
231
            if this storage had two sets of decorators
232
                - #7H1X3x0!r+0A
233
                - #6H1X4x0!r+0A
234
            and the input atom would have the decorators:
235
                - #6H1X3x2!r+0a
236

237
            The score is calculated by finding the number of decorators
238
            in common which would be
239
                - #7H1X3x0!r+0A and #6H1X3x2r6+0a
240
                    have 3 decorators in common (H1,X3,+0)
241
                - #6H1X4x0!r+0A and #6H1X3x2r6+0a
242
                    also have 3 decorators in common (#6, H1, +0)
243
            However, we weight atoms with the same atomic number as more
244
            similar by dividing the score by 10 if the atomic numbers do
245
            not agree. Therefore the final scores will be:
246
                - 0.3 for #7H1X3x0!r+0A
247
                - 3 for #6H1X4x0!r+0A
248

249
            The highest score for any set of decorators is returned so
250
            3 is the returned score in this example.
251

252
            Parameters
253
            ----------
254
            atom : ChemPer Atom
255

256
            Returns
257
            -------
258
            score : float
259
                A score describing how similar the input atom is to any set of
260
                decorators currently in this storage, based on its SMIRKS decorators.
261
                This score ranges from 0 to 7. 7 comes from the number of decorators
262
                on any atom, if this atom matches perfectly with one of the
263
                current decorator sets then 7 decorators agree.However, if the atomic
264
                number doesn't agree, then that set of decorators is considered
265
                less ideal, thus if the atomic numbers don't agree, then the score
266
                is given by the number other decorators divided by 10.
267
                If the current storage is empty, then the score is given as 7
268
                since any atom matches a wildcard atom.
269
            """
270
            # If decorators is empty (no known atom information, return 7 (current max)
271 151
            if len(self.decorators) == 0:
272 0
                return 7
273

274 151
            score = 0
275 151
            decs = self.make_atom_decorators(atom)
276

277 151
            for ref in self.decorators:
278
                # get atomic number for this set of decorators
279 151
                current = len(set(ref) & set(decs))
280

281
                # if atomic numbers don't agree, get the number of common decorators / 10
282
                # if there are no matching atomic numbers, priority should still be given
283
                # when the current atom matches stored decorators most closely
284 151
                if ref[0] != decs[0]:
285 151
                    current = current / 10.0
286

287 151
                if current > score:
288 151
                    score = current
289

290 151
            return score
291

292 151
    @total_ordering
293
    class BondStorage:
294
        """
295
        BondStorage tracks information about a bond
296
        """
297 151
        def __init__(self, bonds=None, label=None):
298
            """
299
            Parameters
300
            ----------
301
            bonds : list of ChemPer Bonds
302
                this is one or more bonds whose information should be stored
303
            label : a label for the object, it can be anything
304
                unlike atoms, bonds in smirks don't have labels
305
                so this is only used for labeling the object if wanted
306
            """
307 151
            self.order = set()
308 151
            self.ring = set()
309 151
            self.order_dict = {1:'-', 1.5:':', 2:'=', 3:'#'}
310 151
            if bonds is not None:
311 151
                if 'Bond' in str(type(bonds)):
312 151
                    bonds = [bonds]
313 151
                for bond in bonds:
314 151
                    self.order.add(bond.get_order())
315 151
                    self.ring.add(bond.is_ring())
316

317 151
            self.label = label
318

319 151
        def __str__(self): return self.as_smirks()
320

321 151
        def __lt__(self, other):
322 0
            if self.as_smirks() == other.as_smirks():
323 0
                return self.label < other.label
324 0
            return self.as_smirks() < other.as_smirks()
325

326 151
        def __eq__(self, other):
327 0
            return self.label == other.label and self.as_smirks() == other.as__smirks()
328

329 151
        def __hash__(self): return id(self)
330
        
331 151
        def as_smirks(self):
332
            """
333
            Returns
334
            -------
335
            smirks : str
336
                how this bond would be represented in a SMIRKS string
337
                using only the required number of
338
            """
339 151
            if len(self.order) == 0:
340 151
                order = '~'
341
            else:
342 151
                order = ','.join([self.order_dict.get(o, '~') for o in sorted(list(self.order))])
343

344
            # the ring set has booleans, if the length of the set is 1 then only ring (@) or non-ring (!@)
345
            # bonds haven been added to this storage and we AND that decorator to the end of the bond
346 151
            if len(self.ring) == 1:
347 151
                if list(self.ring)[0]:
348 151
                    return order+';@'
349
                else:
350 151
                    return order+';!@'
351

352 151
            return order
353

354 151
        def add_bond(self, bond):
355
            """
356
            Expand current BondStorage by adding information about
357
            a new ChemPer Bond
358

359
            Parameters
360
            ----------
361
            bond : ChemPer Bond
362
            """
363 151
            self.order.add(bond.get_order())
364 151
            self.ring.add(bond.is_ring())
365

366 151
        def compare_bond(self, bond):
367
            """
368

369
            Parameters
370
            ----------
371
            bond : ChemPer Bond
372
                bond you want to compare to the current storage
373

374
            Returns
375
            -------
376
            score : int (0,1,2)
377
                A score describing how similar the input bond is to any set of decorators currently
378
                in this storage, based on its SMIRKS decorators.
379

380
                1 for the bond order +
381
                1 base on if this is a ring bond
382
            """
383 151
            score = 0
384 151
            if bond.get_order() in self.order or len(self.order) == 0:
385 151
                score += 1
386

387
            # the ring set has booleans, if the length of the set is 1 then only ring or non-ring
388
            # bonds haven been added to this storage. That is the only time the ring contributes to the score
389 151
            if len(self.ring) == 1 and list(self.ring)[0] == bond.is_ring():
390 151
                score += 1
391

392 151
            return score
393

394
    # Initiate ClusterGraph
395 151
    def __init__(self, mols=None, smirks_atoms_lists=None, layers=0):
396
        """
397
        Initialize a SingleGraph from a molecule and list of indexed atoms
398

399
        For the example, imagine we wanted to get a SMIRKS that
400
        would match the carbon-carbon bonds in ethane and propane.
401
        The carbon atoms are have indices (0,1) in ethane and (0,1) and (1,2)
402
        in propane. For this example, we will assume we also want to include
403
        the atoms one bond away from the indexed atoms (1 layer away).
404

405
        Parameters
406
        ----------
407
        mols : list of molecules (optional)
408
            default = None (makes an empty graph)
409
            these can be ChemPer Mols or molecule objects from
410
            any supported toolkit (currently OpenEye or RDKit)
411

412
        smirks_atoms_lists : list of list of tuples (optional)
413
            default = None (must be paired with mols=None)
414
            There is a list of tuples for each molecule, where each tuple specifies
415
            a molecular fragment using the atoms' indices.
416
            In the ethane and propane example, the `smirks_atoms_lists` would be
417
                [ [ (0,1) ], [ (0,1), (1,2) ] ]
418
            with one carbon-carbon bond in ethane and two carbon-carbon bonds in propane
419

420
        layers : int (optional)
421
            default = 0
422
            layers specifies how many bonds away from the indexed atoms should be included in the
423
            the SMIRKS patterns.
424
            Instead of an int, the string 'all' would lead to all atoms in the molecules
425
            being included in the SMIRKS (not recommended)
426
        """
427 151
        SingleGraph.__init__(self)
428

429 151
        self.mols = list()
430 151
        self.smirks_atoms_lists = list()
431 151
        self.layers = layers
432 151
        self._symmetry_funct = self._no_symmetry
433

434 151
        if mols is not None:
435 151
            temp_mols = [mol_toolkit.Mol(m) for m in mols]
436 151
            if len(temp_mols) != len(smirks_atoms_lists):
437 151
                raise Exception('Number of molecules and smirks dictionaries should be equal')
438

439 151
            for idx, mol in enumerate(temp_mols):
440 151
                self.add_mol(mol, smirks_atoms_lists[idx])
441

442 151
    def as_smirks(self, compress=False):
443
        """
444
        Parameters
445
        ----------
446
        compress : boolean
447
            returns the shorter version of atom SMIRKS patterns
448
            that is atoms have decorators "anded" to the end rather than listed
449
            in each set that are OR'd together.
450
            For example "[#6AH2X3x0!r+0,#6AH1X3x0!r+0:1]-;!@[#1AH0X1x0!r+0]"
451
            compresses to: "[#6H2,#6H1;AX3x0!r+0:1]-;!@[#1AH0X1x0!r+0]"
452

453
        Returns
454
        -------
455
        SMIRKS : str
456
            a SMIRKS string matching the exact atom and bond information stored
457
        """
458
        # The atom compression is different, but otherwise this is the
459
        # same function as the parent class (SingleGraph)
460 151
        return SingleGraph.as_smirks(self, compress)
461

462 151
    def get_symmetry_funct(self, sym_label):
463
        """
464
        Determine the symmetry function that should be used
465
        when adding atoms to this graph.
466

467
        For example, imagine a user is trying to make a
468
        SMIRKS for all of the C-H bonds in methane. In most
469
        toolkits the index for the carbon is 0 and the hydrogens are 1,2,3,4.
470
        The final SMIRKS should have the form [#6AH4X4x0!r+0:1]-;!@[#1AH0X1x0!r+0]
471
        no matter what order the atoms are input into ClusterGraph.
472
        So if the user provides (0,1), (0,2), (3,0), (4,0) ClusterGraph
473
        should figure out that the carbons in (3,0) and (4,0) should be in
474
        the atom index :1 place like they were in the first set of atoms.
475

476
        Bond atoms in (1,2) or (2,1) are symmetric, for angles its (1,2,3) or (3,2,1)
477
        for proper torsions (1,2,3,4) or (4,3,2,1) and for
478
        improper torsions (1,2,3,4), (3,2,1,4), (4,2,1,3).
479
        For any other fragment type the atoms will be added to the graph in
480
        the order they are provided since the symmetry function is unknown.
481

482
        # TODO: In theory you could generalize this for generic linear fragments
483
        # where those with an odd number of atoms behave like angles and an
484
        # even number behave like proper torsions, however I think that is
485
        # going to be outside the scope of ChemPer for the foreseeable future.
486

487
        Parameters
488
        ----------
489
        sym_label : str or None
490
            type of symmetry, options which will change the way symmetry is
491
            handled in the graph are "bond", "angle", "ProperTorsion", and "ImproperTorsion"
492

493
        Returns
494
        -------
495
        symmetry_funct : function
496
            returns the function that should be used to handle the appropriate symmetry
497
        """
498 151
        if sym_label is None:
499 0
            return self._no_symmetry
500 151
        if sym_label.lower() == 'bond':
501 151
            return self._bond_symmetry
502 151
        if sym_label.lower() == 'angle':
503 151
            return self._angle_symmetry
504 151
        if sym_label.lower() == 'propertorsion':
505 151
            return self._proper_torsion_symmetry
506 151
        if sym_label.lower() == 'impropertorsion':
507 0
            return self._improper_torsion_symmetry
508 151
        return self._no_symmetry
509

510 151
    def add_mol(self, input_mol, smirks_atoms_list):
511
        """
512
        Expand the information in this graph by adding a new molecule
513

514
        Parameters
515
        ----------
516
        input_mol : ChemPer Mol
517
        smirks_atoms_list : list of tuples
518
            This is a list of tuples with atom indices [ (indices), ... ]
519
        """
520 151
        mol = mol_toolkit.Mol(input_mol)
521

522 151
        if len(smirks_atoms_list) == 0:
523 151
            return
524

525 151
        if len(self.mols) == 0:
526 151
            self._add_first_smirks_atoms(mol, smirks_atoms_list[0])
527 151
            self._symmetry_funct = self.get_symmetry_funct(CE(self.as_smirks()).get_type())
528 151
            self._add_mol(mol, smirks_atoms_list[1:])
529
        else:
530 151
            self._add_mol(mol, smirks_atoms_list)
531

532 151
        self.mols.append(mol)
533 151
        self.smirks_atoms_lists.append(smirks_atoms_list)
534

535 151
    def _add_first_smirks_atoms(self, mol, smirks_atoms):
536
        """
537
        private function for adding the first molecule to an empty ClusterGraph
538
        add_mol calls this if the graph is empty
539

540
        Parameters
541
        ----------
542
        mol : ChemPer Mol
543
        smirks_atoms : tuple
544
            tuple of atom indices for the first atoms to add to the graph. i.e. (0, 1)
545
        """
546 151
        atom_dict = dict()
547 151
        for key, atom_index in enumerate(smirks_atoms, 1):
548 151
            atom_dict[atom_index] = key
549

550 151
            atom1 = mol.get_atom_by_index(atom_index)
551 151
            new_atom_storage = self.AtomStorage([atom1], key)
552 151
            self._graph.add_node(new_atom_storage)
553 151
            self.atom_by_label[key] = new_atom_storage
554

555
            # Check for bonded atoms already in the graph
556 151
            for neighbor_key in range(len(smirks_atoms), 0, -1):
557 151
                if neighbor_key not in self.atom_by_label:
558 151
                    continue
559

560
                # check if atoms are already connected on the graph
561 151
                neighbor_storage = self.atom_by_label[neighbor_key]
562 151
                if nx.has_path(self._graph, new_atom_storage, neighbor_storage):
563 151
                    continue
564

565
                # check if atoms are connected in the molecule
566 151
                atom2 = mol.get_atom_by_index(smirks_atoms[neighbor_key-1])
567 151
                bond = mol.get_bond_by_atoms(atom1, atom2)
568

569 151
                if bond is not None: # Atoms are connected add edge
570 151
                    bond_smirks = tuple(sorted([neighbor_key, key]))
571 151
                    bond_storage = self.BondStorage([bond], bond_smirks)
572 151
                    self.bond_by_label[bond_smirks] = bond_storage
573 151
                    self._graph.add_edge(new_atom_storage,
574
                                         neighbor_storage,
575
                                         bond=bond_storage)
576

577
        # for each indexed atoms add unindexed atoms for the number of specified layers
578 151
        for atom_label, atom_index in enumerate(smirks_atoms, 1):
579 151
            atom = mol.get_atom_by_index(atom_index)
580 151
            storage = self.atom_by_label[atom_label]
581 151
            self._add_layers(mol, atom, storage, self.layers, atom_dict, is_first=True)
582

583 151
    def _add_layers(self, mol, atom, storage, layers, idx_dict, is_first=False):
584
        """
585
        Parameters
586
        ----------
587
        mol : ChemPer Mol
588
            molecule containing provided atom
589
        atom : ChemPer Atom
590
        storage: AtomStorage
591
            corresponding to the ChemPer Atom provided
592
        layers : int or 'all'
593
            number of layers left to add (or all)
594
        idx_dict : dict
595
            form {atom index: label} for this smirks_list in this molecule
596
        """
597
        # if layers is 0 there are no more atoms to add so end the recursion
598 151
        if layers == 0:
599 151
            return
600

601
        # find atom neighbors that are not already included in SMIRKS indexed atoms
602 151
        atom_neighbors = [(a, mol.get_bond_by_atoms(a,atom)) for a in atom.get_neighbors() \
603
                          if a.get_index() not in idx_dict]
604

605
        # get the smirks indices already added to the storage
606
        # This includes all previous layers since the idx_dict is updated as you go
607 151
        storage_labels = [e for k,e in idx_dict.items()]
608

609
        # similar to atoms find neighbors already in the graph that haven't already been used
610 151
        storage_neighbors = [(s, self.get_connecting_bond(s, storage)) for s in self.get_neighbors(storage) \
611
                             if s.label not in storage_labels]
612

613 151
        new_pairs = list()
614
        # if this is the first set of atoms added, just make a new
615
        # storage for all neighboring atoms
616 151
        if is_first:
617 151
            min_smirks = storage.label * 10
618 151
            if min_smirks > 0:
619 151
                min_smirks = min_smirks * -1
620

621 151
            for a, b in atom_neighbors:
622 151
                new_bond_smirks = tuple(sorted([storage.label, min_smirks]))
623

624 151
                adding_new_storage = self.add_atom(a,b,storage,
625
                                                   min_smirks, new_bond_smirks)
626

627 151
                idx_dict[a.get_index()] = min_smirks
628 151
                self.atom_by_label[min_smirks] = adding_new_storage
629 151
                min_smirks -= 1
630 151
                new_pairs.append((a, adding_new_storage))
631

632
        else: # this isn't the first set of atoms so you need to
633
            # pair up the atoms with their storage
634 151
            pairs = self.find_pairs(atom_neighbors, storage_neighbors)
635 151
            for new_atom, new_bond, new_storage_atom, new_storage_bond in pairs:
636
                # if no storage is paired to this atom skip it
637 151
                if new_storage_atom is None:
638 151
                    continue
639
                # if there is no atom paired to a storage remove that branch
640 151
                if new_atom is None:
641 151
                    self.remove_atom(new_storage_atom)
642 151
                    continue
643
                # add atom and bond information to the storage
644 151
                new_storage_atom.add_atom(new_atom)
645 151
                new_storage_bond.add_bond(new_bond)
646 151
                new_pairs.append((new_atom, new_storage_atom))
647 151
                idx_dict[new_atom.get_index()] = new_storage_atom.label
648

649
        # Repeat for the extra layers
650 151
        if layers == 'all':
651 151
            new_layers = 'all'
652
        else:
653 151
            new_layers = layers - 1
654 151
            if new_layers == 0:
655 151
                return
656

657 151
        for new_atom, new_storage in new_pairs:
658 151
            self._add_layers(mol, new_atom, new_storage, new_layers, idx_dict, is_first)
659

660 151
    def find_pairs(self, atoms_and_bonds, storages):
661
        """
662
        Find pairs is used to determine which current AtomStorage from storages
663
        atoms should be paired with.
664
        This function takes advantage of the maximum scoring function in networkx
665
        to find the pairing with the highest "score".
666
        Scores are determined using functions in the atom and bond storage objects
667
        that compare those storages to the new atom or bond.
668

669
        If there are less atoms than storages then the atoms with the lowest pair are
670
        assigned a None pairing.
671

672
        Parameters
673
        ----------
674
        atoms_and_bonds : list of tuples in form (ChemPer Atom, ChemPer Bond, ...)
675
        storages: list of tuples in form (AtomStorage, BondStorage, ...)
676

677
        Tuples can be of any length as long as they are the same, so for example, in
678
        a bond you might only care about the outer atoms for comparison so you would compare
679
        (atom1,) and (atom2,) with (atom_storage1,) and (atom_storage2,)
680
        However, in a torsion, you might want the atoms and bonds for each outer bond
681
        so in that case you would compare
682
        (atom1, bond1, atom2) and (atom4, bond3, atom3)
683
        with the corresponding storage objects.
684

685
        Returns
686
        -------
687
        pairs : list of lists
688
            pairs of atoms and storage objects that are most similar,
689
            these lists always come in the form (all atom/bonds, all storage objects)
690
            for the bond example above you might get
691
            [ [atom1, storage1], [atom2, storage2] ]
692
            for the torsion example you might get
693
            [ [atom4, bond4, atom3, atom_storage1, bond_storage1, atom_storage2],
694
              [atom1, bond1, atom2, atom_storage4, bond_storage3, atom_storage3]
695

696
        """
697
        # store paired stets of atoms/bonds and corresponding storages
698 151
        pairs = list()
699
        # check for odd cases
700 151
        combo = atoms_and_bonds + storages
701
        # 1. both lists are empty
702 151
        if len(combo) == 0:
703 151
            return pairs
704

705 151
        nones = [None] * len(combo[0])
706
        # 2. no atom/bond storage
707 151
        if len(atoms_and_bonds) == 0:
708 151
            for storage_set in storages:
709 151
                pairs.append(nones + list(storage_set))
710 151
            return pairs
711

712
        # 3. no storages
713 151
        if len(storages) == 0:
714 151
            for atom_set in atoms_and_bonds:
715 151
                pairs.append(list(atom_set) + nones)
716 151
            return pairs
717

718 151
        g = nx.Graph()
719

720 151
        atom_dict = dict()
721 151
        storage_dict = dict()
722

723
        # create a bipartite graph with atoms/bonds on one side
724 151
        for idx, atom_set in enumerate(atoms_and_bonds):
725 151
            g.add_node(idx+1, bipartite=0)
726 151
            atom_dict[idx+1] = atom_set
727
        # and atom/bond storage objects on the other
728 151
        for idx, storage_set in enumerate(storages):
729 151
            g.add_node((idx*-1)-1, bipartite=1)
730 151
            storage_dict[(idx*-1)-1] = storage_set
731

732
        # Fill in the weight on each edge of the graph using the compare_atom/bond functions
733 151
        for a_idx, atom_set in atom_dict.items():
734 151
            for s_idx, storage_set in storage_dict.items():
735
                # sum up score for every entry in the atom and storage set
736 151
                score = 0
737 151
                for sa, a in zip(storage_set, atom_set):
738 151
                    if isinstance(sa, self.BondStorage):
739 151
                        score += sa.compare_bond(a)
740
                    else:
741 151
                        score += sa.compare_atom(a)
742
                # score can't be zero so save score+1
743 151
                g.add_edge(a_idx,s_idx,weight=score+1)
744

745
        # calculate maximum matching, that is the pairing of atoms/bonds to
746
        # storage objects that leads the the highest overall score
747 151
        matching = nx.algorithms.max_weight_matching(g,maxcardinality=False)
748
        # track the atoms assigned a paired storage object
749 151
        pair_set = set()
750

751
        # store all pairs
752 151
        for idx_1, idx_2 in matching:
753 151
            pair_set.add(idx_1)
754 151
            pair_set.add(idx_2)
755 151
            if idx_1 in atom_dict:
756 151
                atom_set = atom_dict[idx_1]
757 151
                storage_set = storage_dict[idx_2]
758
            else:
759 151
                atom_set = atom_dict[idx_2]
760 151
                storage_set = storage_dict[idx_1]
761 151
            pairs.append(list(atom_set) + list(storage_set))
762

763
        # check for missing atom storages
764 151
        for a_idx, atom_set in atom_dict.items():
765 151
            if a_idx not in pair_set:
766 151
                pairs.append(list(atom_set) + nones)
767

768
        # check for missing atoms
769 151
        for s_idx, storage_set in storage_dict.items():
770 151
            if s_idx not in pair_set:
771 151
                pairs.append(nones + list(storage_set))
772

773 151
        return pairs
774

775 151
    def _add_mol(self, mol, smirks_atoms_list):
776
        """
777
        private function for adding a new molecule
778
        This is used by add_mol if the graph is not empty, allowing the user to
779
        not have to track if the graph already has information before adding molecules
780

781
        Parameters
782
        ----------
783
        mol : any Mol
784
        smirks_atoms_list : list of dicts
785
            This is a list of dictionaries of the form [{smirks index: atom index}]
786
            each atom (by index) in the dictionary will be added the relevant
787
            AtomStorage by smirks index
788
        """
789 151
        for smirks_atoms in smirks_atoms_list:
790 151
            atom_dict = dict()
791 151
            sorted_smirks_atoms = self._symmetry_funct(mol, smirks_atoms)
792 151
            for key, atom_index in enumerate(sorted_smirks_atoms, 1):
793 151
                atom_dict[atom_index] = key
794 151
                atom1 = mol.get_atom_by_index(atom_index)
795 151
                self.atom_by_label[key].add_atom(atom1)
796

797 151
                for neighbor_key, neighbor_index in enumerate(sorted_smirks_atoms, 1):
798
                    # check for connecting bond
799 151
                    atom2 = mol.get_atom_by_index(neighbor_index)