1
#!/usr/bin/env python
2 60
"""
3
basic_plot.py
4

5
Basic plots (lines or bars) of information in SD tags.
6

7
Version:    Apr 12 2019
8
By:         Victoria T. Lim
9

10

11
TODO:
12
- plot diff tags in same file
13

14
"""
15 60
import numpy as np
16 60
import itertools
17

18 60
import matplotlib.pyplot as plt
19

20 60
import quanformer.proc_tags as pt
21 60
import quanformer.reader as reader
22

23

24 60
def basic_plot(infile, tag, style, molname=None, take_relative=False, har_to_kcal=False):
25
    """
26
    TODO
27

28
    Parameters
29
    ----------
30
    infile : string
31
        Name of SDF file with information in SD tags.
32
    tag : string
33
        Full tag string directly as listed in the SD file.
34
    style : string
35
        plot style. can be 'scatter', 'line', or 'bar'
36
        TODO
37
    take_relative : Boolean
38
        subtract lowest value
39
    har_to_kcal : Boolean
40
        multiply data in Hartrees by 627.5095 to yield kcal/mol
41

42
    """
43
    # Open molecule file.
44 60
    mols = reader.read_mols(infile)
45

46 60
    for i, mol_i in enumerate(mols):
47 60
        if molname is not None and mol_i.GetTitle() != molname:
48 0
            continue
49

50
        # get array of all conformer data of this mol
51 60
        try:
52 60
            data_array = np.fromiter(pt.get_sd_list(mol_i, datum='', taglabel=tag), dtype=np.float64)
53 0
        except ValueError:
54 0
            data_array = np.asarray([np.nan])*mol_i.NumConfs()
55

56
        # exclude conformers for which job did not finish (nan)
57 60
        nanIndices = np.argwhere(np.isnan(data_array))
58 60
        for j in reversed(nanIndices):  # loop in reverse to delete correctly
59 0
            data_array = np.delete(data_array, j)
60

61 60
        if take_relative:
62 60
            data_array = data_array - np.amin(data_array)
63 60
        if har_to_kcal:
64 60
            data_array = 627.5095*data_array
65

66
        # generate plot
67 60
        plt.plot(data_array)
68 60
        plt.grid()
69 60
        plt.title(mol_i.GetTitle()+'\n'+tag, fontsize=14)
70 60
        plt.savefig(f'output_{i}.png', bbox_inches='tight')
71 60
        plt.show()
72

73

74 60
def combine_files_plot(infile, figname='combined.png', molname=None, verbose=False, take_relative=False, har_to_kcal=False):
75
    """
76
    TODO
77

78
    This only supports plotting of ONE specified molecule across different files.
79

80
    Note on take_relative:
81
        [1] Subtracting global minimum (single value) from all energies
82
        doesn't work since everything is still on different scale.
83
    subtract: (1) first conformer of each?, (2) global minimum?, (3) minimum of each?
84

85
    Parameters
86
    ----------
87
    infile : str
88
        Filename with information on the files to read in, and
89
        the SDF tags to be extracted from each. Columns are:
90
        (1) QM method/basis, (2) sdf file, (3) tag key in sdf (like 'QM spe'),
91
        (4) arbitrary label for plotting. Separate columns by comma.
92
    molname
93
    verbose
94

95
    """
96 60
    wholedict = reader.read_text_input(infile)
97

98 60
    numFiles = len(wholedict)
99 60
    xarray = []
100 60
    yarray = []
101 60
    labels = []
102 60
    titles = []
103 60
    for i in wholedict:
104 60
        print("Reading molecule(s) from file: ", wholedict[i]['fname'])
105 60
        mols = reader.read_mols(wholedict[i]['fname'])
106 60
        qmethod, qbasis = reader.separated_theory(wholedict[i]['theory'])
107 60
        short_tag = wholedict[i]['tagkey']
108

109 60
        for j, mol_j in enumerate(mols):
110 60
            if molname is not None and mol_j.GetTitle() != molname:
111 60
                continue
112 60
            data_array = np.array(list(map(float,
113
                pt.get_sd_list(mol_j, short_tag, 'Psi4', qmethod, qbasis))))
114

115 60
        if take_relative:
116 60
            data_array = data_array - data_array[0]
117
            #data_array = data_array/data_array[0]
118 60
        if har_to_kcal:
119 60
            data_array = 627.5095*data_array
120

121 60
        titles.append(mol_j.GetTitle())
122 60
        labels.append(wholedict[i]['label'])
123 60
        yarray.append(data_array)
124 60
        xarray.append(range(len(data_array)))
125

126 60
    if verbose:
127 60
        header = '{}\n'.format(molname)
128 60
        for l in labels:
129 60
            header += ("%s\n" % l)
130 60
        xydata = np.vstack((xarray[0], yarray)).T
131 60
        np.savetxt('combined.dat', xydata, delimiter='\t', header=header,
132
            fmt=' '.join(['%i'] + ['%10.4f']*numFiles))
133

134
    # letter labels for x-axis
135 60
    num_confs = len(xarray[0])
136 60
    letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
137 60
    rpt = int((num_confs / 26) + 1)
138 60
    xlabs = [''.join(i)
139
             for i in itertools.product(letters, repeat=rpt)][:num_confs]
140

141 60
    fig = plt.figure()
142 60
    ax = fig.add_subplot(111)
143 60
    xlabel='conformer'
144 60
    ylabel="energy"
145

146
    # vtl print max range of relative energies
147 60
    conf_then_file = np.array(yarray).T
148 60
    ranges = []
149 60
    for c in conf_then_file:
150 60
        c_spread = max(c)-min(c)
151 60
        ranges.append(c_spread)
152 60
    print(f'mol {molname} max range: {max(ranges)}')
153

154 60
    ax.set_prop_cycle(plt.cycler('color', plt.cm.rainbow(np.linspace(0, 1, len(yarray)))))
155 60
    for i, (xs, ys) in enumerate(zip(xarray,yarray)):
156 60
        plt.plot(xs, ys, '-o', lw=0.8, label=labels[i])
157

158
    # publication view
159
#    plt.ylabel(ylabel,fontsize=8)
160
#    plt.xlabel(xlabel,fontsize=8)
161
#    plt.legend(bbox_to_anchor=(0.08,1.05),loc=3,fontsize=8)
162
#    fig.set_size_inches(3.37,1.7)
163

164
    # standard view
165 60
    plt.ylabel(ylabel,fontsize=14)
166 60
    plt.xlabel(xlabel,fontsize=14)
167 60
    plt.xticks(list(range(num_confs)), xlabs)
168 60
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2)
169

170 60
    plt.title(molname)
171 60
    plt.grid()
172 60
    plt.savefig(figname, bbox_inches='tight',dpi=300)
173 60
    plt.show()

Read our documentation on viewing source code .

Loading