MobleyLab / chemper
1
"""
2
fragment_graph.py
3

4
ChemPerGraph is a class for storing smirks decorators for a molecular fragment.
5
These can be used to convert a molecular sub-graph or an entire molecule into a SMIRKS
6
pattern with all decorators specified.
7

8
For example, imagine you want a SMIRKS for the carbon in methane, it would become:
9

10
"[#6AH4X4x0!r+0:1]"
11

12
with decorators:
13
#6: atomic number 6 for carbon
14
A: aliphatic (a would be aromatic)
15
H4: a total hydrogen count of 4, 4 neighbors are hydrogen
16
X4: connectivity of 4, that is number of neighbors, not valence or sum of bond orders
17
x0: ring connectivity of 0, no ring bonds
18
!r: not in a ring, for atoms in a ring this decorator is `rn` where n is the size of the smallest ring
19
+0: 0 formal charge
20

21
To the best of the authors knowledge, this is the first open source tool capable
22
of converting a molecule (or sub-graph) into a detailed SMIRKS pattern.
23

24
AUTHORS:
25

26
Caitlin C. Bannan <bannanc@uci.edu>, Mobley Group, University of California Irvine
27
"""
28

29 2
import networkx as nx
30 2
from functools import total_ordering
31 2
from chemper.mol_toolkits import mol_toolkit
32

33

34 2
@total_ordering
35 2
class ChemPerGraph(object):
36
    """
37
    ChemPerGraphs are a graph based class for storing atom and bond information.
38
    They use the chemper.mol_toolkits Atoms, Bonds, and Mols
39
    """
40 2
    @total_ordering
41 2
    class AtomStorage(object):
42
        """
43
        AtomStorage tracks information about an atom
44
        """
45 2
        def __init__(self, atom=None, label=None):
46
            """
47
            Initializes AtomStorage based on a provided atom
48

49
            Parameters
50
            ----------
51
            atom: chemper Atom object
52
            label: int
53
                integer for labeling this atom in a SMIRKS
54
                or if negative number just used to track the atom locally
55
            """
56 2
            self.atom = atom
57

58 2
            if atom is None:
59 2
                self.atomic_number = None
60 2
                self.aromatic = None
61 2
                self.charge = None
62 2
                self.hydrogen_count = None
63 2
                self.connectivity = None
64 2
                self.ring_connectivity = None
65 2
                self.min_ring_size = None
66 2
                self.atom_index = None
67

68
            else:
69 2
                self.atomic_number = atom.atomic_number()
70 2
                self.aromatic = atom.is_aromatic()
71 2
                self.charge = atom.formal_charge()
72 2
                self.hydrogen_count = atom.hydrogen_count()
73 2
                self.connectivity = atom.connectivity()
74 2
                self.ring_connectivity = atom.ring_connectivity()
75 2
                self.min_ring_size = atom.min_ring_size()
76 2
                self.atom_index = atom.get_index()
77

78 2
            self.label = label
79

80 2
        def __lt__(self, other):
81
            """
82
            Overrides the default implementation
83
            This method was primarily written for making SMIRKS patterns predictable.
84
            If atoms are sortable, then the SMIRKS patterns are always the same making
85
            tests easier to write. However, the specific sorting was created to also make SMIRKS
86
            output as human readable as possible, that is to at least make it easier for a
87
            human to see how the indexed atoms are related to each other.
88
            It is typically easier for humans to read SMILES/SMARTS/SMIRKS with less branching (indicated with ()).
89

90
            For example in:
91
            [C:1]([H])([H])~[N:2]([C])~[O:3]
92
            it is easier to see that the atoms C~N~O are connected in a "line" instead of:
93
            [C:1]([N:2]([O:3])[C])([H])[H]
94
            which is equivalent, but with all the () it is hard for a human to read the branching
95

96
            Parameters
97
            ----------
98
            other: AtomStorage
99

100
            Returns
101
            -------
102
            is_less_than: boolean
103
                self is less than other
104
            """
