Optimal Transport for Counterfactual Estimation: A Method for Causal Inference

Online Appendix: short examples

Authors

Arthur Charpentier

Emmanuel Flachaire

Ewen Gallic

Content of the Notebook

This online appendix provides R codes to apply the methods presented in the companion paper (Charpentier, Flachaire, and Gallic (2023)).

We would like to measure the effect of a treatment variable \(T\) on the probability of observing a binary outcome \(y\), depending on a binary treatment \(T\). The outcome is assumed to depend other variables \(\boldsymbol{x}^m\) that are also influenced by the treatment; these variable are mediators.

Two short examples are provided in this notebook: one in which there is only one mediator (Section 2) and one where we consider two mediators (Section 3).

1 Data

To illustrate the methods, we rely on the Linked Birth/Infant Death Cohort Data (2013 cohort). The CSV files for the 2013 cohort were downloaded from the NBER collection of Birth Cohort Linked Birth and Infant Death Data of the National Vital Statistics System of the National Center for Health Statistics, on the NBER website. We formated the data (see ./births_stats.html).

Let us load the birth data.

library(tidyverse)
load("../data/births.RData")

Then, we can only keep a subsample of variables to illustrate the method.

base <- 
  births %>% 
  mutate(
    black_mother = mracerec == "Black",
    nonnatural_delivery = rdmeth_rec != "Vaginal"
  ) %>% 
  select(
    sex, dbwt, cig_rec, wtgain, 
    black_mother, nonnatural_delivery,
    mracerec, rdmeth_rec
  ) %>% 
  mutate(
    cig_rec = replace_na(cig_rec, "Unknown or not stated"),
    black_mother = ifelse(black_mother, yes = "Yes", no = "No"),
    is_girl = ifelse(sex == "Female", yes = "Yes", no = "No")
  ) %>% 
  labelled::set_variable_labels(
    black_mother = "Is the mother Black?",
    is_girl = "Is the newborn a girl?",
    nonnatural_delivery = "Is the delivery method non-natural?"
  ) %>% 
  rename(birth_weight = dbwt)
base
# A tibble: 3,940,764 × 9
   sex    birth_weight cig_rec wtgain black_mo…¹ nonna…² mrace…³ rdmet…⁴ is_girl
   <fct>         <dbl> <fct>    <dbl> <chr>      <lgl>   <fct>   <fct>   <chr>  
 1 Male           3017 No           1 No         FALSE   "White" Vaginal No     
 2 Female         3367 No          18 No         FALSE   "White" Vaginal Yes    
 3 Male           4409 No          36 No         FALSE   "White" Vaginal No     
 4 Male           4047 No          25 No         FALSE   "White" Vaginal No     
 5 Male           3119 Yes          0 No         FALSE   "Ameri… Vaginal No     
 6 Female         4295 Yes         34 No         FALSE   "White" Vaginal Yes    
 7 Male           3160 No          47 No         FALSE   "Ameri… Vaginal No     
 8 Male           3711 No          39 No         FALSE   "White" Vaginal No     
 9 Male           3430 No          NA No         FALSE   "White" Vaginal No     
10 Female         3772 No          36 No         FALSE   "White" Vaginal Yes    
# … with 3,940,754 more rows, and abbreviated variable names ¹​black_mother,
#   ²​nonnatural_delivery, ³​mracerec, ⁴​rdmeth_rec

The variables that were kept are the following:

  • sex: sex of the newborn
  • birth_weight: Birth Weight (in Grams)
  • cig_rec: mother smokes cigarettes
  • wtgain: mother weight gain
  • mracerec: mother’s race
  • black_mother: is mother black?
  • rdmeth_rec: delivery method
  • nonnatural_delivery: is the delivery method non-natural?

Then, let us discard individuals with missing values.

base <- 
  base %>% 
  filter(
    !is.na(birth_weight),
    !is.na(wtgain),
    !is.na(nonnatural_delivery),
    !is.na(black_mother)
  )

Let us define some colours for later use:

library(wesanderson)
colr1 <- wes_palette("Darjeeling1")
colr2 <- wes_palette("Darjeeling2")
couleur1 <- colr1[2]
couleur2 <- colr1[4]
couleur3 <- colr2[2]

coul1 <- "#882255"
coul2 <- "#DDCC77"

We need to load some packages.

