Skip to content

almost-matching-exactly/MALTS

Repository files navigation

Introduction

MALTS is a learning-to-match method for interpretable causal inference proposed by Harsh Parikh, Cynthia Rudin and Alexander Volfovsky in their 2019 paper "MALTS: Matching After Learning to Stretch". This repository contains PyMALTS and MALTS, Python and R implementations, respectively, of the MALTS algorithm.

Setup

Dependencies

Python

PyMALTS is a Python3 library and it requires numpy, pandas, scipy, scikit-learn, matplotlib and seaborn.

R

MALTS requires nloptr for distance metric learning.

Installation

Python

pip install git+https://github.com/almost-matching-exactly/MALTS.git

R

devtools::install_github('https://github.com/almost-matching-exactly/MALTS', 
                         subdir = 'RMALTS/MALTS')

Importing

Python

import pymalts
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
np.random.seed(0)
sns.set()

R

library(MALTS)

Reading Data

Python

df = pd.read_csv('example/example_data.csv',index_col=0)
print(df.shape)
df.head()
(2500, 20)
X1 X2 X3 X4 X5 X6 X7 X8 X9 X10 X11 X12 X13 X14 X15 X16 X17 X18 outcome treated
1355 1.881335 1.684164 0.532332 2.002254 1.435032 1.450196 1.974763 1.321659 0.709443 -1.141244 0.883130 0.956721 2.498229 2.251677 0.375271 -0.545129 3.334220 0.081259 -15.679894 0
1320 0.666476 1.263065 0.657558 0.498780 1.096135 1.002569 0.881916 0.740392 2.780857 -0.765889 1.230980 -1.214324 -0.040029 1.554477 4.235513 3.596213 0.959022 0.513409 -7.068587 0
1233 -0.193200 0.961823 1.652723 1.117316 0.590318 0.566765 0.775715 0.938379 -2.055124 1.942873 -0.606074 3.329552 -1.822938 3.240945 2.106121 0.857190 0.577264 -2.370578 -5.133200 0
706 1.378660 1.794625 0.701158 1.815518 1.129920 1.188477 0.845063 1.217270 5.847379 0.566517 -0.045607 0.736230 0.941677 0.835420 -0.560388 0.427255 2.239003 -0.632832 39.684984 1
438 0.434297 0.296656 0.545785 0.110366 0.151758 -0.257326 0.601965 0.499884 -0.973684 -0.552586 -0.778477 0.936956 0.831105 2.060040 3.153799 0.027665 0.376857 -1.221457 -2.954324 0

R

df <- read.csv('example/example_data.csv', row.names = 1)
# dim(df)
# head(df)

Using MALTS

Distance Metric Learning

Python

# Default settings
m = pymalts.malts_mf( outcome='outcome', treatment='treated', data=df) 

R

# Default settings
m <- MALTS(data = df, outcome = 'outcome', treatment = 'treated')

Matched Groups

Python

Matched Group matrix (MG_matrix) is NxN matrix with each row corresponding to each query unit and each column corresponds to matched units. Cell (i,j) in the matrix corresponds to the weight of unit j in the matched group of unit i. The weight corresponds to the numbers of times a unit is included in a matched group across M-folds.

m.MG_matrix
1355 1320 1233 706 438 184 1108 1612 816 131 ... 1181 1698 916 59 2267 1520 1408 909 603 2285
1355 4.0 0.0 0.0 2.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 3.0 0.0 3.0
1320 0.0 4.0 0.0 0.0 0.0 0.0 0.0 1.0 4.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1233 0.0 0.0 4.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
706 0.0 0.0 0.0 4.0 0.0 0.0 0.0 1.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
438 0.0 0.0 0.0 0.0 4.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1520 0.0 0.0 0.0 0.0 2.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 4.0 0.0 0.0 0.0 0.0
1408 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 3.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 4.0 0.0 0.0 0.0
909 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4.0 0.0 0.0
603 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 2.0 0.0 0.0 0.0 0.0 0.0 0.0 4.0 0.0
2285 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4.0

2500 rows × 2500 columns

Visualizing matched group matrix as heatmap

fig = plt.figure(figsize=(10,10))
sns.heatmap(m.MG_matrix)

png

Accessing the matched group for an example unit with index equal to "1" and visualizing the weights as bar-chart

MG1 = m.MG_matrix.loc[1] #matched group for unit "1"
MG1[MG1>1].sort_values(ascending=False).plot(kind='bar',figsize=(20,5)) #Visualizing all the units matched to unit 1 more than once

png

R

Matched groups can be found in the MGs entry of the output of MALTS. Entries correspond to different folds being used for distance metric estimation vs. matching.

# Matched group of first unit in first split.
# m$MGs[[1]][[1]]