105
            # if either smirks index is None, then you can't directly compare
106
            # make a temporary index that is negative if it was None
107 2
            self_index = self.label if self.label is not None else -1000
108 2
            other_index = other.label if other.label is not None else -1000
109
            # if either index is greater than 0, the one that is largest should go at the end of the list
110 2
            if self_index > 0 or other_index > 0:
111 2
                return self_index < other_index
112

113
            # Both SMIRKS indices are not positive or None so compare the SMIRKS patterns instead
114 2
            return self.as_smirks() < other.as_smirks()
115

116 2
        def __eq__(self, other): return self.as_smirks() == other.as_smirks() and self.label == other.label
117

118 2
        def __hash__(self): return id(self)
119

120 2
        def __str__(self): return self.as_smirks()
121

122 2
        def as_smirks(self, compress=False):
123
            """
124
            Returns
125
            -------
126
            smirks: str
127
                how this atom would be represented in a SMIRKS string
128
            """
129 2
            if self.atom is None:
130 2
                if self.label is None or self.label <= 0:
131 2
                    return '[*]'
132 2
                return '[*:%i]' % self.label
133

134 2
            aromatic = 'a' if self.aromatic else 'A'
135 2
            if self.charge >= 0:
136 2
                charge = '+%i' % self.charge
137
            else:
138 2
                charge = '%i' % self.charge
139 2
            if self.min_ring_size == 0:
140 2
                ring = '!r'
141
            else:
142 2
                ring = 'r%i' % self.min_ring_size
143

144 2
            if compress:
145 2
                base_smirks = "#%i" % self.atomic_number
146
            else:
147 2
                base_smirks = '#%i%sH%iX%ix%i%s%s' % (self.atomic_number,
148
                                                      aromatic,
149
                                                      self.hydrogen_count,
150
                                                      self.connectivity,
151
                                                      self.ring_connectivity,
152
                                                      ring,
153
                                                      charge)
154

155 2
            if self.label is None or self.label <= 0:
156 2
                return '[%s]' % base_smirks
157

158 2
            return '[%s:%i]' % (base_smirks, self.label)
159

160 2
    @total_ordering
161 2
    class BondStorage(object):
162
        """
163
        BondStorage tracks information about a bond
164
        """
165 2
        def __init__(self, bond=None, label=None):
166
            """
167
            Parameters
168
            ----------
169
            bond: chemper Bond object
170
            label: int or float
171
                Bonds don't have SMIRKS indices so this is only used for internal
172
                tracking of the object.
173
            """
174 2
            if bond is None:
175 2
                self.order = None
176 2
                self.ring = None
177 2
                self.bond_index = None
178
            else:
179 2
                self.order = bond.get_order()
180 2
                self.ring = bond.is_ring()
181 2
                self.bond_index = bond.get_index()
182

183 2
            self._bond = bond
184 2
            self.label = label
185

186 2
        def __str__(self): return self.as_smirks()
187

188 2
        def __lt__(self, other):
189 0
            if self.as_smirks() == other.as_smirks():
190 0
                return self.label < other.label
191 0
            return self.as_smirks() < other.as_smirks()
192

193 2
        def __eq__(self, other):
194 0
            return self.label == other.label and self.as_smirks() == other.as__smirks()
195

196 2
        def __hash__(self): return id(self)
197

198 2
        def as_smirks(self):
199
            """
200
            Returns
201
            -------
202
            SMIRKS: str
203
                how this bond should appear in a SMIRKS string
204
            """
205 2
            if self.ring is None:
206 2
                ring = ''
207 2
            elif self.ring:
208 0
                ring = '@'
209
            else:
210 2
                ring = '!@'
211

212 2
            order = {1:'-', 1.5:':', 2:'=', 3:'#', None:'~'}.get(self.order)
213

214 2
            return order+ring
215

216 2
    def __init__(self):
217
        """
218
        Initialize empty ChemPerGraph
219
        """
