-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
186 lines (133 loc) · 6.49 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
import os
import molgrid
import argparse
import gc
import pathlib
import time
import torch.nn as nn
import torch.nn.functional as F
import csv
import shutil
from cleanpdb import pdb_clean
from get_centers import get_centers
from types_and_gninatyper import create_gninatype_file,create_types_file
from deeplearningmodels.cbam import ResNet18_CBAM_3D, BasicBlock3D, ChannelAttention3D, SpatialAttention3D
from deeplearningmodels.seresnet import SEResNet, ResidualBlock, SEBlock
from deeplearningmodels.cnn import CNNModel
from deeplearningmodels.resnet18 import ResidualBlock_Resnet18, ResNet18
from deeplearningmodels.densenet import DenseNet3D, DenseBlock, TransitionBlock
current_directory = os.path.dirname(os.path.abspath(__file__))
bestmodels_dir = os.path.join(current_directory, 'bestmodels')
def parse_arguments(args=None):
parser = argparse.ArgumentParser(description='Classify by Ligand Type')
parser.add_argument('-p', '--protein', type=str, required=True, help="Input PDB file")
parser.add_argument('-t', '--trainedpth', type=str, required=True, help="Trained Model (.pth)")
args = parser.parse_args(args)
arg_dict = vars(args)
arg_str = ''
for name, value in arg_dict.items():
if value != parser.get_default(name):
arg_str += f' --{name}={value}'
return args, arg_str
def to_cuda(*models):
return [model.to("cuda") for model in models]
if __name__ == '__main__':
(args, cmdline) = parse_arguments()
project_path = os.path.dirname(os.path.abspath(__file__))
trainedpth = args.trainedpth
trainedpth_dir = os.path.join(bestmodels_dir, trainedpth)
deep_model = torch.load(trainedpth_dir)
deep_model.eval()
protein_file= args.protein
pro_id = protein_file.split("/")[-1].split(".")[0] # take four digit protein id
protein_nowat_file=protein_file.replace('.pdb','_nowat.pdb')
pdb_clean(protein_file,protein_nowat_file) #clean pdb file and remove hetero atoms/non standard residues
os.system('fpocket -f '+ protein_nowat_file) # fpocket
fpocket_dir=os.path.join(protein_nowat_file.replace('.pdb','_out'),'pockets')
fpocket_result_folder = pathlib.Path(fpocket_dir).parent #xxx_nowat_out folder
get_centers(fpocket_dir) #create bary_centers.txt
barycenter_file=os.path.join(fpocket_dir,'bary_centers.txt')
protein_gninatype=create_gninatype_file(protein_nowat_file) # dir of gninatype
class_types=create_types_file(barycenter_file,protein_gninatype) # create bary_centers_ranked.types
types_lines=open(class_types,'r').readlines()
batch_size = len(types_lines)
#avoid cuda out of memory
if batch_size>50:
batch_size=50
gmaker2 = molgrid.GridMaker(binary=False)
dims = gmaker2.grid_dimensions(24)
tensor_shape_2 = (1,)+dims #(1, 24, 48, 48, 48)
inputfile_name = 'inputfile.types'
inputfile_dir = os.path.join(project_path, inputfile_name)
#inputfile_dir = '/content/inputfile.types' # inputfile.types contains X,Y,Z,xxxx_nowat.gninatypes for exampleprovider
formatted_lines = []
# Read the input file
with open(os.path.abspath(class_types), "r") as infile:
for line in infile:
parts = line.strip().split() # Split the line into parts
if len(parts) == 5:
# Extract the filename from the last column
filename_parts = parts[4].split('/')
filename = filename_parts[-1]
# Format the line with extracted values
formatted_line = f"{parts[1]} {parts[2]} {parts[3]} {filename}"
formatted_lines.append(formatted_line)
# Write the formatted lines to the output file
with open(inputfile_dir, "w") as outfile:
for line in formatted_lines:
outfile.write(line + "\n")
e_test_1 = molgrid.ExampleProvider(data_root= project_path, shuffle=False,stratify_receptor=True, balanced = False)
e_test_1.populate(inputfile_dir)
input_tensor_1 = torch.zeros(tensor_shape_2, dtype=torch.float32, device='cuda') #[1, 24, 48, 48, 48]
float_labels_1 = torch.zeros((1,3), dtype=torch.float32, device='cuda')
categorization_prediction = []
for i in range(batch_size):
batch =e_test_1.next_batch()
# extract centers of batch datapoints
batch.extract_labels(float_labels_1)
centers = float_labels_1[:,0:]
for b in range(1):
center = molgrid.float3(float(centers[b][0]),float(centers[b][1]),float(centers[b][2]))
# Update input tensor with b'th datapoint of the batc
gmaker2.forward(center,batch[b].coord_sets[0],input_tensor_1[b])
output = deep_model(input_tensor_1[:,:24])
predicted = torch.argmax(output,dim=1)
categorization_prediction.append(predicted.item())
mapping = {0: 'Other', 1: 'Antagonist', 2: 'Inhibitor', 3: 'Activator', 4: 'Agonist'}
def get_category(prediction):
if prediction == 0:
return "Other"
elif prediction == 1:
return "Antagonist"
elif prediction == 2:
return "Inhibitor"
elif prediction == 3:
return "Activator"
elif prediction == 4:
return "Agonist"
else:
return "None" # Return a None string for any other values
with open(inputfile_dir, 'r') as input_file:
lines = input_file.readlines()
if batch_size == 50:
inputfiledata = [line.strip().split() for line in lines[:50]]
else:
inputfiledata = [line.strip().split() for line in lines]
for i, row in enumerate(inputfiledata):
prediction = categorization_prediction[i]
row.append(get_category(prediction))
protein_id = row[-2] # Access the second-to-last column
protein_id_parts = protein_id.split('_')
if len(protein_id_parts) > 0:
row[-2] = protein_id_parts[0]
csvfilename = "{}.csv".format(pro_id)
with open(csvfilename, 'w', newline='') as output_csv_file:
csv_writer = csv.writer(output_csv_file)
csv_writer.writerow(["x", "y", "z", "protein_id", "LigandType"])
csv_writer.writerows(inputfiledata)
lines.clear()
shutil.rmtree(fpocket_result_folder)
os.remove(protein_nowat_file)
os.remove(protein_gninatype)
os.remove(inputfile_dir)