library(car)
library(transport)
library(splines)
library(mgcv)
library(expm)
library(tidyverse)

2 Univariate Case

Let us first consider the univariate case.

2.1 Objective

The outcome is assumed to depend on the treatment and on a single mediator variable \(\boldsymbol{x}^m\) (the latter is assumed to be also influenced by the treatment).

Variable Name Description
Output (\(y\)) nonnatural_delivery Probability of having a non-natural delivery
Treatment (\(T\)) cig_rec Whether the mother smokes \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\)
Mediator (\(\boldsymbol{x}^m\)) birth_weight Birth weight of the newborn

Let us print some summary statistics of the variables considered.

library(gtsummary)
base %>% 
  select(nonnatural_delivery, cig_rec, birth_weight) %>% 
  tbl_summary(
    by = cig_rec,
    type = all_continuous() ~ "continuous2",
    statistic = list(
      all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}, {p75})"),
      all_categorical() ~ "{n} ({p}%)"),
    digits = list(
      all_continuous() ~ 2,
      all_categorical() ~ 0
    ),
    missing_text = "Missing value"
  ) %>% 
  add_p() %>% 
  add_overall(col_label = "Whole sample") %>% 
  modify_header(label ~ "**Variable**") %>% 
  modify_spanning_header(c("stat_1", "stat_2") ~ "**Mother is a smoker**") %>% 
  add_stat_label(
    label = list(
      all_continuous() ~ c("Mean (Std)", "Median (IQR)"),
      all_categorical() ~ "n (%)"
    )
  )
Table 1. Summary statistics for the observations
Variable Whole sample Mother is a smoker Unknown or not stated, N = 147,766 p-value1
Yes, N = 273,685 No, N = 2,959,847
Is the delivery method non-natural?, n (%) 1,159,776 (34%) 94,580 (35%) 1,013,995 (34%) 51,201 (35%) <0.001
Birth Weight (in Grams) <0.001
    Mean (Std) 3,276.01 (588.22) 3,102.59 (595.08) 3,292.17 (583.95) 3,273.51 (608.59)
    Median (IQR) 3,317.00 (2,977.00, 3,635.00) 3,147.00 (2,802.00, 3,472.00) 3,330.00 (3,001.00, 3,657.00) 3,317.00 (2,970.00, 3,657.00)
1 Pearson's Chi-squared test; Kruskal-Wallis rank sum test
Sample Conditional Average Treatment Effect: \(\text{SCATE}(\boldsymbol{x})\)

Consider two models, \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\), that estimate, respectively, \(\color{couleur1}{\mathbb{E}[Y|X=x,T=0]}\) and \(\color{couleur2}{\mathbb{E}[Y|X=x,T=1]}\), \[ \text{SCATE}(x)=\color{couleur2}{\widehat{m}_1}\big(\widehat{\mathcal{T}}(x)\big)\color{black}{} - \color{couleur1}{\widehat{m}_0}\big(x\big) \] where \(\widehat{\mathcal{T}}(\cdot)\) is a transport function:

  • if we consider a quantile-based matching, \(\widehat{\mathcal{T}}(x)= \color{couleur2}{\widehat{F}_1^{-1}}\color{black}{}\circ \color{couleur1}{\widehat{F}_0}(x)\), with \(\color{couleur1}{\widehat{F}_0}\) and \(\color{couleur2}{\widehat{F}_1}\) denoting the empirical distribution functions of \(x\) conditional on \(\color{couleur1}{t=0}\) and \(\color{couleur2}{t=1}\), respectively.
  • assuming a Gaussian distribution of the mediator variable, we can consider \(\widehat{\mathcal{T}}(x) := \widehat{\mathcal{T}}_{\mathcal{N}}(x)= \color{couleur2}{\overline{x}_1}\color{black}{}+\color{couleur2}{s_1}\color{couleur1}{s_0^{-1}}\color{black}{} (x-\color{couleur1}{\overline{x}_0}\color{black}{})\), \(\color{couleur1}{\overline{x}_0}\) and \(\color{couleur2}{\overline{x}_1}\) being respectively the averages of \(x\) in the two sub-populations, and \(\color{couleur1}{s_0}\) and \(\color{couleur2}{s_1}\) the sample standard deviations.
What is Needed

To estimate the mutatis mutandis Sample Conditional Average Treatment Effect, the following are required:

  1. The two estimated models \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\).
  2. A function to predict new values with these each of these two models.
  3. A transport method \(\mathcal{T}(\cdot)\).

We will compute the SCATE at the following values of \(\boldsymbol{x}\):

x_0_birth_weight <- seq(2000,4500, by=500)
x_0_birth_weight
[1] 2000 2500 3000 3500 4000 4500

2.2 Models

Let us estimate \(\color{couleur1}{\widehat{m}_0(x)}\) using a GAM model, on the subset of the non-smokers:

reg_0 <- glm(nonnatural_delivery ~ bs(birth_weight),
             data=base, family=binomial,subset = (cig_rec == "No"))

And \(\color{couleur2}{\widehat{m}_1(x)}\), on the subset of the smokers:

reg_1 <- glm(nonnatural_delivery ~ bs(birth_weight),
             data=base, family=binomial,subset = (cig_rec == "Yes"))
Type of Model

It is possible to estimate different models, the methodology is not restricted to GAM models.

2.3 Prediction Function

Let us define a prediction function for these types of models:

#' @param object regression model (GAM)
#' @param newdata data frame in which to look for the mediator variable used to predict the target
model_spline_predict <- function(object, newdata){
  predict(object, newdata = newdata, type="response")
}

2.4 Transport

Let us now turn to the transport of the mediator variable.

2.4.1 Quantile-based

We define a function to transport the values of a mediator.

#' Quantile-based transport (Univariate)
#' @param x_0 vector of values to be transported
#' @param x_m_name name of the mediator variable
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
#' @param data data frame with both T=0 or T=1
transport_quantile <- function(x_0, x_m_name, treatment_name, treatment_0, treatment_1, data){
  ind_0 <- pull(data, treatment_name) == treatment_0
  x_val <- pull(data, x_m_name)
  # Empirical Cumulative Distribution Function for values of the mediator variable for non-treated
  Fn <- ecdf(x_val)
  # Probability associated in the non-treated
  u <- Fn(x_0)
  # Transported_values
  x_1 <- pull(data, x_m_name)[pull(data, treatment_name) == treatment_1]
  x_t_quantile <- quantile(x_1, u)
  
  list(x_0 = x_0, u = u, x_t_quantile = x_t_quantile)
}

The transported values of \(\boldsymbol{x}\), i.e., \(\mathcal{T}(\boldsymbol{x})\)

x_0_birth_weight_t_q <- 
  transport_quantile(x_0 = x_0_birth_weight,
                     x_m_name = "birth_weight", 
                     treatment_name = "cig_rec",
                     treatment_0 = "No",
                     treatment_1 = "Yes",
                     data = base)
x_0_birth_weight_t_q
$x_0
[1] 2000 2500 3000 3500 4000 4500

$u
[1] 0.02883478 0.07896465 0.26154364 0.65182572 0.92058079 0.98907431

$x_t_quantile
2.883478% 7.896465% 26.15436% 65.18257% 92.05808% 98.90743% 
     1800      2301      2815      3340      3850      4338 

2.4.2 Gaussian Assumption

Now, let us consider the case in which the mediator is assumed to be Gaussian.

The transport function can be defined as follows:

#' @param x_0 vector of values to be transported
#' @param x_m_name name of the mediator variable
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1value for treated
#' @param data data frame with both T=0 or T=1
transport_univ_gaussian <- function(x_0, x_m_name, treatment_name, treatment_0, treatment_1, data){
  x0 <- mean(pull(data, x_m_name)[pull(data, treatment_name) == treatment_0])
  x1 <- mean(pull(data, x_m_name)[pull(data, treatment_name) == treatment_1])
  s0 <- sd(pull(data, x_m_name)[pull(data, treatment_name) == treatment_0])
  s1 <- sd(pull(data, x_m_name)[pull(data, treatment_name) == treatment_1])
  u_N <- pnorm(x_0,x0,s0)
  x_t_N <- qnorm(u_N, x1, s1)
  list(x_0 = x_0, u_N = u_N, x_t_N = x_t_N)
}

The transported values of \(\boldsymbol{x}\), i.e., \(\mathcal{T}_\mathcal{N}(\boldsymbol{x})\):

x_0_birth_weight_t_n <- 
  transport_univ_gaussian(x_0 = x_0_birth_weight,
                     x_m_name = "birth_weight", 
                     treatment_name = "cig_rec",
                     treatment_0 = "No", 
                     treatment_1 = "Yes",
                     data = base)
