MALTS is a learning-to-match method for interpretable causal inference proposed by Harsh Parikh, Cynthia Rudin and Alexander Volfovsky in their 2019 paper "MALTS: Matching After Learning to Stretch". This repository contains PyMALTS
and MALTS
, Python and R implementations, respectively, of the MALTS algorithm.
Python
PyMALTS
is a Python3 library and it requires numpy
, pandas
, scipy
, scikit-learn
, matplotlib
and seaborn
.
R
MALTS
requires nloptr
for distance metric learning.
Python
pip install git+https://github.com/almost-matching-exactly/MALTS.git
R
devtools::install_github('https://github.com/almost-matching-exactly/MALTS',
subdir = 'RMALTS/MALTS')
Python
import pymalts
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(0)
sns.set()
R
library(MALTS)
Python
df = pd.read_csv('example/example_data.csv',index_col=0)
print(df.shape)
df.head()
(2500, 20)
X1 | X2 | X3 | X4 | X5 | X6 | X7 | X8 | X9 | X10 | X11 | X12 | X13 | X14 | X15 | X16 | X17 | X18 | outcome | treated | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1355 | 1.881335 | 1.684164 | 0.532332 | 2.002254 | 1.435032 | 1.450196 | 1.974763 | 1.321659 | 0.709443 | -1.141244 | 0.883130 | 0.956721 | 2.498229 | 2.251677 | 0.375271 | -0.545129 | 3.334220 | 0.081259 | -15.679894 | 0 |
1320 | 0.666476 | 1.263065 | 0.657558 | 0.498780 | 1.096135 | 1.002569 | 0.881916 | 0.740392 | 2.780857 | -0.765889 | 1.230980 | -1.214324 | -0.040029 | 1.554477 | 4.235513 | 3.596213 | 0.959022 | 0.513409 | -7.068587 | 0 |
1233 | -0.193200 | 0.961823 | 1.652723 | 1.117316 | 0.590318 | 0.566765 | 0.775715 | 0.938379 | -2.055124 | 1.942873 | -0.606074 | 3.329552 | -1.822938 | 3.240945 | 2.106121 | 0.857190 | 0.577264 | -2.370578 | -5.133200 | 0 |
706 | 1.378660 | 1.794625 | 0.701158 | 1.815518 | 1.129920 | 1.188477 | 0.845063 | 1.217270 | 5.847379 | 0.566517 | -0.045607 | 0.736230 | 0.941677 | 0.835420 | -0.560388 | 0.427255 | 2.239003 | -0.632832 | 39.684984 | 1 |
438 | 0.434297 | 0.296656 | 0.545785 | 0.110366 | 0.151758 | -0.257326 | 0.601965 | 0.499884 | -0.973684 | -0.552586 | -0.778477 | 0.936956 | 0.831105 | 2.060040 | 3.153799 | 0.027665 | 0.376857 | -1.221457 | -2.954324 | 0 |
R
df <- read.csv('example/example_data.csv', row.names = 1)
# dim(df)
# head(df)
Python
# Default settings
m = pymalts.malts_mf( outcome='outcome', treatment='treated', data=df)
R
# Default settings
m <- MALTS(data = df, outcome = 'outcome', treatment = 'treated')
Python
Matched Group matrix (MG_matrix) is NxN matrix with each row corresponding to each query unit and each column corresponds to matched units. Cell (i,j) in the matrix corresponds to the weight of unit j in the matched group of unit i. The weight corresponds to the numbers of times a unit is included in a matched group across M-folds.
m.MG_matrix
1355 | 1320 | 1233 | 706 | 438 | 184 | 1108 | 1612 | 816 | 131 | ... | 1181 | 1698 | 916 | 59 | 2267 | 1520 | 1408 | 909 | 603 | 2285 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1355 | 4.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 3.0 |
1320 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 4.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
1233 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
706 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 |
438 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
1520 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 | 0.0 |
1408 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 | 0.0 |
909 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 | 0.0 |
603 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 | 0.0 |
2285 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 4.0 |
2500 rows × 2500 columns
Visualizing matched group matrix as heatmap
fig = plt.figure(figsize=(10,10))
sns.heatmap(m.MG_matrix)
Accessing the matched group for an example unit with index equal to "1" and visualizing the weights as bar-chart
MG1 = m.MG_matrix.loc[1] #matched group for unit "1"
MG1[MG1>1].sort_values(ascending=False).plot(kind='bar',figsize=(20,5)) #Visualizing all the units matched to unit 1 more than once
R
Matched groups can be found in the MGs
entry of the output of MALTS
.
Entries correspond to different folds being used for distance metric estimation vs. matching.
# Matched group of first unit in first split.
# m$MGs[[1]][[1]]
# Matched group of third unit in second split.
# m$MGs[[2]][[3]]
Additional information on matched groups can be found by creating and printing
objects of type mg.malts
:
# Only include units matching at least 3 times in the matched group
mg <- make_MG(1, m, threshold_n = 3)
# Unpruned CATEs are computed with respect to all units ever matched.
# Pruned CATEs are computed only with respect to the thresholded units.
print(mg)
The main matched group of unit 1:
Matched to 88 units, 45 treated and 43 control.
The unpruned CATE estimate is: 70.40668.
The pruned CATE estimate (with a threshold of 3) is: 70.1119176.
Python
m.CATE_df #each row is a cate estimate for a corresponding unit
avg.CATE | std.CATE | outcome | treated | |
---|---|---|---|---|
0 | 47.232061 | 21.808950 | -15.313091 | 0.0 |
1 | 40.600643 | 21.958906 | -16.963202 | 0.0 |
2 | 40.877320 | 22.204570 | 9.527929 | 1.0 |
3 | 37.768578 | 19.740320 | -3.940218 | 0.0 |
4 | 39.920257 | 21.744433 | -8.011915 | 0.0 |
... | ... | ... | ... | ... |
2495 | 49.227788 | 21.581176 | -14.529871 | 0.0 |
2496 | 42.352355 | 21.385861 | 19.570055 | 1.0 |
2497 | 43.737763 | 19.859275 | -16.342666 | 0.0 |
2498 | 41.189297 | 20.346711 | -9.165242 | 0.0 |
2499 | 45.427037 | 23.762884 | -17.604829 | 0.0 |
2500 rows × 4 columns
Estimating the Average Treatment Effect (ATE)
ATE = m.CATE_df['avg.CATE'].mean()
ATE
42.29673993471417
R
head(m$data[, c('CATE', 'sd_CATE', 'outcome', 'treated')])
summary(m)
CATE | outcome | treated | weight | |
---|---|---|---|---|
1355 | 72.734244 | 3.467431 | -15.679894 | 0 |
1320 | 28.543205 | 3.498173 | -7.068587 | 0 |
1233 | 26.917941 | 4.022228 | -5.133200 | 0 |
706 | 52.258803 | 4.115889 | 39.684984 | 1 |
438 | 6.766184 | 4.370843 | -2.954324 | 0 |
184 | 19.066282 | 5.476048 | -6.901449 | 0 |
The average treatment effect of `treated` on `outcome` is estimated to be
41.262.
Average Stretch Values:
Minimum Maximum
Continuous 0.151 (X9) 2.46 (X7)
Please refer to the internal R package documentation for full details on MALTS
.
Argument Type | Python | R |
---|---|---|
Input data | data |
data |
Name of outcome column | outcome |
outcome |
Name of treatment column | treatment |
treatment |
Names of discrete columns | discrete |
discrete |
Loss regularization | C |
C |
Matched group sizes | k_tr , k_est |
k_tr , k_est |
Reweighting | reweight |
reweight |
CATE Smoothing | smooth_CATE , estimator |
smooth_CATE , NA |
Refitting | n_repeats , n_folds |
n_repeats , n_folds |
CATE data frame formatting | output_format |
NA |
Missing data handling | NA | missing_data , impute_with_outcome , impute_with_treatment |
Optimization parameters | NA | ... |
Python
fig = plt.figure(figsize=(10,5))
sns.kdeplot(m.CATE_df['avg.CATE'],shade=True)
plt.axvline(ATE,c='black')
plt.text(ATE-4,0.04,'$\hat{ATE}$',rotation=90)
Text(38.29673993471417, 0.04, '$\\hat{ATE}$')
R
plot(m, which_plots = 2)
R
# Boxplots across `n_repeats` x `n_folds` different matrices learned
plot(m, which_plots = 1)
Plotting the X1 and X2 marginal of matched-group of unit "0"
Python
MG0 = m.MG_matrix.loc[0] #fetching the matched group
matched_units_idx = MG0[MG0!=0].index #getting the indices of the matched units
matched_units = df.loc[matched_units_idx] #fetching the data of matched units
sns.lmplot(x='X1', y='X2', hue='treated', data=matched_units,palette="Set1") #plotting the MG on (X1,X2)
plt.scatter(x=[df.loc[0,'X1']],y=[df.loc[0,'X2']],c='black',s=100) #plotting the unit-0 on (X1,X2)
plt.title('Matched Group for Unit-0') #setting title of the plot
Text(0.5, 1, 'Matched Group for Unit-0')
R
mg <- make_MG(1, m)
plot(mg, 'X1', 'X2', smooth = TRUE)
Plotting CATE v.s. X1
Python
data_w_cate = df.join(m.CATE_df, rsuffix='_').drop(columns=['outcome_','treated_']) #joining cate dataframe with data
sns.regplot( x='X1', y='avg.CATE', data=data_w_cate, scatter_kws={'alpha':0.5,'s':2}, line_kws={'color':'black'}, order=2 ) #fitting a degree 2 polynomial X1 on CATE
R
plot_CATE(m, 'X1', smooth = TRUE)