220 2
        self._graph = nx.Graph()
221 2
        self.atom_by_label = dict() # stores a dictionary of atoms by label
222 2
        self.bond_by_label = dict() # stores a dictionary of bonds by label
223

224 2
    def __str__(self): return self.as_smirks()
225

226 2
    def __lt__(self, other): return self.as_smirks() < other.as_smirks()
227

228 2
    def __eq__(self, other): return self.as_smirks() == self.as_smirks()
229

230 2
    def __hash__(self): return id(self)
231

232 2
    def as_smirks(self, compress=False):
233
        """
234
        Parameters
235
        ----------
236
        compress: boolean
237
                  returns the shorter version of atom SMIRKS patterns
238
                  that is the atoms only include atomic numbers rather
239
                  than the full list of decorators
240
        Returns
241
        -------
242
        SMIRKS: str
243
            a SMIRKS string matching the exact atom and bond information stored
244
        """
245

246
        # If no atoms have been added
247 2
        if len(self._graph.nodes()) == 0:
248 2
            return None
249

250 2
        if self.atom_by_label:
251
            # sometimes we use negative numbers for internal indexing
252
            # the first atom in a smirks pattern should be based on actual smirks indices (positive)
253 2
            smirks_indices = [k for k in self.atom_by_label.keys() if k > 0]
254 2
            if len(smirks_indices) != 0:
255 2
                min_smirks = min(smirks_indices)
256
            else:
257 0
                min_smirks = min([k for k in self.atom_by_label.keys()])
258 2
            init_atom = self.atom_by_label[min_smirks]
259
        else:
260 2
            init_atom = self.get_atoms()[0]
261

262
        # sort neighboring atoms to keep consist output
263 2
        neighbors = sorted(self.get_neighbors(init_atom))
264 2
        return self._as_smirks(init_atom, neighbors, compress)
265

266 2
    def _as_smirks(self, init_atom, neighbors, compress=False):
267
        """
268
        This is an internal/private method used to add all AtomStorage to the SMIRKS pattern
269

270
        Parameters
271
        ----------
272
        init_atom: AtomStorage object
273
            current atom
274
        neighbors: list of AtomStorage objects
275
            list of neighbor atoms you wanted added to the SMIRKS pattern
276

277
        Returns
278
        -------
279
        SMIRKS: str
280
            This graph as a SMIRKS string
281
        """
282 2
        smirks = init_atom.as_smirks(compress)
283 2
        for idx, neighbor in enumerate(neighbors):
284 2
            bond = self.get_connecting_bond(init_atom, neighbor)
285 2
            bond_smirks = bond.as_smirks()
286

287 2
            new_neighbors = sorted(self.get_neighbors(neighbor))
288 2
            new_neighbors.remove(init_atom)
289

290 2
            atom_smirks = self._as_smirks(neighbor, new_neighbors,compress)
291

292 2
            if idx < len(neighbors) - 1:
293 2
                smirks += '(' + bond_smirks + atom_smirks + ')'
294
            else:
295 2
                smirks += bond_smirks + atom_smirks
296

297 2
        return smirks
298

299 2
    def get_atoms(self):
300
        """
301
        Returns
302
        -------
303
        atoms: list of AtomStorage objects
304
            all atoms stored in the graph
305
        """
306 2
        return list(self._graph.nodes())
307

308 2
    def get_connecting_bond(self, atom1, atom2):
309
        """
310
        Parameters
311
        ----------
312
        atom1: AtomStorage object
313
        atom2: AtomStorage object
314

315
        Returns
316
        -------
317
        bond: BondStorage object
318
            bond between the two given atoms or None if not connected
319
        """
320 2
        bond = self._graph.get_edge_data(atom1, atom2)
321 2
        if bond is not None:
322 2
            return bond['bond']
323 0
        return None
324

325 2
    def get_bonds(self):
326
        """
327
        Returns
328
        -------
329
        bonds: list of BondStorage objects
330
            all bonds stored as edges in this graph
331
        """
