-
Notifications
You must be signed in to change notification settings - Fork 2
/
sib.py
95 lines (71 loc) · 2.52 KB
/
sib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from _sib import *
import sys
import _sib
__module_file__ = _sib.__file__
def marginal(n):
return n.marginal()
def marginals_t(f, t):
'''
returns the marginals of nodes at fixed time
- f: sib.f class function
- t: time
return: dict - probability to be {i : [prob_S, prob_I, prob_R]
'''
M = {}
for i in range(len(f.nodes)):
M[i] = marginal_t(f.nodes[i], t)
return M
def marginal_t(n, t):
'''
returns the marginals of nodes at fixed time
- n: sib.Node class
- t: time
return: list - probability to be [prob_S, prob_I, prob_R]
'''
# we use "-1", marginal_index removes the source times.
ttrue = list(n.times).index(t)-1
M = n.marginal_index(ttrue)
return M
class FactorGraph(_sib.FactorGraph):
def gettest(self, s):
if isinstance(s, int) and s == -1:
return self._fakeobs
elif isinstance(s, Test):
return s
else:
return self.puretest[s]
def __init__(self, params = _sib.Params(_sib.Uniform(1.0), _sib.Exponential(0.5), 0.1, 0.45, 0.0, 0.0),
contacts = [],
observations = [],
tests = [],
times = [],
individuals = []):
self.puretest = [_sib.Test(s==0,s==1,s==2) for s in range(3)]
self._fakeobs = _sib.Test(1,1,1)
tests = [(i, self.gettest(s), t) for (i,s,t) in observations+tests]
_sib.FactorGraph.__init__(self, params = params, contacts = contacts, tests = tests, individuals = individuals)
def append_observation(self, i, s, t):
_sib.FactorGraph.append_observation(self, i, self.gettest(s), t)
def iterate(self,
maxit=100,
tol=1e-3,
damping=0.0,
learn=False,
callback=False
):
newline = False
if callback == False:
callback = lambda t,e,f : print(f"sib.iterate(damp={damping}): {t}/{maxit} {e:1.3e}/{tol}", end=' \r', flush=True)
newline = True
if callback == None:
callback = lambda t,e,f : None
for t in range(maxit):
err = self.update(damping=damping, learn=learn)
if callback(t, err, self) == False:
break;
if err < tol:
break;
if newline:
print()
def iterate(f, maxit=100, tol=1e-3, damping=0.0, learn=False, callback=False):
return f.iterate(maxit=maxit, tol=tol, damping=damping, learn=learn, callback=callback)