-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtoy_implementation.py
361 lines (302 loc) · 11.5 KB
/
toy_implementation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import numpy as np
import networkx as nx
from scipy.special import expit # Logistic function
from scipy.optimize import minimize
import random
random.seed(42)
debug = True
########################################################################################################################
# Toy data
# Genes
genes = ['G1', 'G2']
# Diseases
diseases = ['D1', 'D2', 'D3']
# Symptoms
symptoms = ['S1', 'S2', 'S3']
# Frequency categories and their midpoints
frequency_categories = {
'always_present': 1.00,
'very_frequent': 0.90,
'frequent': 0.55,
'occasional': 0.17,
'rare': 0.02
}
# Mapping from genes to diseases
delta = {
'G1': {'D1'},
'G2': {'D2', 'D3'}
}
# Symptom frequencies for diseases
# (disease, symptom): frequency_category
disease_symptom_freq = {
('D1', 'S1'): 'very_frequent',
('D1', 'S2'): 'occasional',
('D2', 'S2'): 'frequent',
('D2', 'S3'): 'rare',
('D3', 'S1'): 'always_present',
('D3', 'S3'): 'frequent'
}
# Build X: set of (disease, symptom, frequency_category)
X = set()
for (d, s), f in disease_symptom_freq.items():
X.add((d, s, f))
########################################################################################################################
# Functions and mappings
# xi: frequency category to midpoint fraction
xi = frequency_categories
# gamma: maps (g, s) to set of frequency categories
def gamma(g, s):
freqs = set()
for d in delta[g]:
if (d, s) in disease_symptom_freq:
freqs.add(disease_symptom_freq[(d, s)])
return freqs
# phi: maps (g, s) to maximum symptom frequency
def phi(g, s, A, avg_phi):
if (g, s) in A:
freqs = gamma(g, s)
if freqs:
return max([xi[f] for f in freqs])
else:
return avg_phi # If no frequency found, use average
else:
return avg_phi # For (g, s) not in A
# Build A: set of (g, s)
A = set()
for g in genes:
for d in delta[g]:
for s in symptoms:
if (d, s) in disease_symptom_freq:
A.add((g, s))
# Build B: shuffle the second components of A
B = set()
symptoms_shuffled = symptoms.copy()
random.shuffle(symptoms_shuffled)
for (g, s), s_shuffled in zip(A, symptoms_shuffled):
if (g, s_shuffled) not in A:
B.add((g, s_shuffled))
# Ensure B is disjoint from A and same size
B = set(list(B)[:len(A)])
# y_gs: indicator function
def y(g, s):
return 1 if (g, s) in A else 0
# Calculate avg_phi for use in phi function
# TODO: Steve, is this the correct average?
avg_phi = np.mean(list(xi.values()))
# sigma: maps g to set of symptoms in A and B
def sigma(g):
return {s for s in symptoms if (g, s) in A.union(B)}
# psi: per-gene weight normalization (equation 13)
def psi(g):
return sum([phi(g, s, A, avg_phi) for s in sigma(g)])
########################################################################################################################
# Dead simple knowledge graph
G = nx.DiGraph()
# Add nodes: genes, diseases, symptoms
G.add_nodes_from(genes, type='gene')
G.add_nodes_from(diseases, type='disease')
G.add_nodes_from(symptoms, type='symptom')
# Add edges between genes and diseases
for g in genes:
for d in delta[g]:
G.add_edge(g, d, predicate='gene_associated_with_disease')
# Add edges between diseases and symptoms
for (d, s), f in disease_symptom_freq.items():
G.add_edge(d, s, predicate='disease_has_symptom')
# For simplicity, we will define predicates
P = ['gene_associated_with_disease', 'disease_has_symptom']
predicate_indices = {p: i for i, p in enumerate(P)}
########################################################################################################################
# Model implementation
# Number of nodes
N = G.number_of_nodes()
nodes = list(G.nodes())
node_indices = {node: i for i, node in enumerate(nodes)}
# Number of predicates
P_size = len(P)
# Adjacency matrices for each predicate
A_tilde = np.zeros((P_size, N, N))
for u, v, data in G.edges(data=True):
p = data['predicate']
p_idx = predicate_indices[p]
u_idx = node_indices[u]
v_idx = node_indices[v]
A_tilde[p_idx, u_idx, v_idx] = 1
# Initial weights for predicates
w = np.random.rand(P_size)
# Baseline offset
f = -1.0 # Initialized to a negative value
#f = 0
# Hyperparameters
a = 0
b = 0
c = 0
max_path_len = 4 # Maximum path length
# List of all genes
Gene_list = genes.copy()
# Initialize q for each gene
Q = {g: np.random.rand(N) for g in Gene_list}
# Functions to compute z_gs and y_hat_gs
def compute_z_gs(q_g, w, f, s_idx, g_idx):
diag_q = np.diag(q_g)
A_weighted = sum(w[p] * A_tilde[p] for p in range(P_size))
M = diag_q @ A_weighted # in the sum under the power
z = np.zeros_like(M) # the summand
#M_power = M.copy()
#for l in range(2, max_path_len + 1):
# M_power = M_power @ M
# z += M_power
# potentially more efficient (as it might use exponentiation by squaring) TODO: check this
for l in range(2, max_path_len + 1):
z += np.linalg.matrix_power(M, l)
#z_gs = z[s_idx, g_idx] + f # TODO: this is reversed
z_gs = z[g_idx, s_idx] + f # This is correct for the directed graph
return z_gs
def compute_y_hat(z_gs):
return expit(z_gs) # (inverse logit) Logistic function
########################################################################################################################
# Model training
# Collect all (g, s) pairs from A and B
all_pairs = list(A.union(B))
# Build mapping from genes to indices
gene_indices = {g: node_indices[g] for g in genes}
symptom_indices = {s: node_indices[s] for s in symptoms}
# Flatten Q into a vector for optimization
def flatten_Q(Q_dict):
return np.concatenate([Q_dict[g] for g in Gene_list])
def unflatten_Q(q_flat):
Q_dict = {}
n = N
for i, g in enumerate(Gene_list):
Q_dict[g] = q_flat[i*n:(i+1)*n]
return Q_dict
# Objective function to minimize: equation (14)
def objective(params):
# Unpack parameters. params is of the form [q1, q2, ..., q(N*gene_list), w1, w2, ..., w(num_predicates), f]
n_q = N * len(Gene_list)
q_flat = params[:n_q]
w = params[n_q:n_q+P_size]
f = params[-1]
Q_dict = unflatten_Q(q_flat)
loss = 0
for g in Gene_list:
q_g = Q_dict[g]
g_idx = gene_indices[g]
psi_g = psi(g)
sum_loss = 0
#for s in sigma(g):
for s in symptoms:
s_idx = symptom_indices[s]
z_gs = compute_z_gs(q_g, w, f, s_idx, g_idx)
y_hat = compute_y_hat(z_gs)
y_true = y(g, s)
#if debug:
# print(f"y_hat({g}, {s}) = {y_hat}, y_true({g}, {s}) = {y_true}")
phi_gs = phi(g, s, A, avg_phi)
cel = - (y_true * np.log(y_hat + 1e-15) + (1 - y_true) * np.log(1 - y_hat + 1e-15))
sum_loss += phi_gs * cel
reg_q = (a / N) * np.sum(np.abs(q_g)) + (b / N) * np.sqrt(np.sum(q_g ** 2)) # regularization terms
loss += (sum_loss / psi_g) + reg_q
reg_w = (c / P_size) * np.sum(w ** 2) # regularization term from equation (16)
total_loss = (loss / len(Gene_list)) + reg_w
return total_loss
# Initial parameters
q_initial = flatten_Q(Q)
params_initial = np.concatenate([q_initial, w, [f]])
# Enfore parameters are nonnegative
# Number of q parameters
n_q = N * len(Gene_list)
# Bounds for q parameters (non-negative)
bounds_q = [(0, None) for _ in range(n_q)]
# Bounds for w parameters (non-negative)
bounds_w = [(0, None) for _ in range(P_size)]
# Bounds for f (unbounded)
#bounds_f = [(-np.inf, -1)]
bounds_f = [(-4, -4)]
# Combine all bounds
bounds = bounds_q + bounds_w + bounds_f
# Optimization
result = minimize(objective, params_initial, method='L-BFGS-B', bounds=bounds)
#result = minimize(objective, params_initial, method='L-BFGS-B', bounds=bounds, options={'ftol': 1e-10, 'gtol': 1e-10,
# 'eps': 1e-10, 'iprint':99})
# Extract optimized parameters
optimized_params = result.x
q_optimized_flat = optimized_params[:N * len(Gene_list)]
w_optimized = optimized_params[N * len(Gene_list):N * len(Gene_list) + P_size]
f_optimized = optimized_params[-1]
Q_optimized = unflatten_Q(q_optimized_flat)
# Now, check how well the y_hat's match the y's
# Compute y_hat for each (g, s) pair
if debug:
y_hats = {}
y_true = {}
for g in Gene_list:
q_g = Q_optimized[g]
#for s in sigma(g):
for s in symptoms:
s_idx = symptom_indices[s]
g_idx = gene_indices[g]
z_gs = compute_z_gs(q_g, w_optimized, f_optimized, s_idx, g_idx)
y_hats[(g, s)] = compute_y_hat(z_gs)
y_true[(g, s)] = y(g, s)
print(f"y_hat({g}, {s}) = {y_hats[(g, s)]}, y_true({g}, {s}) = {y_true[(g, s)]}")
########################################################################################################################
# View intermediate node weights
def get_top_node_weights(g, Q_dict, top_k=3):
q_g = Q_dict[g]
# Get indices of top_k nodes
top_indices = np.argsort(q_g)[::-1][:top_k]
top_nodes = [nodes[i] for i in top_indices]
return top_nodes
# Example prediction for gene 'G1'
predicted_nodes_G1 = get_top_node_weights('G1', Q_optimized)
print(f"Node weights for gene G1: {predicted_nodes_G1}")
# Example prediction for gene 'G2'
predicted_nodes_G2 = get_top_node_weights('G2', Q_optimized)
print(f"Node weights for gene G2: {predicted_nodes_G2}")
########################################################################################################################
# Make predictions
# Function to re-optimize q_g for prediction
def predict_q_g(g, w_opt, f_opt):
# Initialize q_g with random values
q_g_initial = np.random.rand(N)
# Bounds for q_g (non-negative)
bounds_q_g = [(0, None) for _ in range(N)]
# Symptoms associated with gene g (sigma'(g))
sigma_prime_g = {s for s in symptoms if (g, s) in A}
# Indices mapping
g_idx = gene_indices[g]
s_indices = [symptom_indices[s] for s in sigma_prime_g]
# psi_g for normalization
psi_g = sum([phi(g, s, A, avg_phi) for s in sigma_prime_g])
# Objective function to minimize for q_g
def objective_q(q_g):
loss = 0
# q_g = np.clip(q_g, 0, None) # Ensure q_g >= 0, not needed as I have bounds in the minimization
for s_idx in s_indices:
z_gs = compute_z_gs(q_g, w_opt, f_opt, s_idx, g_idx)
y_hat = compute_y_hat(z_gs)
y_true = 1 # Setting y_{g,s} = 1
phi_gs = phi(g, nodes[s_idx], A, avg_phi)
cel = - np.log(y_hat + 1e-15) # Since y_true = 1
loss += phi_gs * cel
reg_q = (a / N) * np.sum(np.abs(q_g)) + (b / N) * np.sqrt(np.sum(q_g ** 2)) # regularization terms
total_loss = (loss / psi_g) + reg_q
return total_loss
# Optimize q_g
result_q = minimize(objective_q, q_g_initial, method='L-BFGS-B', bounds=bounds_q_g)
q_g_optimized = result_q.x
return q_g_optimized
# Function to predict intermediate nodes for gene g
def predict_intermediate_nodes(g, w_opt, f_opt, top_k=6):
q_g_optimized = predict_q_g(g, w_opt, f_opt)
# Get indices of top_k nodes
top_indices = np.argsort(q_g_optimized)[-top_k:][::-1]
top_nodes = [nodes[i] for i in top_indices]
return top_nodes
# Example prediction for gene 'G1'
predicted_nodes_G1 = predict_intermediate_nodes('G1', w_optimized, f_optimized)
print(f"Predicted intermediate nodes for gene G1: {predicted_nodes_G1}")
# Example prediction for gene 'G2'
predicted_nodes_G2 = predict_intermediate_nodes('G2', w_optimized, f_optimized)
print(f"Predicted intermediate nodes for gene G2: {predicted_nodes_G2}")