x_0_birth_weight_t_n
$x_0
[1] 2000 2500 3000 3500 4000 4500

$u_N
[1] 0.01345452 0.08745575 0.30841597 0.63904204 0.88727139 0.98069826

$x_t_N
[1] 1785.778 2295.311 2804.845 3314.379 3823.913 4333.446

2.5 Estimation of the Sample Conditional Average Treatment Effect

Let us define a function that computes the Mutatis Mutandis Sample Conditional Average Effect:

#' Computes the Sample Average Treatment Effect with and without transport
#' @param x_0 vector of values at which to compute $\hat{m}_0(x_0)$, $\hat{m}_1(x_0)$
#' @param x_t vector of values at which to compute $\hat{m}_1(\mathcal{T}(x_t))$
#' @param x_m_name vector of names of the mediator variables
#' @param mod_0 model $\hat{m}_0$
#' @param mod_1 model $\hat{m}_1$
#' @param pred_mod_0 prediction function for model $\hat{m}_0(\cdot)$
#' @param pred_mod_1 prediction function for model $\hat{m}_1(\cdot)$
#' @param return_x if `TRUE` (default) the mediator variables are returned in the table, as well as their transported values
sate <- function(x_0, x_t, x_m_names, mod_0, mod_1, pred_mod_0, pred_mod_1, return_x = TRUE){
  if(is.vector(x_0)){
    # Univariate case
    new_data <- tibble(!!x_m_names := x_0)
  }else{
    new_data <- x_0
  }
  if(is.vector(x_t)){
    # Univariate case
    new_data_t <- tibble(!!x_m_names := x_t)
  }else{
    new_data_t <- x_t
  }
  
  
  # $\hat{m}_0(x_0)$
  y_0 <- pred_mod_0(object = mod_0, newdata = new_data)
  # $\hat{m}_1(x_0)$
  y_1 <- pred_mod_1(object = mod_1, newdata = new_data)
  # $\hat{m}_1(\mathcal{T}(x_0))$
  y_1_t <- pred_mod_1(object = mod_1, newdata = new_data_t)
  
  
  scate_tab <- 
    tibble(y_0 = y_0, y_1 = y_1, y_1_t = y_1_t,
         CATE = y_1-y_0, SCATE = y_1_t-y_0)
  
  if(return_x){
    new_data_t <- new_data_t %>% rename_all(~paste0(.x, "_t"))
    scate_tab <- bind_cols(new_data, new_data_t, scate_tab)
  }
  scate_tab
}

The quantile-based SCATE (\(SCATE(\boldsymbol{x})\)), computed using GAM models for \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\):

cate_q_gam_smoke <-
  sate(x_0 = x_0_birth_weight,
       x_t = x_0_birth_weight_t_q$x_t_quantile,
       x_m_names = "birth_weight",
       mod_0 = reg_0,
       mod_1 = reg_1,
       pred_mod_0 = model_spline_predict,
       pred_mod_1 = model_spline_predict)
cate_q_gam_smoke
# A tibble: 6 × 7
  birth_weight birth_weight_t   y_0   y_1 y_1_t     CATE    SCATE
         <dbl>          <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
1         2000           1800 0.512 0.468 0.506 -0.0441  -0.00553
2         2500           2301 0.404 0.378 0.412 -0.0255   0.00775
3         3000           2815 0.326 0.319 0.336 -0.00687  0.0101 
4         3500           3340 0.301 0.309 0.306  0.00790  0.00488
5         4000           3850 0.349 0.369 0.342  0.0197  -0.00770
6         4500           4338 0.513 0.537 0.469  0.0250  -0.0436 

The SCATE (\(SCATE(\boldsymbol{x})\)) assuming a Gaussian distribution of the mediator variable, computed using GAM models for \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\):

cate_n_gam_smoke <-
  sate(x_0 = x_0_birth_weight,
       x_t = x_0_birth_weight_t_n$x_t_N,
       x_m_names = "birth_weight",
       mod_0 = reg_0,
       mod_1 = reg_1,
       pred_mod_0 = model_spline_predict,
       pred_mod_1 = model_spline_predict)
cate_n_gam_smoke
# A tibble: 6 × 7
  birth_weight birth_weight_t   y_0   y_1 y_1_t     CATE    SCATE
         <dbl>          <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
1         2000          1786. 0.512 0.468 0.509 -0.0441  -0.00282
2         2500          2295. 0.404 0.378 0.413 -0.0255   0.00876
3         3000          2805. 0.326 0.319 0.337 -0.00687  0.0112 
4         3500          3314. 0.301 0.309 0.306  0.00790  0.00499
5         4000          3824. 0.349 0.369 0.338  0.0197  -0.0115 
6         4500          4333. 0.513 0.537 0.467  0.0250  -0.0453 

3 Multivariate Case

Now, let us turn to the multivariate case, where we consider two mediator variables: the birth weight and the weight gain of the mother.

3.1 Objective

The objective is the same as that in Section 2.

Variable Name Description
Output (\(y\)) nonnatural_delivery Probability of having a non-natural delivery
Treatment (\(T\)) cig_rec Whether the mother smokes \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\)
Mediators (\(\boldsymbol{x}^m\)) birth_weight, wtgain Birth weight of the newborn, weight gain on the mother

Let us print some summary statistics of the variables considered.

library(gtsummary)
base %>% 
  select(nonnatural_delivery, cig_rec, birth_weight, wtgain) %>% 
  tbl_summary(
    by = cig_rec,
    type = all_continuous() ~ "continuous2",
    statistic = list(
      all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}, {p75})"),
      all_categorical() ~ "{n} ({p}%)"),
    digits = list(
      all_continuous() ~ 2,
      all_categorical() ~ 0
    ),
    missing_text = "Missing value"
  ) %>% 
  add_p() %>% 
  add_overall(col_label = "Whole sample") %>% 
  modify_header(label ~ "**Variable**") %>% 
  modify_spanning_header(c("stat_1", "stat_2") ~ "**Mother is a smoker**") %>% 
  add_stat_label(
    label = list(
      all_continuous() ~ c("Mean (Std)", "Median (IQR)"),
      all_categorical() ~ "n (%)"
    )
  )
Table 2. Summary statistics for the observations
Variable Whole sample Mother is a smoker Unknown or not stated, N = 147,766 p-value1
Yes, N = 273,685 No, N = 2,959,847
Is the delivery method non-natural?, n (%) 1,159,776 (34%) 94,580 (35%) 1,013,995 (34%) 51,201 (35%) <0.001
Birth Weight (in Grams) <0.001
    Mean (Std) 3,276.01 (588.22) 3,102.59 (595.08) 3,292.17 (583.95) 3,273.51 (608.59)
    Median (IQR) 3,317.00 (2,977.00, 3,635.00) 3,147.00 (2,802.00, 3,472.00) 3,330.00 (3,001.00, 3,657.00) 3,317.00 (2,970.00, 3,657.00)
Weight Gain <0.001
    Mean (Std) 30.41 (15.04) 30.55 (17.25) 30.41 (14.80) 30.17 (15.43)
    Median (IQR) 30.00 (20.00, 39.00) 30.00 (19.00, 41.00) 30.00 (21.00, 39.00) 30.00 (20.00, 40.00)
1 Pearson's Chi-squared test; Kruskal-Wallis rank sum test
Mutatis Mutandis \(\text{CATE}(\boldsymbol{x})\), in the Multivariate Case

The Mutadis Mutandis CATE is given by: \[ \color{couleur2}{m_1}\color{black}{}(\mathcal{T}(\boldsymbol{x}^m),\boldsymbol{x}^c,\not\!\boldsymbol{x}^p)-\color{couleur1}{m_0}\color{black}{}(\boldsymbol{x}^m,\boldsymbol{x}^c,\not\!\boldsymbol{x}^p), \]

where \(\boldsymbol{x}^c\) are collider (exogeneous) variables that influence \(y\) but are not influenced by the treatment and \(\not\!\boldsymbol{x}^p\) are confounding (noise, proxy) variables that are influenced by the treatment but do not influence \(y\). As the latter are only correlated with \(y\) (no causal relationship), they are excluded from \(m(\cdot)\).

Gaussian Assumption