# Matched group of third unit in second split. 
# m$MGs[[2]][[3]]

Additional information on matched groups can be found by creating and printing objects of type mg.malts:

# Only include units matching at least 3 times in the matched group
mg <- make_MG(1, m, threshold_n = 3)

# Unpruned CATEs are computed with respect to all units ever matched. 
# Pruned CATEs are computed only with respect to the thresholded units.
print(mg)
The main matched group of unit 1:
  Matched to 88 units, 45 treated and 43 control.
  The unpruned CATE estimate is: 70.40668.
  The pruned CATE estimate (with a threshold of 3) is: 70.1119176.

Treatment Effect Estimates

Python

m.CATE_df #each row is a cate estimate for a corresponding unit
avg.CATE std.CATE outcome treated
0 47.232061 21.808950 -15.313091 0.0
1 40.600643 21.958906 -16.963202 0.0
2 40.877320 22.204570 9.527929 1.0
3 37.768578 19.740320 -3.940218 0.0
4 39.920257 21.744433 -8.011915 0.0
... ... ... ... ...
2495 49.227788 21.581176 -14.529871 0.0
2496 42.352355 21.385861 19.570055 1.0
2497 43.737763 19.859275 -16.342666 0.0
2498 41.189297 20.346711 -9.165242 0.0
2499 45.427037 23.762884 -17.604829 0.0

2500 rows × 4 columns

Estimating the Average Treatment Effect (ATE)

ATE = m.CATE_df['avg.CATE'].mean()
ATE
42.29673993471417

R

head(m$data[, c('CATE', 'sd_CATE', 'outcome', 'treated')])
summary(m)
CATE outcome treated weight
1355 72.734244 3.467431 -15.679894 0
1320 28.543205 3.498173 -7.068587 0
1233 26.917941 4.022228 -5.133200 0
706 52.258803 4.115889 39.684984 1
438 6.766184 4.370843 -2.954324 0
184 19.066282 5.476048 -6.901449 0
The average treatment effect of `treated` on `outcome` is estimated to be
41.262.

 Average Stretch Values:
             Minimum    Maximum
  Continuous 0.151 (X9) 2.46 (X7)

Documentation

Please refer to the internal R package documentation for full details on MALTS.

Argument Type Python R
Input data data data
Name of outcome column outcome outcome
Name of treatment column treatment treatment
Names of discrete columns discrete discrete
Loss regularization C C
Matched group sizes k_tr, k_est k_tr, k_est
Reweighting reweight reweight
CATE Smoothing smooth_CATE, estimator smooth_CATE, NA
Refitting n_repeats, n_folds n_repeats, n_folds
CATE data frame formatting output_format NA
Missing data handling NA missing_data,
impute_with_outcome, impute_with_treatment
Optimization parameters NA ...

Visualization

Visualizing CATE Estimates

Python

fig = plt.figure(figsize=(10,5))
sns.kdeplot(m.CATE_df['avg.CATE'],shade=True)
plt.axvline(ATE,c='black')
plt.text(ATE-4,0.04,'$\hat{ATE}$',rotation=90)
Text(38.29673993471417, 0.04, '$\\hat{ATE}$')

png

R

plot(m, which_plots = 2)

png

Visualizing the Stretch Matrix

R

# Boxplots across `n_repeats` x `n_folds` different matrices learned
plot(m, which_plots = 1)

png

Looking Inside a Matched Group

Plotting the X1 and X2 marginal of matched-group of unit "0"

Python

MG0 = m.MG_matrix.loc[0] #fetching the matched group
matched_units_idx = MG0[MG0!=0].index #getting the indices of the matched units 
matched_units = df.loc[matched_units_idx] #fetching the data of matched units

sns.lmplot(x='X1', y='X2', hue='treated', data=matched_units,palette="Set1") #plotting the MG on (X1,X2)
plt.scatter(x=[df.loc[0,'X1']],y=[df.loc[0,'X2']],c='black',s=100) #plotting the unit-0 on (X1,X2)
plt.title('Matched Group for Unit-0') #setting title of the plot
Text(0.5, 1, 'Matched Group for Unit-0')

png

R

mg <- make_MG(1, m)
plot(mg, 'X1', 'X2', smooth = TRUE)

png

Plotting CATE versus covariate

Plotting CATE v.s. X1

Python

data_w_cate = df.join(m.CATE_df, rsuffix='_').drop(columns=['outcome_','treated_']) #joining cate dataframe with data

sns.regplot( x='X1', y='avg.CATE', data=data_w_cate, scatter_kws={'alpha':0.5,'s':2}, line_kws={'color':'black'}, order=2 ) #fitting a degree 2 polynomial X1 on CATE

png

R

plot_CATE(m, 'X1', smooth = TRUE)

png

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •