-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRLdata500.Rmd
314 lines (251 loc) · 11.4 KB
/
RLdata500.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
---
title: "Introduction to dblinkR"
author: "Neil Marchant and Rebecca C. Steorts"
date: "18 March 2020"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Introduction to dblinkR}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
In this vignette, we demonstrate how to use dblinkR to perform Bayesian
entity resolution. dblinkR implements a generative Bayesian model that
jointly performs blocking and entity resolution of multiple databases as
outlined in the following paper:
> Marchant, N. G., Steorts, R. C., Kaplan, A., Rubinstein, B. I. P., & Elazar,
D. N. (2019). d-blink: Distributed End-to-End Bayesian Entity Resolution.
arXiv preprint arXiv:1909.06039.
We will be using a dataset called `RLdata500` that was formerly available in
the RecordLinkage R package (now deprecated). For convenience, we have
included the `RLdata500` dataset in `dblinkR`.
An outline of the vignette is as follows:
1. We load the required packages and connect to Spark
2. We introduce the RLdata500 dataset
3. We demonstrate how to set the d-blink model parameters
4. We perform inference using MCMC
5. We examine MCMC diagnostics
6. We assess the posterior linkage structure and other model parameters
```{r setup, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.width=7,
fig.height=5
)
```
## Loading Required Packages
We begin by loading the required packages. Note that the `dblinkR` package
must be loaded *after* `sparklyr`.
```{r packages, message=FALSE}
# For running dblink
library(sparklyr)
library(dblinkR)
# For generating trace plots
library(stringr)
library(dplyr)
library(tidyr)
library(ggplot2)
# For generating credible interval plots
library(tidybayes)
```
## Connecting to Spark
In this vignette, we will run `dblink` on a local instance of Spark with 2
cores. This is sufficient for testing purposes, however for larger data sets,
we recommend connecting to a Spark deployment. The `sparklyr`
[documentation](https://spark.rstudio.com/deployment/) contains information
about connecting to Spark deployments.
```{r spark-connect, message=FALSE}
sc <- spark_connect(master = "local[2]", version = "2.4.3")
spark_context(sc) %>% invoke("setLogLevel", "WARN")
```
`dblink` requires a location on disk (includes HDFS) to save diagnostics,
posterior samples, and the state of the Markov chain. We refer to this location
as the *project path*, and set it to the current working directory below.
In addition, we must also specify a location on disk to save Spark checkpoints.
```{r set-paths, message=FALSE}
projectPath <- paste0(getwd(), "/") # working directory
spark_set_checkpoint_dir(sc, paste0(projectPath, "checkpoints/"))
```
Now, that we have loaded the required packages and connected to Spark, we can
proceed with the rest of the tutorial!
## Understanding the RLdata500 dataset
RLdata500 is a synthetic dataset that contains 500 personal records with 10
percent duplication. Thus, there are 450 unique individuals represented in
the data and 50 duplicate records. Each record contains the individual's
full name and date of birth, however these attributes may be distorted.
We add an additional column for record identifiers, which are used to specify
the linkage structure later on.
```{r load-data}
records <- RLdata500
records['rec_id'] <- seq_len(nrow(records))
records <- lapply(records, as.character) %>% as.data.frame(stringsAsFactors = FALSE)
# inspect RLdata500
head(RLdata500)
```
## dblink parameters
Here, we specify the parameters for the `dblink` model.
We treat the name-related attributes as string-type with a Levenshtein
similarity function, and the date-related attributes as categorical (with a
constant similarity function). We also place a Beta prior on the distortion
probability for each attribute.
```{r set-attribute-parameters}
distortionPrior <- BetaRV(1, 50)
levSim <- LevenshteinSimFn(threshold = 7, maxSimilarity = 10)
attributeSpecs <- list(
fname_c1 = Attribute(levSim, distortionPrior),
lname_c1 = Attribute(levSim, distortionPrior),
by = CategoricalAttribute(distortionPrior),
bm = CategoricalAttribute(distortionPrior),
bd = CategoricalAttribute(distortionPrior)
)
```
The Beta prior on the distortion probabilities favours low distortion, as
shown below.
```{r plot-distortion-prior}
tibble(prob = seq(from=0, to=1, by=0.01)) %>%
mutate(density = dbeta(prob, distortionPrior@shape1, distortionPrior@shape2)) %>%
ggplot(aes(x = prob, y = density)) + geom_line() +
labs(title = "Prior distribution on the distortion probabilities",
x = "Distortion probability", y = "Density")
```
We can also specify the size of the latent entity population. A larger
population size favours under-linkage, while a smaller population
size favours over-linkge. By default, the population size is set to the
number of records in the data. This is specified by setting the population
size to `NULL`:
```{r set-population-size}
populationSize <- NULL
```
Finally, we specify the partitioner which is used for blocking the entities
to improve scalability. Since the data is small in this case, we can switch
partitioning off by setting the partitioner to `NULL`. For larger data sets,
we recommend using the `KDTreePartitioner` (please refer to the documentation).
```{r set-partitioner}
partitioner <- NULL
```
## Running inference
In the previous section, we specified the model parameters. Now, we are ready
to perform inference. The results will be saved in the project path.
```{r run-inference}
state <- initializeState(sc, records, attributeSpecs, recIdColname = 'rec_id',
partitioner = partitioner, populationSize = populationSize,
randomSeed = 1, maxClusterSize = 10L)
result <- runInference(state, projectPath, sampleSize = 1000,
burninInterval = 5000, thinningInterval = 10)
```
To load results from a previous run, we can use the following function.
```{r, eval=FALSE}
#result <- loadResult(sc, projectPath)
```
## MCMC diagnostics
It's important to review convergence and mixing diagnostics when performing
inference using Markov chain Monte Carlo (MCMC). Various diagnostics are saved
in the `diagnostics.csv` file located in the project path. We can use the
function below to load these into a local tibble.
```{r read-diagnostics}
diagnostics <- loadDiagnostics(sc, projectPath)
```
Below we provide trace plots of various summary statistics.
```{r diagnostic-plots-1}
ggplot(diagnostics, aes(x=iteration, y=numObservedEntities)) +
geom_line() +
labs(title = 'Trace plot: # entity clusters', x = 'Iteration',
y = '# clusters')
diagnostics %>%
select(iteration, starts_with("aggDist")) %>%
gather(attribute, numDistortions, starts_with("aggDist")) %>%
mutate(attribute = str_match(attribute, "^aggDist(.+)")[,2],
numDistortions = numDistortions / nrow(records)) %>%
ggplot(aes(x=iteration, y=numDistortions)) +
geom_line(aes(colour = attribute)) +
labs(title = 'Trace plot: attribute distortion', x = 'Iteration',
y = '% distorted', colour = 'Attribute')
ggplot(diagnostics, aes(x=iteration, y=logLikelihood)) +
geom_line() +
labs(title = 'Trace plot: log-likelihood (unnormalized)', x = 'Iteration',
y = 'log-likelihood')
```
Additional diagnostic statistics can be computed from the samples of the
linkage structure, which we refer to as the *linkage chain*.
```{r diagnostic-plots-2}
linkageChain <- loadLinkageChain(sc, projectPath)
clustSizeDist <- clusterSizeDistribution(linkageChain)
ggplot(clustSizeDist, aes(x=iteration, y=frequency)) +
geom_line(aes(colour = clusterSize)) +
labs(title = 'Trace plot: cluster size distribution', x = 'Iteration',
y = 'Frequency', colour = 'Cluster size')
partSizes <- partitionSizes(linkageChain)
ggplot(partSizes, aes(x=iteration, y=size)) +
geom_line(aes(colour = partitionId)) +
labs(title = 'Trace plot: partition sizes', x = 'Iteration', y = 'Size',
colour = 'Partition')
```
## Evaluation
In this section, we evaluate the predictions of our model and compare against
the ground truth (provided with the synthetic data).
We compute pairwise precision and recall using a point estimate of the
linkage structure. We also inspect the posterior distribution over two
important summary statistics:
* the number of unique individuals in the data, and
* the distortion level of each attribute.
### Linkage quality
We compute a point estimate of the posterior linkage structure/clustering
using the *shared most probable maximal matching sets method*
(Steorts et al., 2016). This method creates a globally consistent estimate
that obeys transitivity constraints.
Note that we must use the `collect` function to retrieve the data from Spark
into a local tibble.
```{r smpc}
predClusters <- sharedMostProbableClusters(linkageChain) %>% collect()
```
Next we assess the quality of the predicted linkage structure by
computing pairwise precision and recall using functions from the
[`clevr`](https://github.com/cleanzr/clevr) package.
```{r pairwise-metrics, message=FALSE}
library(clevr)
predMatches <- clusters_to_pairs(predClusters)
trueMatches <- membership_to_pairs(identity.RLdata500, elem_ids = records$rec_id)
numRecords <- nrow(records)
measures <- eval_report_pairs(trueMatches, predMatches, numRecords*(numRecords - 1)/2)
print(measures)
```
### Posterior estimates
We can use the saved posterior samples to infer unknown model parameters.
First, we consider the number of unique entities present in the data.
This summary statistic is stored in `diagnostics$numObservedEntities`.
Below we compute a point estimate based on the median, along with a 95\%
highest density interval.
```{r}
hdiNumEntities <- median_hdi(diagnostics$numObservedEntities)
cat("The predicted number of unique entities in RLdata500 is ",
hdiNumEntities$y, ".\n", sep="")
cat("A 95% HDI is ",
sprintf("[%s,%s]", hdiNumEntities$ymin, hdiNumEntities$ymax), ".\n", sep="")
```
We can also visualize the estimated posterior distribution, while
comparing to the ground truth (red), posterior median (blue) and
95\% highest density interval (green).
```{r posterior-plot}
ggplot(diagnostics) +
geom_histogram(aes(x=numObservedEntities), binwidth = 1) +
geom_vline(aes(xintercept = hdiNumEntities$y), color='blue') +
geom_vline(aes(xintercept = 450), color='red') +
geom_vline(aes(xintercept = hdiNumEntities$ymin), color = 'green') +
geom_vline(aes(xintercept = hdiNumEntities$ymax), color = 'green') +
labs(x = "# of entities", y = "Frequency", title = "Posterior number of observed entities")
```
Second, we consider the distortion level observed in each attribute, averaged
across all of the records. This information is also stored in
`diagnostics`. Below we plot a point estimate of the distortion level for each
attribute based on the median, along with a 95\% highest density interval.
```{r posterior-attribute-distortion}
diagnostics %>%
select(iteration, starts_with("aggDist")) %>%
gather(attribute, numDistortions, starts_with("aggDist")) %>%
transmute(attribute = str_match(attribute, "^aggDist(.+)")[,2],
percDistorted = numDistortions / nrow(records) * 100) %>%
group_by(attribute) %>%
ggplot(aes(y = attribute, x = percDistorted)) +
stat_pointinterval(point_interval = median_hdi, .width = 0.95) +
labs(y = "Attribute", x = "Distortion level (%)", title = "Posterior distortion level by attribute")
```