forked from RUBi-ZA/MD-TASK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalc_correlation.py
executable file
·164 lines (110 loc) · 4.56 KB
/
calc_correlation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python
#
# Calculate correlation in MD trajectory
#
# Script distributed under GNU GPL 3.0
#
# Author: Caroline Ross
# Date: 17-11-2016
from matplotlib import cm
import numpy as np
from lib.cli import CLI
from lib.utils import Logger
from lib.trajectory import load_trajectory
import argparse, math, matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def parse_traj(traj, topology=None, step=1, selected_atoms=["CA"], lazy_load=False):
traj = load_trajectory(traj, topology, step, lazy_load)[0]
residues = {}
for frame in traj:
for atom in frame.topology.atoms:
if atom.name in selected_atoms:
res = atom.residue.resSeq
ac = frame.xyz[0, atom.index]
co_ords = [ac[0], ac[1], ac[2]]
if res in residues:
residues[res].append(co_ords)
else:
residues[res] = [co_ords]
return residues
def mean_dot(m1, m2, size):
DOT = np.zeros(size)
for t in range(size):
DOT[t] = np.dot(m1[t],m2[t])
return np.mean(DOT)
def correlate(residues):
sorted_residues = sorted(residues.keys())
num_trajectories = len(residues[sorted_residues[0]])
num_residues = len(residues)
correlation = np.zeros((num_residues, num_residues))
for a, key_a in enumerate(sorted_residues):
i = residues[key_a]
resI = np.array(i)
meanI = np.tile((np.mean(resI, 0)),(num_trajectories, 1))
idelta = resI - meanI;
magnitudeI = math.sqrt(mean_dot(idelta, idelta, num_trajectories))
for b, key_b in enumerate(sorted_residues):
j = residues[key_b]
resJ = np.array(j)
meanJ = np.tile((np.mean(resJ, 0)),(num_trajectories, 1))
jdelta = resJ - meanJ
magnitudeJ = math.sqrt(mean_dot(jdelta, jdelta, num_trajectories))
meanDotIJ = mean_dot(idelta, jdelta, num_trajectories)
magProd = magnitudeI * magnitudeJ
correlation[a,b] = meanDotIJ/magProd
return correlation
def plot_map(correlation, title, output_prefix):
M = np.array(correlation)
ax = plt.subplots()[1]
colors = [('white')] + [(cm.jet(i)) for i in range(40,250)]
new_map = matplotlib.colors.LinearSegmentedColormap.from_list('new_map', colors, N=300)
heatmap = ax.pcolor(M, cmap=new_map, vmin=-1, vmax=1)
fig = plt.gcf()
ax.set_frame_on(False)
ax.grid(False)
plt.xticks(rotation=90)
# Turn off all the ticks
ax = plt.gca()
for t in ax.xaxis.get_major_ticks():
t.tick1line.set_visible = False
t.tick2line.set_visible = False
t.label.set_fontsize(8)
for t in ax.yaxis.get_major_ticks():
t.tick1line.set_visible = False
t.tick2line.set_visible = False
t.label.set_fontsize(8)
plt.title(title, fontsize=16)
plt.xlabel('Residue Index', fontsize=12)
plt.ylabel("Residue Index", fontsize=12)
plt.colorbar(heatmap, orientation="vertical")
plt.savefig('%s.png' % output_prefix, dpi=300)
plt.close('all')
def print_correlation(correlation, output_prefix):
with open("%s.txt" % output_prefix, "w") as w:
rows = correlation.shape[0]
cols = correlation.shape[1]
for r in range(rows):
for c in range(cols):
w.write('%s ' % str(correlation[r,c]))
w.write('\n')
def main(args):
log.info("Preparing a trajectory matrix...\n")
traj_matrix = parse_traj(args.trajectory, args.topology, args.step, lazy_load=args.lazy_load)
log.info("Correlating...\n")
correlation = correlate(traj_matrix)
log.info("Plotting heat map...\n")
plot_map(correlation, args.title, args.prefix)
print_correlation(correlation, args.prefix)
log = Logger()
if __name__ == "__main__":
#parse cmd arguments
parser = argparse.ArgumentParser()
#custom arguments
parser.add_argument("--trajectory", help="Trajectory file")
parser.add_argument("--topology", help="Referencce PDB file (must contain the same number of atoms as the trajectory)")
parser.add_argument("--step", help="Size of the step to take when iterating the the trajectory frames", type=int)
parser.add_argument("--lazy-load", help="Iterate through trajectory, loading one frame into memory at a time (memory-efficient for large trajectories)", action='store_true', default=False)
parser.add_argument("--title", help="Title for heatmap", default="Protein")
parser.add_argument("--prefix", help="Prefix for output files", default="correlation")
CLI(parser, main, log)