In the case where \(\color{couleur2}{\boldsymbol{X}|t=1\sim\mathcal{N}(\boldsymbol{\mu}_1,\boldsymbol{\Sigma}_1)}\) and \(\color{couleur1}{\boldsymbol{X}|t=0\sim\mathcal{N}(\boldsymbol{\mu}_0,\boldsymbol{\Sigma}_0)}\), there is an explicit expression for the optimal transport, which is simply an affine map (see Villani (2003) for more details). In the univariate case, \(\color{couleur2}{x_1}\color{black}{} = \mathcal{T}^\star_{\mathcal{N}}(\color{couleur1}{x_0}\color{black}{}) = \color{couleur2}{\mu_1}\color{black}{}+ \displaystyle{\frac{\color{couleur2}{\sigma_1}}{\color{couleur1}{\sigma_0}}(\color{couleur1}{x_0}\color{black}{}-\color{couleur1}{\mu_0}\color{black}{})}\), while in the multivariate case, an analogous expression can be derived: \[ \color{couleur2}{\boldsymbol{x}_1}\color{black}{} = \mathcal{T}^\star_{\mathcal{N}}(\color{couleur1}{\boldsymbol{x}_0}\color{black}{})=\color{couleur2}{\boldsymbol{\mu}_1}\color{black}{} + \boldsymbol{A}(\color{couleur1}{\boldsymbol{x}_0}\color{black}{}-\color{couleur1}{\boldsymbol{\mu}_0}\color{black}{}), \] where \(\boldsymbol{A}\) is a symmetric positive matrix that satisfies \(\boldsymbol{A}\boldsymbol{\Sigma}_0\boldsymbol{A}=\boldsymbol{\Sigma}_1\), which has a unique solution given by \(\boldsymbol{A}=\color{couleur1}{\boldsymbol{\Sigma}_0}^{\color{black}{-1/2}}\color{black}{}\big(\color{couleur1}{\boldsymbol{\Sigma}_0}^{\color{black}{1/2}}\color{couleur2}{\boldsymbol{\Sigma}_1}\color{couleur1}{\boldsymbol{\Sigma}_0}^{\color{black}{1/2}}\color{black}{}\big)^{1/2}\color{couleur1}{\boldsymbol{\Sigma}_0}^{\color{black}{-1/2}}\), where \(\boldsymbol{M}^{1/2}\) is the square root of the square (symmetric) positive matrix \(\boldsymbol{M}\) based on the Schur decomposition (\(\boldsymbol{M}^{1/2}\) is a positive symmetric matrix), as described in Higham (2008).

What is Needed

As in the univariate case, to estimate the mutatis mutandis Sample Conditional Average Treatment Effect, the following are required:

  1. The two estimated models \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\).
  2. A function to predict new values with these each of these two models.
  3. A transport method \(\mathcal{T}(\cdot)\).

We will compute the SCATE at the following values of \(\boldsymbol{x}\):

val_birth_weight <- seq(1800, 4600, length = 251)
val_wtgain <- seq(0, 90, length = 251)
val_grid <- expand.grid(wtgain = val_wtgain,  birth_weight = val_birth_weight) %>% as_tibble()
val_grid
# A tibble: 63,001 × 2
   wtgain birth_weight
    <dbl>        <dbl>
 1   0            1800
 2   0.36         1800
 3   0.72         1800
 4   1.08         1800
 5   1.44         1800
 6   1.8          1800
 7   2.16         1800
 8   2.52         1800
 9   2.88         1800
10   3.24         1800
# … with 62,991 more rows

3.2 Models

Let us estimate \(\color{couleur1}{\widehat{m}_0(x)}\) using a GAM model, on the subset of the non-smokers:

reg_0 <- glm(nonnatural_delivery ~ bs(birth_weight)+bs(wtgain),
             data=base, family=binomial,subset = (cig_rec == "No"))

And \(\color{couleur2}{\widehat{m}_1(x)}\), on the subset of the smokers:

reg_1 <- glm(nonnatural_delivery ~ bs(birth_weight)+bs(wtgain),
             data=base, family=binomial,subset = (cig_rec == "Yes"))
Type of Model

As in the univariate case, it is possible to estimate different models, the methodology is not restricted to GAM models.

3.3 Prediction Function

The prediction function is the same as that define in the univariate case:

#' @param object regression model (GAM)
#' @param newdata data frame in which to look for the mediator variable used to predict the target
model_spline_predict <- function(object, newdata){
  predict(object, newdata = newdata, type="response")
}

3.4 Transport

Let us turn to the transport function.

3.4.1 Gaussian Assumption

We assume the mediator variables to be both Normally distributed. The parameters used to transport the mediator variables under the Gaussian assumption can be estimated thanks to the following function:

#' Optimal Transport assuming Gaussian distribution for the mediator variables (helper function)
#' @return A list with the mean and variance of the mediator in each subset, and the symmetric matrix A.
#' @param target name of the target variable
#' @param x_m_names vector of names of the mediator variables
#' @param scale vector of scaling to apply to each `x_m_names` variable to transport (default to 1)
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
transport_gaussian_param <- function(target, x_m_names, scale=1, treatment_name, treatment_0, treatment_1){
  base_0_unscaled <- 
    base %>% 
    select(!!c(x_m_names, treatment_name)) %>% 
    filter(!!sym(treatment_name) ==  treatment_0)
  base_1_unscaled <- 
    base %>% 
    select(!!c(x_m_names, treatment_name)) %>% 
    filter(!!sym(treatment_name) ==  treatment_1)
  
  for(i in 1:length(scale)){
    base_0_scaled <- 
      base_0_unscaled %>% 
      mutate(!!sym(x_m_names[i]) := !!sym(x_m_names[i]) * scale[i])
    base_1_scaled <- 
      base_1_unscaled %>% 
      mutate(!!sym(x_m_names[i]) := !!sym(x_m_names[i]) * scale[i])
  }
  
  # Mean in each subset (i.e., T=0, T=1)
  m_0 <- base_0_scaled %>% summarise(across(!!x_m_names, mean)) %>% as_vector()
  m_1 <- base_1_scaled %>% summarise(across(!!x_m_names, mean)) %>% as_vector()
  # Variance
  S_0 <- base_0_scaled %>% select(!!x_m_names) %>% var()
  S_1 <- base_1_scaled %>% select(!!x_m_names) %>% var()
  # Matrix A
  A <- (solve(sqrtm(S_0))) %*% sqrtm( sqrtm(S_0) %*% S_1 %*% (sqrtm(S_0)) ) %*% solve(sqrtm(S_0))
  
  list(m_0 = m_0, m_1 = m_1, S_0 = S_0, S_1 = S_1, A = A, scale = scale, x_m_names = x_m_names)
}

We create a function that transports a single observation, given the different parameters:

#' Gaussian transport for a single observation (helper function)
#' @param z vector of variables to transport (for a single observation)
#' @param A symmetric positive matrix that satisfies \(A\sigma_0A=\sigma_1\)
#' @param m_0,m_1 vectors of mean values in the subsets \(\mathcal{D}_0\) and \(\mathcal{D}_1\)
#' @param scale vector of scaling to apply to each variable to transport (default to 1)
T_C_single <- function(z, A, m_0, m_1, scale = 1){
  z <- z*scale
  as.vector(m_1 + A %*% (z-m_0))*(1/scale)
}

Lastly, we define a function that transports the mediator variables, under the Gaussian assumption. If the params argument is NULL, the arguments needed to estimate them through the transport_gaussian_param() function are required (target, x_m_names, treatment_name, treatment_0, treatment_1).

#' Gaussian transport
#' @param z data frame of variables to transport
#' @param params (optional) parameters to use for the transport (result of `transport_gaussian_param()`, default to NULL)
#' @param target name of the target variable
#' @param x_m_names vector of names of the mediator variables
#' @param scale vector of scaling to apply to each `x_m_names` variable to transport (default to 1)
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
gaussian_transport <- function(z, params = NULL, ..., x_m_names, scale, treatment_name, treatment_0, treatment_1){
  if(is.null(params)){
    # If the parameters of the transport function are not provided
    # they need to be computed first
    params <- 
      transport_gaussian_param(target = target, x_m_names = x_m_names, 
                               scale = scale, treatment_name = treatment_name, 
                               treatment_0 = treatment_0, treatment_1 = treatment_1)
  }
  A <- params$A
  m_0 <- params$m_0 ; m_1 <- params$m_1
  scale <- params$scale
  x_m_names <- params$x_m_names
  
  values_to_transport <- z %>% select(!!x_m_names)
  transported_val <- 
    apply(values_to_transport, 1, T_C_single, A = A, m_0 = m_0, m_1 = m_1,
          scale = scale, simplify = FALSE) %>% 
    do.call("rbind", .)
  colnames(transported_val) <- colnames(z)
  transported_val <- as_tibble(transported_val)
  
  structure(.Data = transported_val, params = params)
}