332 2
        return [data['bond'] for a1, a2, data in self._graph.edges(data=True)]
333

334 2
    def get_neighbors(self, atom):
335
        """
336
        Parameters
337
        ----------
338
        atom: an AtomStorage object
339

340
        Returns
341
        -------
342
        atoms: list of AtomStorage objects
343
            list of atoms one bond (edge) away from the given atom
344
        """
345 2
        return list(self._graph.neighbors(atom))
346

347 2
    def remove_atom(self, atom):
348
        """
349
        Removes the provided atom and all connected atoms
350
        """
351
        # if atom isn't in the graph, it can't be removed
352 2
        if atom not in self._graph.nodes():
353 0
            return False
354
        # if atom is "indexed" that is has a SMIRKS index > 0 it can't be removed
355 2
        if atom.label > 0:
356 0
            return False
357
        # remove specified atom
358 2
        self._graph.remove_node(atom)
359
        # find atoms on that "branch" of the molecule
360
        # we do this by looking for atoms that are no longer connected to
361
        # the base of the graph, where we consider the base a positively indexed atom
362 2
        ref_atom = [n for n in self._graph.nodes if n.label > 0][0]
363 2
        remove_atoms_list = list()
364 2
        for n in self._graph.nodes:
365 2
            if not nx.has_path(self._graph, n, ref_atom):
366 2
                remove_atoms_list.append(n)
367
        # remove the disconnected atoms
368 2
        self._graph.remove_nodes_from(remove_atoms_list)
369 2
        return True
370

371 2
    def add_atom(self, new_atom, new_bond=None, bond_to_atom=None,
372
                 new_label=None, new_bond_label=None):
373
        """
374
        Expand the graph by adding one new atom including relevant bond
375

376
        Parameters
377
        ----------
378
        new_atom: a chemper Atom object
379
        new_bond: a chemper Bond object
380
        bond_to_atom: AtomStorage object
381
            This is where you want to connect the new atom, required if the graph isn't empty
382
        new_label: int
383
            (optional) index for SMIRKS or internal storage if less than zero
384
        new_bond_label: int or float
385
            (optional) index used to track bond storage
386

387
        Returns
388
        -------
389
        AtomStorage: AtomStorage object or None
390
            If the atom was successfully added then the AtomStorage object is returned
391
            None is returned if the atom wasn't able to be added
392
        """
393 2
        if bond_to_atom is None and len(self.get_atoms()) > 0:
394 2
            return None
395

396 2
        new_atom_storage = self.AtomStorage(new_atom, label=new_label)
397 2
        self._graph.add_node(new_atom_storage)
398 2
        if new_label is not None:
399 2
            self.atom_by_label[new_label] = new_atom_storage
400

401
        # This is the first atom added to the graph
402 2
        if bond_to_atom is None:
403 2
            return new_atom_storage
404

405 2
        new_bond_storage = self.BondStorage(new_bond, new_bond_label)
406 2
        self.bond_by_label[new_bond_label] = new_bond_storage
407

408 2
        self._graph.add_edge(bond_to_atom, new_atom_storage, bond = new_bond_storage)
409 2
        return new_atom_storage
410

411

412
# ==============================================================================
413
# TODO: Isn't this the same thing as starting with a ChemPerGraph with mols=None
414
# and smirks_atoms=None as the default?
415
# ==============================================================================
416 2
class ChemPerGraphFromMol(ChemPerGraph):
417
    """
418
    Creates a ChemPerGraph from a chemper Mol object
419
    """
420 2
    def __init__(self, mol, smirks_atoms, layers=0):
421
        """
422
        Parameters
423
        ----------
424
        mol: Mol
425
            this can be a chemper mol or a molecule from any supported toolkit
426
            (currently OpenEye or RDKit)
427
        smirks_atoms: tuple of integers
428
            This is a tuple of the atom indices which will have SMIRKS indices.
429
            For example, if (1,2) is provided then the atom in molecule with indices
430
            1 and 2 will be used to create a SMIRKS with two indexed atoms.
431
        layers: int or 'all'
432
            how many atoms out from the smirks indexed atoms do you wish save (default=0)
433
            'all' will lead to all atoms in the molecule being specified
434
        """
435 2
        ChemPerGraph.__init__(self)
436

437 2
        self.mol = mol_toolkit.Mol(mol)
438 2
        self.atom_by_index = dict()
439 2
        self._add_smirks_atoms(smirks_atoms)
440 2
        keys = list(self.atom_by_label.keys())
441 2
        for smirks_key in keys:
442 2
            atom_storage = self.atom_by_label[smirks_key]
443 2
            self._add_layers(atom_storage, layers)
444

445 2
    def _add_smirks_atoms(self, smirks_atoms):
446
        """
447
        private function for adding atoms to the graph
448

449
        Parameters
450
        ----------
451
        smirks_atoms: tuple of integers
452
            This is a tuple of the atom indices which will have SMIRKS indices.
453
        """
454
        # add all smirks atoms to the graph
455 2
        for key, atom_index in enumerate(smirks_atoms, 1):
456 2
            atom1 = self.mol.get_atom_by_index(atom_index)
457 2
            new_atom_storage = self.AtomStorage(atom1, key)
458 2
            self._graph.add_node(new_atom_storage)
459 2
            self.atom_by_label[key] = new_atom_storage
460 2
            self.atom_by_index[atom_index] = new_atom_storage
461
            # Check for bonded atoms already in the graph
462 2
            for neighbor_key, neighbor_index in enumerate(smirks_atoms, 1):
463 2
                if not neighbor_key in self.atom_by_label:
464 2
                    continue
465

466
                # check if atoms are already connected on the graph
467 2
                neighbor_storage = self.atom_by_label[neighbor_key]
468 2
                if nx.has_path(self._graph, new_atom_storage, neighbor_storage):
469 2
                    continue
470

471
                # check if atoms are connected in the molecule
472 2
                atom2 = self.mol.get_atom_by_index(neighbor_index)
473 2
                bond = self.mol.get_bond_by_atoms(atom1, atom2)
474

475 2
                if bond is not None: # Atoms are connected add edge
476 2
                    bond_index = max(neighbor_key, key)-1
477 2
                    bond_storage = self.BondStorage(bond, bond_index)
478 2
                    self.bond_by_label[bond_index] = bond_storage
479 2
                    self._graph.add_edge(new_atom_storage,
480
                                         self.atom_by_label[neighbor_key],
481
                                         bond=bond_storage)
482

483
    # TODO: I could probably do this with a while loop, is that better?
484 2
    def _add_layers(self, atom_storage, add_layer):
485
        """
486
        private function for expanding beyond the initial SMIRKS atoms.
487
        For now this is recursive so the input is:
488

489
        Parameters
490
        ----------
491
        atom_storage: AtomStorage object
492
            atom whose's neighbors you currently need to add
493
        add_layer: int
494
            how many more layers need to be added
495
        """
496 2
        if add_layer == 0:
497 2
            return
498

499 2
        new_label = min(1, atom_storage.label) - 1
500

501 2
        for new_atom in atom_storage.atom.get_neighbors():
502 2
            if new_atom.get_index() in self.atom_by_index:
503 2
                continue
504

505 2
            new_bond = self.mol.get_bond_by_atoms(atom_storage.atom, new_atom)
506 2
            new_storage = self.add_atom(new_atom, new_bond, atom_storage,
507
                                        new_label, new_label)
508 2
            self.atom_by_index[new_atom.get_index()] = new_storage
509 2
            if add_layer == 'all':
510 2
                self._add_layers(new_storage, add_layer)
511 2
            elif add_layer > 1:
512 2
                self._add_layers(new_storage, add_layer-1)

Read our documentation on viewing source code .

Loading