-
Notifications
You must be signed in to change notification settings - Fork 0
/
TFCox.py
142 lines (120 loc) · 6.1 KB
/
TFCox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class TFCox():
def __init__(self, seed=42,norm=False,optimizer='Ftrl',l1_ratio=1,lbda=0.0001,
max_it=50,learn_rate=0.001,momentum=0.1,stop_if_nan=True,stop_at_value=False, cscore_metric=False,suppress_warnings=True,verbose=0):
self.max_it = max_it
self.tnan = stop_if_nan
self.tcscore = stop_at_value
self.lr=learn_rate
self.cscore=cscore_metric
np.random.seed(seed)
tf.random.set_seed(seed)
self.op = optimizer
self.l1r = l1_ratio
self.lbda=lbda
self.norm = norm
self.verbose=verbose
if suppress_warnings == True:
import warnings
warnings.filterwarnings('ignore')
if optimizer == 'Adam':
self.opt = Adam(learn_rate)
if optimizer == 'SGD':
self.opt = SGD(learn_rate)
if optimizer == 'SGDmomentum':
self.opt = SGD(learn_rate,momentum=momentum)
if optimizer == 'RMSprop':
self.opt = RMSprop(learn_rate)
if optimizer == 'Ftrl':
self.opt = Ftrl(learn_rate,l1_regularization_strength=l1_ratio*lbda*100, l2_regularization_strength=(1-l1_ratio)*lbda*100)
#self.opt = Ftrl(learn_rate,l1_regularization_strength=1, l2_regularization_strength=0)
def coxloss(self, state):
def loss(y_true, y_pred):
return -K.mean((y_pred - K.log(tf.math.cumsum(K.exp(y_pred),reverse=True,axis=0)+0.0001))*state,axis=0)
return loss
def cscore_metric(self, state):
def loss(y_true,y_pred):
con = 0
dis = 0
for a in range(len(y_pred)):
for b in range(a+1,len(y_pred)):
if (y_pred[a]>y_pred[b]) & (y_pred[a]*state[a]!=0):
con+=1
elif (y_pred[a]<y_pred[b]) & (y_pred[a]*state[a]!=0):
dis+=1
return con/(con+dis)
return loss
def fit(self, X,state,time):
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()
K.clear_session()
self.time = np.array(time)
self.newindex = pd.DataFrame(self.time).sort_values(0).index
self.X = (pd.DataFrame(np.array(X)).reindex(self.newindex))
self.state = np.array(pd.DataFrame(np.array(state)).reindex(self.newindex))
self.time = np.array(pd.DataFrame(np.array(time)).reindex(self.newindex))
inputsx = Input(shape=(self.X.shape[1],))
state = Input(shape=(1,))
# if self.op == 'Ftrl':
# if self.norm==True:
# out = BatchNormalization()(inputsx)
# out = Dense(1,activation='linear', use_bias=False)(out)
# else:
# out = Dense(1,activation='linear', use_bias=False)(inputsx)
# else:
if self.norm==True:
out = BatchNormalization()(inputsx)
if self.op!='Ftrl':
out = Dense(1,activation='linear',kernel_initializer=Zeros(),
kernel_regularizer=l1_l2(self.lbda*self.l1r,self.lbda*(1-self.l1r)),
use_bias=False)(out)
else:
out = Dense(1,activation='linear',kernel_initializer=Zeros(),
use_bias=False)(out)
else:
if self.op!='Ftrl':
out = Dense(1,activation='linear',kernel_initializer=Zeros(),
kernel_regularizer=l1_l2(self.lbda*self.l1r,self.lbda*(1-self.l1r)),
use_bias=False)(inputsx)
else:
out = Dense(1,activation='linear',kernel_initializer=Zeros(),
use_bias=False)(inputsx)
model = Model(inputs=[inputsx, state], outputs=out)
if (self.tcscore != False) or (self.cscore==True) :
model.compile(optimizer=self.opt ,
loss=self.coxloss(state) , metrics=[self.cscore_metric(state)],
experimental_run_tf_function=False)
else:
model.compile(optimizer=self.opt ,
loss=self.coxloss(state) ,
experimental_run_tf_function=False)
self.model=model
if self.verbose==1:
print(self.model.summary())
self.loss_history_ = []
for its in range(self.max_it):
self.temp_weights = self.model.get_weights()
tr = self.model.train_on_batch([self.X, self.state],np.zeros(self.state.shape))
self.loss_history_.append(tr)
if self.verbose == 1:
if (self.tcscore != False) or (self.cscore==True) :
print('loss:', self.loss_history_[-1][0],' C-score: ',self.loss_history_[-1][1] )
else:
print('loss:', self.loss_history_[-1] )
if self.tcscore != False:
if self.loss_history_[-1][1]>=self.tcscore:
print('Terminated early because concordance >=' +str(self.tcscore)+ ' as set by stop_at_value flag.')
break
if (self.tcscore != False) or (self.cscore==True) :
if (math.isnan(self.loss_history_[-1][0]) or math.isinf(self.loss_history_[-1][0])) and self.tnan:
self.model.set_weights(self.temp_weights)
print('Terminated because weights == nan or inf, reverted to last valid weight set')
break
else:
if (math.isnan(self.loss_history_[-1]) or math.isinf(self.loss_history_[-1])) and self.tnan:
self.model.set_weights(self.temp_weights)
print('Terminated because weights == nan or inf, reverted to last valid weight set')
break
self.beta_ = self.model.get_weights()[-1]
def predict(self,X):
preds = self.model.predict([X,np.zeros(len(X))])
return preds