-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmixture.py
84 lines (68 loc) · 2.43 KB
/
mixture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
from attr import attrib, attrs
def _to_mixture_weights(rel_weights):
"""
Converts to np.array of floats, normalizes so sum=1.
"""
weights = np.atleast_1d(np.array(rel_weights, dtype=np.float_, copy=True))
weights /= weights.sum()
return weights
def _validate_weights(instance, attribute, value):
assert value.ndim == 1
if len(value) != len(instance.models):
raise ValueError("Weights vector does not match number of models")
def _validate_models(instance, attribute, value):
model_list = value
ndim0 = model_list[0].ndim
for model in model_list:
if model.ndim != ndim0:
raise ValueError(
"Models of different dimensionality supplied")
@attrs
class Mixture(object):
"""
Represents a mixture model
Attributes:
"""
models = attrib()
weights = attrib(convert=_to_mixture_weights,
validator=_validate_weights)
def joint_pdf(self, x):
# Could try to zero-allocate, but getting the dimensions
# right is a little tricky.
# jpdf = np.zeros(len(x),dtype=np.float_)
jpdf = None
for idx, m in enumerate(self.models):
if jpdf is None:
jpdf = self.weights[idx] * m.dist.pdf(x)
else:
jpdf += self.weights[idx] * m.dist.pdf(x)
return jpdf
def joint_sample(self, size, shuffle=True):
component_sample_sizes = np.random.multinomial(n=size,
pvals=self.weights)
component_samples = []
for idx in range(len(self.models)):
sample = self.models[idx].dist.rvs(component_sample_sizes[idx])
component_samples.append(sample)
mixture_sample = np.concatenate(component_samples)
if shuffle:
np.random.shuffle(mixture_sample)
return mixture_sample
def _repr_html_(self):
output=[]
for idx, mdl in enumerate(self.models):
output.append(
"""
<h4>Component {idx}</h4>
<p>Weight: {weight:0.3f}
{model}
</p>
<div style="clear:left;"></div>
""".format(idx=idx, weight=self.weights[idx],
model = mdl._repr_html_())
)
return '\n'.join(output)
@property
def ndim(self):
return self.models[0].ndim