The transported values:

val_grid_t_n <- 
  gaussian_transport(z = val_grid, target = "nonnatural_delivery",
                     x_m_names = c("birth_weight", "wtgain"), 
                     scale = c(1, 1/100), treatment_name = "cig_rec",
                     treatment_0 = "No", treatment_1 = "Yes")
head(val_grid_t_n)
# A tibble: 6 × 2
  wtgain birth_weight
   <dbl>        <dbl>
1  1582.        -6.71
2  1582.        -6.29
3  1582.        -5.88
4  1582.        -5.46
5  1582.        -5.04
6  1582.        -4.63

The estimated parameters can be extracted as follows:

attr(val_grid_t_n, "params")
$m_0
birth_weight       wtgain 
3292.1742614    0.3040885 

$m_1
birth_weight       wtgain 
3102.5903539    0.3055169 

$S_0
             birth_weight      wtgain
birth_weight 340992.00650 14.54058206
wtgain           14.54058  0.02189568

$S_1
             birth_weight      wtgain
birth_weight 354119.68908 22.09039380
wtgain           22.09039  0.02976208

$A
             [,1]         [,2]
[1,] 1.0190674679 0.0000143175
[2,] 0.0000143175 1.1550371625

$scale
[1] 1.00 0.01

$x_m_names
[1] "birth_weight" "wtgain"      

These can be used to transport other values, without estimating the parameters again:

gaussian_transport(z = tibble(birth_weight = 2500, wtgain = 2),
                   params = attr(val_grid_t_n, "params"))
# A tibble: 1 × 2
  birth_weight wtgain
*        <dbl>  <dbl>
1        2295.  -3.40

3.5 Estimation of the Sample Conditional Average Treatment Effect

Let us use the sate() function that computes the Mutatis Mutandis Sample Condition Average Effect, defined in Section 2.5.

The Mutatis Mutandis SCATE (\(SCATE(\boldsymbol{x})\)), computed using GAM models for \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\), under the Gaussian assumption can then be obtained as follows:

mm_sate <- 
  sate(
    x_0 = val_grid,
    x_t = val_grid_t_n,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_0,
    mod_1 = reg_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate
# A tibble: 63,001 × 9
   wtgain birth_weight wtgain_t birth_weig…¹   y_0   y_1    y_1_t    CATE  SCATE
    <dbl>        <dbl>    <dbl>        <dbl> <dbl> <dbl>    <dbl>   <dbl>  <dbl>
 1   0            1800    1582.        -6.71 0.589 0.536 2.22e-16 -0.0536 -0.589
 2   0.36         1800    1582.        -6.29 0.587 0.534 2.22e-16 -0.0531 -0.587
 3   0.72         1800    1582.        -5.88 0.585 0.533 2.22e-16 -0.0527 -0.585
 4   1.08         1800    1582.        -5.46 0.584 0.531 2.22e-16 -0.0523 -0.584
 5   1.44         1800    1582.        -5.04 0.582 0.530 2.22e-16 -0.0519 -0.582
 6   1.8          1800    1582.        -4.63 0.580 0.528 2.22e-16 -0.0515 -0.580
 7   2.16         1800    1582.        -4.21 0.578 0.527 2.22e-16 -0.0511 -0.578
 8   2.52         1800    1582.        -3.80 0.576 0.526 2.22e-16 -0.0508 -0.576
 9   2.88         1800    1582.        -3.38 0.575 0.524 2.22e-16 -0.0504 -0.575
10   3.24         1800    1582.        -2.97 0.573 0.523 2.22e-16 -0.0500 -0.573
# … with 62,991 more rows, and abbreviated variable name ¹​birth_weight_t

References

Charpentier, Arthur, Emmanuel Flachaire, and Ewen Gallic. 2023. “Causal Inference with Optimal Transport.” In Optimal Transport Statistics for Economics and Related Topics, edited by Nguyen Ngoc Thach, Vladik Kreinovich, Doan Thanh Ha, and Nguyen Duc Trung. Springer Verlag.
Higham, Nicholas J. 2008. Functions of Matrices: Theory and Computation. Society for Industrial; Applied Mathematics. https://doi.org/10.1137/1.9780898717778.
Villani, Cédric. 2003. Topics in Optimal Transportation. Vol. 58. American Mathematical Society. https://doi.org/10.1090/gsm/058.