Optimal Transport for Counterfactual Estimation: A Method for Causal Inference

Online Appendix: reproduction of the results (multivariate case)

Authors

Arthur Charpentier

Emmanuel Flachaire

Ewen Gallic

This online appendix provides R codes to apply the methods presented in the companion paper (Charpentier, Flachaire, and Gallic (2023)).

1 Data

To illustrate the methods, we rely on the Linked Birth/Infant Death Cohort Data (2013 cohort). The CSV files for the 2013 cohort were downloaded from the NBER collection of Birth Cohort Linked Birth and Infant Death Data of the National Vital Statistics System of the National Center for Health Statistics, on the NBER website. We formated the data (see ./births_stats.html).

Let us load the birth data.

library(tidyverse)
load("../data/births.RData")

Then, we can only keep a subsample of variables to illustrate the method.

base <- 
  births %>% 
  mutate(
    black_mother = mracerec == "Black",
    nonnatural_delivery = rdmeth_rec != "Vaginal"
  ) %>% 
  select(
    sex, dbwt, cig_rec, wtgain, 
    black_mother, nonnatural_delivery,
    mracerec, rdmeth_rec
  ) %>% 
  mutate(
    cig_rec = replace_na(cig_rec, "Unknown or not stated"),
    black_mother = ifelse(black_mother, yes = "Yes", no = "No"),
    is_girl = ifelse(sex == "Female", yes = "Yes", no = "No")
  ) %>% 
  labelled::set_variable_labels(
    black_mother = "Is the mother Black?",
    is_girl = "Is the newborn a girl?",
    nonnatural_delivery = "Is the delivery method non-natural?"
  ) %>% 
  rename(birth_weight = dbwt)
base
# A tibble: 3,940,764 × 9
   sex    birth_weight cig_rec wtgain black_mo…¹ nonna…² mrace…³ rdmet…⁴ is_girl
   <fct>         <dbl> <fct>    <dbl> <chr>      <lgl>   <fct>   <fct>   <chr>  
 1 Male           3017 No           1 No         FALSE   "White" Vaginal No     
 2 Female         3367 No          18 No         FALSE   "White" Vaginal Yes    
 3 Male           4409 No          36 No         FALSE   "White" Vaginal No     
 4 Male           4047 No          25 No         FALSE   "White" Vaginal No     
 5 Male           3119 Yes          0 No         FALSE   "Ameri… Vaginal No     
 6 Female         4295 Yes         34 No         FALSE   "White" Vaginal Yes    
 7 Male           3160 No          47 No         FALSE   "Ameri… Vaginal No     
 8 Male           3711 No          39 No         FALSE   "White" Vaginal No     
 9 Male           3430 No          NA No         FALSE   "White" Vaginal No     
10 Female         3772 No          36 No         FALSE   "White" Vaginal Yes    
# … with 3,940,754 more rows, and abbreviated variable names ¹​black_mother,
#   ²​nonnatural_delivery, ³​mracerec, ⁴​rdmeth_rec

The variables that were kept are the following:

  • sex: sex of the newborn
  • birth_weight: Birth Weight (in Grams)
  • cig_rec: mother smokes cigarettes
  • wtgain: mother weight gain
  • mracerec: mother’s race
  • black_mother: is mother black?
  • rdmeth_rec: delivery method
  • nonnatural_delivery: is the delivery method non-natural?

Then, let us discard individuals with missing values.

base <- 
  base %>% 
  filter(
    !is.na(birth_weight),
    !is.na(wtgain),
    !is.na(nonnatural_delivery),
    !is.na(black_mother)
  )

Let us define some colours for later use:

library(wesanderson)
colr1 <- wes_palette("Darjeeling1")
colr2 <- wes_palette("Darjeeling2")
couleur1 <- colr1[2]
couleur2 <- colr1[4]
couleur3 <- colr2[2]

coul1 <- "#882255"
coul2 <- "#DDCC77"

We need to load some packages.

library(car)
library(transport)
library(splines)
library(mgcv)
library(expm)
library(tidyverse)

Let us visualize a sample of the data using a scatter plot of \(\boldsymbol{x} = (\text{wtgain}, \text{birth weight})\), conditional on the treatment \(T\). For each value of the treatment, we will add to the scatter plot the iso-density curve such that 95% of the points lie in the ellipse formed by the curve, assuming a Gaussian distribution for the mediator variables.

set.seed(123)
# A sample with only a few points
base_s <- base[sample(1:nrow(base), size = 8000), ]

The treatment variable \(T\), (cig_rec) indicates whether the mother smokes \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\).

Show the codes used to iso-density curves
# Ellipse for smoker mothers
E_C_Y <- dataEllipse(base_s$birth_weight[base_s$cig_rec=="Yes"], 
                     base_s$wtgain[base_s$cig_rec=="Yes"],
                     levels=0.9,draw = FALSE)
# Ellipse for non-smoker mothers
E_C_N <- dataEllipse(base_s$birth_weight[base_s$cig_rec=="No"],
                     base_s$wtgain[base_s$cig_rec=="No"],
                     levels=0.9,draw = FALSE)
Show the codes used to create the plot
plot(base_s$birth_weight[base_s$cig_rec=="Yes"],
     base_s$wtgain[base_s$cig_rec=="Yes"],
     col = yarrr::transparent(couleur1, trans.val = .9), pch = 19, 
     main = "", xlab = "Weight of the baby (g)",
     xlim = c(1800,4600), ylim = c(0,90),
     ylab = "Weight gain of the mother (pounds)", axes = FALSE,cex=.3)
axis(1)
axis(2)
points(base_s$birth_weight[base_s$cig_rec=="No"],
       base_s$wtgain[base_s$cig_rec=="No"],
       col = yarrr::transparent(couleur2, trans.val = .9), pch = 19, cex = .3)

lines(E_C_N,col=couleur1,lwd=2)
lines(E_C_Y,col=couleur2,lwd=2)

legend("topleft",c("Smoker","Non-smoker"),lty=1,col=c(couleur2,couleur1),bty="n")

Figure 1. Joint distributions of \(\boldsymbol{X}\) (weight of the newborn infant and weight gain of the mother), conditional on the treatment T, when T indicates whether the mother is a smoker or not.

The treatment variable \(T\), (black_mother) indicates whether the mother is Black \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\).

Show the codes used to iso-density curves
# Ellipse for Black mothers
E_C_Y_blacm <- dataEllipse(base_s$birth_weight[base_s$black_mother=="Yes"], 
                     base_s$wtgain[base_s$black_mother=="Yes"],
                     levels=0.9,draw = FALSE)
# Ellipse for non-Black mothers
E_C_N_blackm <- dataEllipse(base_s$birth_weight[base_s$black_mother=="No"],
                     base_s$wtgain[base_s$black_mother=="No"],
                     levels=0.9,draw = FALSE)
Show the codes used to create the plot
plot(base_s$birth_weight[base_s$black_mother=="Yes"],
     base_s$wtgain[base_s$black_mother=="Yes"],
     col = yarrr::transparent(couleur1, trans.val = .9), pch = 19, 
     main = "", xlab = "Weight of the baby (g)",
     xlim = c(1600,4600), ylim = c(0,90),
     ylab = "Weight gain of the mother (pounds)", axes = FALSE,cex=.3)
axis(1)
axis(2)
points(base_s$birth_weight[base_s$black_mother=="No"],
       base_s$wtgain[base_s$black_mother=="No"],
       col = yarrr::transparent(couleur2, trans.val = .9), pch = 19, cex = .3)

lines(E_C_N_blackm,col=couleur1,lwd=2)
lines(E_C_Y_blacm,col=couleur2,lwd=2)

legend("topleft",c("Black","Non-Black"),lty=1,col=c(couleur2,couleur1),bty="n")

Figure 2. Joint distributions of \(\boldsymbol{X}\) (weight of the newborn infant and weight gain of the mother), conditional on the treatment T, when T indicates whether the mother is Black or not.

The treatment variable \(T\), (sex) indicates whether the baby is a girl \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\).

Show the codes used to iso-density curves
# Ellipse for baby girls
E_C_Y_sex <- dataEllipse(base_s$birth_weight[base_s$sex=="Female"], 
                     base_s$wtgain[base_s$sex=="Female"],
                     levels=0.9,draw = FALSE)
# Ellipse for baby boys
E_C_N_sex <- dataEllipse(base_s$birth_weight[base_s$sex=="Male"],
                     base_s$wtgain[base_s$sex=="Male"],
                     levels=0.9,draw = FALSE)
Show the codes used to create the plot
plot(base_s$birth_weight[base_s$sex=="Female"],
     base_s$wtgain[base_s$sex=="Female"],
     col = yarrr::transparent(couleur1, trans.val = .9), pch = 19, 
     main = "", xlab = "Weight of the baby (g)",
     xlim = c(1800,4600), ylim = c(0,90),
     ylab = "Weight gain of the mother (pounds)", axes = FALSE,cex=.3)
axis(1)
axis(2)
points(base_s$birth_weight[base_s$sex=="Male"],
       base_s$wtgain[base_s$sex=="Male"],
       col = yarrr::transparent(couleur2, trans.val = .9), pch = 19, cex = .3)

lines(E_C_N_sex,col=couleur1,lwd=2)
lines(E_C_Y_sex,col=couleur2,lwd=2)

legend("topleft",c("Girl","Boy"),lty=1,col=c(couleur2,couleur1),bty="n")

Figure 3. Joint distributions of \(\boldsymbol{X}\) (weight of the newborn infant and weight gain of the mother), conditional on the treatment T, when T indicates whether the baby is a girl or a boy.

2 Objectives

We would like to measure the effect of a treatment variable \(T\) on the probability of observing a binary outcome \(y\), depending on a binary treatment \(T\). The outcome is assumed to depend on other variables \(\boldsymbol{x}^m\) that are also influenced by the treatment; these variable are mediators.

  • Output (\(y\), nonnatural_delivery): Probability of having a non-natural delivery

  • Treatment (\(T\)), either:

    • cig_rec: Whether the mother smokes \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\)
    • black_mother: Whether the mother is Black \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\)
    • sex: Whether the baby is a girl \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\)
  • Mediator variables (\(\boldsymbol{x}^m\), birth_weight): Birth weight of the newborn.

Variable Name Description
Output (\(y\)) nonnatural_delivery Probability of having a non-natural delivery
Treatment (\(T\)) cig_rec Whether the mother smokes \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\)
Mediators (\(\boldsymbol{x}^m\)) birth_weight, wtgain Birth weight of the newborn, weight gain on the mother
Mutatis Mutandis \(\text{CATE}(\boldsymbol{x})\), in the Multivariate Case

The Mutadis Mutandis CATE is given by: \[ \color{couleur2}{m_1}\color{black}{}(\mathcal{T}(\boldsymbol{x}^m),\boldsymbol{x}^c,\not\!\boldsymbol{x}^p)-\color{couleur1}{m_0}\color{black}{}(\boldsymbol{x}^m,\boldsymbol{x}^c,\not\!\boldsymbol{x}^p), \]

where \(\boldsymbol{x}^c\) are collider (exogeneous) variables that influence \(y\) but are not influenced by the treatment and \(\not\!\boldsymbol{x}^p\) are confounding (noise, proxy) variables that are influenced by the treatment but do not influence \(y\). As the latter are only correlated with \(y\) (no causal relationship), they are excluded from \(m(\cdot)\).

Gaussian Assumption

In the case where \(\color{couleur2}{\boldsymbol{X}|t=1\sim\mathcal{N}(\boldsymbol{\mu}_1,\boldsymbol{\Sigma}_1)}\) and \(\color{couleur1}{\boldsymbol{X}|t=0\sim\mathcal{N}(\boldsymbol{\mu}_0,\boldsymbol{\Sigma}_0)}\), there is an explicit expression for the optimal transport, which is simply an affine map (see Villani (2003) for more details). In the univariate case, \(\color{couleur2}{x_1}\color{black}{} = \mathcal{T}^\star_{\mathcal{N}}(\color{couleur1}{x_0}\color{black}{}) = \color{couleur2}{\mu_1}\color{black}{}+ \displaystyle{\frac{\color{couleur2}{\sigma_1}}{\color{couleur1}{\sigma_0}}(\color{couleur1}{x_0}\color{black}{}-\color{couleur1}{\mu_0}\color{black}{})}\), while in the multivariate case, an analogous expression can be derived: \[ \color{couleur2}{\boldsymbol{x}_1}\color{black}{} = \mathcal{T}^\star_{\mathcal{N}}(\color{couleur1}{\boldsymbol{x}_0}\color{black}{})=\color{couleur2}{\boldsymbol{\mu}_1}\color{black}{} + \boldsymbol{A}(\color{couleur1}{\boldsymbol{x}_0}\color{black}{}-\color{couleur1}{\boldsymbol{\mu}_0}\color{black}{}), \] where \(\boldsymbol{A}\) is a symmetric positive matrix that satisfies \(\boldsymbol{A}\boldsymbol{\Sigma}_0\boldsymbol{A}=\boldsymbol{\Sigma}_1\), which has a unique solution given by \(\boldsymbol{A}=\color{couleur1}{\boldsymbol{\Sigma}_0}^{\color{black}{-1/2}}\color{black}{}\big(\color{couleur1}{\boldsymbol{\Sigma}_0}^{\color{black}{1/2}}\color{couleur2}{\boldsymbol{\Sigma}_1}\color{couleur1}{\boldsymbol{\Sigma}_0}^{\color{black}{1/2}}\color{black}{}\big)^{1/2}\color{couleur1}{\boldsymbol{\Sigma}_0}^{\color{black}{-1/2}}\), where \(\boldsymbol{M}^{1/2}\) is the square root of the square (symmetric) positive matrix \(\boldsymbol{M}\) based on the Schur decomposition (\(\boldsymbol{M}^{1/2}\) is a positive symmetric matrix), as described in Higham (2008).

What is Needed

As in the univariate case, to estimate the mutatis mutandis Sample Conditional Average Treatment Effect, the following are required:

  1. The two estimated models \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\).
  2. A function to predict new values with these each of these two models.
  3. A transport method \(\mathcal{T}(\cdot)\).

We will compute the SCATE at the following values of \(\boldsymbol{x}\):

val_birth_weight <- seq(1800, 4600, length = 251)
val_wtgain <- seq(0, 90, length = 251)
val_grid <- expand.grid(wtgain = val_wtgain,  birth_weight = val_birth_weight) %>% as_tibble()
val_grid
# A tibble: 63,001 × 2
   wtgain birth_weight
    <dbl>        <dbl>
 1   0            1800
 2   0.36         1800
 3   0.72         1800
 4   1.08         1800
 5   1.44         1800
 6   1.8          1800
 7   2.16         1800
 8   2.52         1800
 9   2.88         1800
10   3.24         1800
# … with 62,991 more rows

3 Models

Now, let us estimate \(\color{couleur1}{\widehat{m}_0}\) and \(\color{couleur2}{\widehat{m}_1}\) with a logistic GAM model (cubic spline first, then with more knots and degrees). To that end, we can define a function that will estimate both models and return them within a list.

library(splines)

#' Returns $\hat{m}_0()$ and $\hat{m}_1()$, using GAM
#' @param target name of the target variable
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param x_m_names names of the mediator variable
#' @param data data frame with the observations
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
#' @param df degrees of freedom (in the `bs()` function, deefault to `df=3`)
models_spline <- function(target, treatment_name, x_m_names, data, treatment_0, treatment_1, df = 3){
  
  # \hat{m}_0()
  formula_glm <- paste0(target, "~", paste0("bs(", x_m_names, ", df = ",df, ")", collapse = " + "))
  reg_0 <- bquote(
    glm(formula = .(formula_glm), data=data, family = binomial,
        subset = (.(treatment_name) == .(treatment_0))),
    list(
      formula_glm = formula_glm,
      treatment_name = as.name(treatment_name),
      treatment_0 = treatment_0)
  ) %>% eval()
  # \hat{m}_1()
  reg_1 <- bquote(
    glm(formula = .(formula_glm), data=data, family = binomial,
        subset = (.(treatment_name) == .(treatment_1))),
    list(
      formula_glm = formula_glm,
      treatment_name = as.name(treatment_name),
      treatment_1 = treatment_1)
  ) %>% eval()
  
  list(reg_0 = reg_0, reg_1 = reg_1)
}

3.1 GAM (with cubic splines)

target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"

\(\color{couleur1}{\widehat{m}_0}\) and \(\color{couleur2}{\widehat{m}_1}\):

reg_gam_smoker <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_names = x_m_names,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(3, 3))
reg_gam_smoker_0 <- reg_gam_smoker$reg_0
reg_gam_smoker_1 <- reg_gam_smoker$reg_1
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
reg_gam_blackm <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_names = x_m_names,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(3, 3))
reg_gam_blackm_0 <- reg_gam_blackm$reg_0
reg_gam_blackm_1 <- reg_gam_blackm$reg_1
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"

\(\color{couleur1}{\widehat{m}_0}\) and \(\color{couleur2}{\widehat{m}_1}\):

reg_gam_sex <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_names = x_m_names,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(3, 3))
reg_gam_sex_0 <- reg_gam_sex$reg_0
reg_gam_sex_1 <- reg_gam_sex$reg_1

3.2 GAM (with more knots)

target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"

\(\color{couleur1}{\widehat{m}_0}\) and \(\color{couleur2}{\widehat{m}_1}\):

reg_gam_2_smoker <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_names = x_m_names,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(12, 6))
reg_gam_2_smoker_0 <- reg_gam_2_smoker$reg_0
reg_gam_2_smoker_1 <- reg_gam_2_smoker$reg_1
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
reg_gam_2_blackm <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_names = x_m_names,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(12, 6))
reg_gam_2_blackm_0 <- reg_gam_2_blackm$reg_0
reg_gam_2_blackm_1 <- reg_gam_2_blackm$reg_1
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"

\(\color{couleur1}{\widehat{m}_0}\) and \(\color{couleur2}{\widehat{m}_1}\):

reg_gam_2_sex <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_names = x_m_names,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(12, 6))
reg_gam_2_sex_0 <- reg_gam_2_sex$reg_0
reg_gam_2_sex_1 <- reg_gam_2_sex$reg_1

4 Prediction Functions

4.1 GAM

The prediction function for GAM:

#' @param object regression model (GAM)
#' @param newdata data frame in which to look for the mediator variable used to predict the target
model_spline_predict <- function(object, newdata){
  predict(object, newdata = newdata, type="response")
}

5 Transport

Let us now turn to the transport of the mediator variables.

5.1 Gaussian assumption

We assume the mediator variables to be both Normally distributed. The parameters used to transport the mediator variables under the Gaussian assumption can be estimated thanks to the following function:

#' Optimal Transport assuming Gaussian distribution for the mediator variables (helper function)
#' @return A list with the mean and variance of the mediator in each subset, and the symmetric matrix A.
#' @param target name of the target variable
#' @param x_m_names vector of names of the mediator variables
#' @param scale vector of scaling to apply to each `x_m_names` variable to transport (default to 1)
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
transport_gaussian_param <- function(target, x_m_names, scale=1, treatment_name, treatment_0, treatment_1){
  base_0_unscaled <- 
    base %>% 
    select(!!c(x_m_names, treatment_name)) %>% 
    filter(!!sym(treatment_name) ==  treatment_0)
  base_1_unscaled <- 
    base %>% 
    select(!!c(x_m_names, treatment_name)) %>% 
    filter(!!sym(treatment_name) ==  treatment_1)
  
  for(i in 1:length(scale)){
    base_0_scaled <- 
      base_0_unscaled %>% 
      mutate(!!sym(x_m_names[i]) := !!sym(x_m_names[i]) * scale[i])
    base_1_scaled <- 
      base_1_unscaled %>% 
      mutate(!!sym(x_m_names[i]) := !!sym(x_m_names[i]) * scale[i])
  }
  
  # Mean in each subset (i.e., T=0, T=1)
  m_0 <- base_0_scaled %>% summarise(across(!!x_m_names, mean)) %>% as_vector()
  m_1 <- base_1_scaled %>% summarise(across(!!x_m_names, mean)) %>% as_vector()
  # Variance
  S_0 <- base_0_scaled %>% select(!!x_m_names) %>% var()
  S_1 <- base_1_scaled %>% select(!!x_m_names) %>% var()
  # Matrix A
  A <- (solve(sqrtm(S_0))) %*% sqrtm( sqrtm(S_0) %*% S_1 %*% (sqrtm(S_0)) ) %*% solve(sqrtm(S_0))
  
  list(m_0 = m_0, m_1 = m_1, S_0 = S_0, S_1 = S_1, A = A, scale = scale, x_m_names = x_m_names)
}

We create a function that transports a single observation, given the different parameters:

#' Gaussian transport for a single observation (helper function)
#' @param z vector of variables to transport (for a single observation)
#' @param A symmetric positive matrix that satisfies \(A\sigma_0A=\sigma_1\)
#' @param m_0,m_1 vectors of mean values in the subsets \(\mathcal{D}_0\) and \(\mathcal{D}_1\)
#' @param scale vector of scaling to apply to each variable to transport (default to 1)
T_C_single <- function(z, A, m_0, m_1, scale = 1){
  z <- z*scale
  as.vector(m_1 + A %*% (z-m_0))*(1/scale)
}

Lastly, we define a function that transports the mediator variables, under the Gaussian assumption. If the params argument is NULL, the arguments needed to estimate them through the transport_gaussian_param() function are required (target, x_m_names, treatment_name, treatment_0, treatment_1).

#' Gaussian transport
#' @param z data frame of variables to transport
#' @param params (optional) parameters to use for the transport (result of `transport_gaussian_param()`, default to NULL)
#' @param target name of the target variable
#' @param x_m_names vector of names of the mediator variables
#' @param scale vector of scaling to apply to each `x_m_names` variable to transport (default to 1)
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
gaussian_transport <- function(z, params = NULL, ..., x_m_names, scale, treatment_name, treatment_0, treatment_1){
  if(is.null(params)){
    # If the parameters of the transport function are not provided
    # they need to be computed first
    params <- 
      transport_gaussian_param(target = target, x_m_names = x_m_names, 
                               scale = scale, treatment_name = treatment_name, 
                               treatment_0 = treatment_0, treatment_1 = treatment_1)
  }
  A <- params$A
  m_0 <- params$m_0 ; m_1 <- params$m_1
  scale <- params$scale
  x_m_names <- params$x_m_names
  
  values_to_transport <- z %>% select(!!x_m_names)
  transported_val <- 
    apply(values_to_transport, 1, T_C_single, A = A, m_0 = m_0, m_1 = m_1,
          scale = scale, simplify = FALSE) %>% 
    do.call("rbind", .)
  colnames(transported_val) <- colnames(z)
  transported_val <- as_tibble(transported_val)
  
  structure(.Data = transported_val, params = params)
}

The transported values:

target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
val_grid_t_n_smoker <- 
  gaussian_transport(z = val_grid, target = target,
                     x_m_names = x_m_names, 
                     scale = scale, treatment_name = treatment_name,
                     treatment_0 = treatment_0, treatment_1 = treatment_1)
val_grid_t_n_smoker
# A tibble: 63,001 × 2
   wtgain birth_weight
 *  <dbl>        <dbl>
 1  -5.12        1542.
 2  -4.70        1543.
 3  -4.28        1543.
 4  -3.86        1544.
 5  -3.44        1545.
 6  -3.03        1545.
 7  -2.61        1546.
 8  -2.19        1546.
 9  -1.77        1547.
10  -1.35        1548.
# … with 62,991 more rows
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
val_grid_t_n_blackm <- 
  gaussian_transport(z = val_grid, target = target,
                     x_m_names = x_m_names, 
                     scale = scale, treatment_name = treatment_name,
                     treatment_0 = treatment_0, treatment_1 = treatment_1)
val_grid_t_n_blackm
# A tibble: 63,001 × 2
   wtgain birth_weight
 *  <dbl>        <dbl>
 1  -5.85        1429.
 2  -5.44        1430.
 3  -5.03        1430.
 4  -4.62        1430.
 5  -4.21        1430.
 6  -3.80        1430.
 7  -3.39        1430.
 8  -2.98        1430.
 9  -2.58        1430.
10  -2.17        1430.
# … with 62,991 more rows
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
val_grid_t_n_sex <- 
  gaussian_transport(z = val_grid, target = target,
                     x_m_names = x_m_names, 
                     scale = scale, treatment_name = treatment_name,
                     treatment_0 = treatment_0, treatment_1 = treatment_1)
val_grid_t_n_sex
# A tibble: 63,001 × 2
   wtgain birth_weight
 *  <dbl>        <dbl>
 1 -0.470        1770.
 2 -0.113        1770.
 3  0.244        1770.
 4  0.602        1769.
 5  0.959        1769.
 6  1.32         1769.
 7  1.67         1769.
 8  2.03         1769.
 9  2.39         1768.
10  2.74         1768.
# … with 62,991 more rows

The estimated parameters can be extracted as follows:

params_smoker <- attr(val_grid_t_n_smoker, "params")
params_smoker
$m_0
      wtgain birth_weight 
    30.40885     32.92174 

$m_1
      wtgain birth_weight 
    30.55169     31.02590 

$S_0
                wtgain birth_weight
wtgain       218.95677     14.54058
birth_weight  14.54058     34.09920

$S_1
                wtgain birth_weight
wtgain       297.62083     22.09039
birth_weight  22.09039     35.41197

$A
           [,1]       [,2]
[1,] 1.16471785 0.01715655
[2,] 0.01715655 1.01085009

$scale
[1] 1.00 0.01

$x_m_names
[1] "wtgain"       "birth_weight"

The estimated parameters can be extracted as follows:

params_blackm <- attr(val_grid_t_n_blackm, "params")
params_blackm
$m_0
      wtgain birth_weight 
    30.67049     33.09063 

$m_1
      wtgain birth_weight 
    29.02406     31.00026 

$S_0
                wtgain birth_weight
wtgain       215.83720     14.15844
birth_weight  14.15844     32.86016

$S_1
               wtgain birth_weight
wtgain       278.8963     17.99050
birth_weight  17.9905     40.18716

$A
             [,1]         [,2]
[1,] 1.1366862866 0.0007016672
[2,] 0.0007016672 1.1055783355

$scale
[1] 1.00 0.01

$x_m_names
[1] "wtgain"       "birth_weight"

The estimated parameters can be extracted as follows:

params_sex <- attr(val_grid_t_n_sex, "params")
params_sex
$m_0
      wtgain birth_weight 
    30.79877     33.32068 

$m_1
      wtgain birth_weight 
    30.00267     32.17229 

$S_0
               wtgain birth_weight
wtgain       227.8193     16.11780
birth_weight  16.1178     35.86006

$S_1
                wtgain birth_weight
wtgain       224.11168     13.81543
birth_weight  13.81543     32.60395

$A
            [,1]        [,2]
[1,]  0.99222733 -0.00565936
[2,] -0.00565936  0.95595986

$scale
[1] 1.00 0.01

$x_m_names
[1] "wtgain"       "birth_weight"

Let us now consider four different individuals \(\boldsymbol{x}\) in the control group:

individuals <- 
  tibble(birth_weight = c(2584, 2584, 4152, 4152), wtgain = c(10.8, 46.8, 10.8, 46.8))

individuals_trans <- 
  individuals %>% bind_cols(
    gaussian_transport(z = individuals, params = params_smoker) %>% 
      rename_all(~paste0(., "_t")) %>% 
      mutate(target = "smoker")) %>%
  bind_rows(
    individuals %>% bind_cols(
      gaussian_transport(z = individuals, params = params_blackm) %>% 
        rename_all(~paste0(., "_t")) %>% 
        mutate(target = "black")
    )
  ) %>% 
  bind_rows(
    individuals %>% bind_cols(
      gaussian_transport(z = individuals, params = params_sex) %>% 
        rename_all(~paste0(., "_t")) %>% 
        mutate(target = "girl")
    )
  )
individuals_trans %>% 
  pivot_wider(names_from = target, values_from = c(birth_weight_t, wtgain_t)) %>% 
  relocate(paste0(rep(c("birth_weight_t_", "wtgain_t_"), 3), rep(c("smoker", "black", "girl"), each = 2)), .after = wtgain) %>% 
  rename_with(~str_replace(., "birth_weight_t", "bw"), contains("birth_weight_t")) %>% 
  rename_with(~str_replace(., "wtgain_t", "wg"), contains("wtgain_t")) %>% 
  knitr::kable(format = "html", digits = 1) %>% 
  kableExtra::kable_classic(full_width = F, html_font = "Cambria") %>%
  kableExtra::add_header_above(c(" "=2, "Smoker" = 2, "Black" = 2, "Girl" = 2))
Table 1. Bivariate optimal transport, \boldsymbol{x}\mapsto \mathcal{T}_{\mathcal{N}}(\boldsymbol{x}), for the three treatments, for four different individuals \boldsymbol{x} in the control group.
Smoker
Black
Girl
birth_weight wtgain bw_smoker wg_smoker bw_black wg_black bw_girl wg_girl
2584 10.8 7.6 2353.1 6.4 2297.0 10.2 2513.4
2584 46.8 49.5 2414.9 47.4 2299.5 45.9 2493.1
4152 10.8 7.9 3938.1 6.4 4030.6 10.1 4012.4
4152 46.8 49.8 3999.9 47.4 4033.1 45.8 3992.0

5.1.1 Vector Field

Let us visualize on a plot the transported values for some points, using arrows. For each point, the origin of an arrow corresponds to \(\boldsymbol{x} = (\text{birth weight}, \text{weight gain})\), while its end corresponds to \(\mathcal{T}_\mathcal{N} (\boldsymbol{x})\).

# To draw arrows that changes color from start to end: code from Greg Snow
# Source: https://stackoverflow.com/a/20430004
csa <- function(x1, y1, x2, y2, first_col,second_col, length = .035, ...) {
  cols <- colorRampPalette( c(first_col,second_col) )(250)
  x <- approx(c(0,1),c(x1,x2), xout=seq(0,1,length.out=251))$y
  y <- approx(c(0,1),c(y1,y2), xout=seq(0,1,length.out=251))$y
  
  arrows(x[250],y[250],x[251],y[251], col=cols[250], length = length, ...)
  segments(x[-251],y[-251],x[-1],y[-1],col=cols, ...)
  
}
color.scale.arrow <- Vectorize(csa, c('x1','y1','x2','y2') )
Show the codes used to create the plot
image(val_birth_weight, val_wtgain, matrix(0, length(val_wtgain), length(val_birth_weight), byrow = TRUE),
      xlab = "",
      ylab = "Weight gain of the mother",
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600), col = "white")

mtext(expression("Weight of the baby (mother " * phantom("non-smoker") * " → " * phantom("smoker") *")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (mother ") * "non-smoker" * phantom(" → smoker)")),
      side=1,line=3, col = couleur1)
mtext(expression(phantom("Weight of the baby (mother non-smoker → ") * "smoker" * phantom(")")),
      side=1,line=3, col = couleur2)
axis(1)
axis(2)
rect(1800,0,4600,90)
for(i in seq(11,251,by=20)){
  for(j in seq(11,251,by=20)){
    z <- gaussian_transport(z = tibble(wtgain = val_wtgain[j], birth_weight = val_birth_weight[i]),
                            params = params_smoker)
    
    csa(x1 = val_birth_weight[i], y1 = val_wtgain[j], x2 = z$birth_weight, y2 = z$wtgain, 
        first_col = couleur1, second_col = couleur2, lwd = 2)
  }}

Figure 4. Vector field associated with optimal Gaussian transport, in dimension two (weight of the newborn infant and weight gain of the mother) from non-smoker to smoker mother.

Show the codes used to create the plot
image(val_birth_weight, val_wtgain, matrix(0, length(val_wtgain), length(val_birth_weight), byrow = TRUE),
      xlab = "",
      ylab = "Weight gain of the mother",
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600), col = "white")

mtext(expression("Weight of the baby (" * phantom("non-Black") * " → " * phantom("Black") *" mother)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "non-Black" * phantom(" → Black mother)")),
      side=1,line=3, col = couleur1)
mtext(expression(phantom("Weight of the baby (non-Black → ") * "Black" * phantom(" mother)")),
      side=1,line=3, col = couleur2)
axis(1)
axis(2)
rect(1800,0,4600,90)
for(i in seq(11,251,by=20)){
  for(j in seq(11,251,by=20)){
    z <- gaussian_transport(z = tibble(wtgain = val_wtgain[j], birth_weight = val_birth_weight[i]),
                            params = params_blackm)
    
    csa(x1 = val_birth_weight[i], y1 = val_wtgain[j], x2 = z$birth_weight, y2 = z$wtgain, 
        first_col = couleur1, second_col = couleur2, lwd = 2)
  }}

Figure 5. Vector field associated with optimal Gaussian transport, in dimension two (weight of the newborn infant and weight gain of the mother), from non-Black to Black mother.

Show the codes used to create the plot
image(val_birth_weight, val_wtgain, matrix(0, length(val_wtgain), length(val_birth_weight), byrow = TRUE),
      xlab = "",
      ylab = "Weight gain of the mother",
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600), col = "white")

mtext(expression("Weight of the baby (" * phantom("baby boy") * " → " * phantom("baby girl") *")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby boy" * phantom(" → baby girl)")),
      side=1,line=3, col = couleur1)
mtext(expression(phantom("Weight of the baby (baby boy → ") * "baby girl" * phantom(")")),
      side=1,line=3, col = couleur2)
axis(1)
axis(2)
rect(1800,0,4600,90)
for(i in seq(11,251,by=20)){
  for(j in seq(11,251,by=20)){
    z <- gaussian_transport(z = tibble(wtgain = val_wtgain[j], birth_weight = val_birth_weight[i]),
                            params = params_sex)
    
    csa(x1 = val_birth_weight[i], y1 = val_wtgain[j], x2 = z$birth_weight, y2 = z$wtgain, 
        first_col = couleur1, second_col = couleur2, lwd = 2)
  }}

Figure 6. Vector field associated with optimal Gaussian transport, in dimension two (weight of the newborn infant and weight gain of the mother), from baby boy to baby girl mother.

6 Estimation

Now that the two models \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\) are defined and fitted to the data, we can compute the Mutatis Mutandis Sample Conditional Average Effect.

#' Computes the Sample Average Treatment Effect with and without transport
#' @param x_0 vector of values at which to compute $\hat{m}_0(x_0)$, $\hat{m}_1(x_0)$
#' @param x_t vector of values at which to compute $\hat{m}_1(\mathcal{T}(x_t))$
#' @param x_m_name vector of names of the mediator variables
#' @param mod_0 model $\hat{m}_0$
#' @param mod_1 model $\hat{m}_1$
#' @param pred_mod_0 prediction function for model $\hat{m}_0(\cdot)$
#' @param pred_mod_1 prediction function for model $\hat{m}_1(\cdot)$
#' @param return_x if `TRUE` (default) the mediator variables are returned in the table, as well as their transported values
sate <- function(x_0, x_t, x_m_names, mod_0, mod_1, pred_mod_0, pred_mod_1, return_x = TRUE){
  if(is.vector(x_0)){
    # Univariate case
    new_data <- tibble(!!x_m_names := x_0)
  }else{
    new_data <- x_0
  }
  if(is.vector(x_t)){
    # Univariate case
    new_data_t <- tibble(!!x_m_names := x_t)
  }else{
    new_data_t <- x_t
  }
  
  
  # $\hat{m}_0(x_0)$
  y_0 <- pred_mod_0(object = mod_0, newdata = new_data)
  # $\hat{m}_1(x_0)$
  y_1 <- pred_mod_1(object = mod_1, newdata = new_data)
  # $\hat{m}_1(\mathcal{T}(x_0))$
  y_1_t <- pred_mod_1(object = mod_1, newdata = new_data_t)
  
  
  scate_tab <- 
    tibble(y_0 = y_0, y_1 = y_1, y_1_t = y_1_t,
         CATE = y_1-y_0, SCATE = y_1_t-y_0)
  
  if(return_x){
    new_data_t <- new_data_t %>% rename_all(~paste0(.x, "_t"))
    scate_tab <- bind_cols(new_data, new_data_t, scate_tab)
  }
  scate_tab
}

6.1 GAM (with cubic splines), Gaussian assumption for transport

Let us compute here the SCATE where the models \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\) are GAM and where the transport method assumes a Gaussian distribution for both mediators.

mm_sate_gam_smoker <- 
  sate(
    x_0 = val_grid,
    x_t = val_grid_t_n_smoker,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_gam_smoker_0,
    mod_1 = reg_gam_smoker_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate_gam_smoker
# A tibble: 63,001 × 9
   wtgain birth_weight wtgain_t birth_weight_t   y_0   y_1 y_1_t    CATE  SCATE
    <dbl>        <dbl>    <dbl>          <dbl> <dbl> <dbl> <dbl>   <dbl>  <dbl>
 1   0            1800    -5.12          1542. 0.589 0.536 0.606 -0.0536 0.0167
 2   0.36         1800    -4.70          1543. 0.587 0.534 0.604 -0.0531 0.0164
 3   0.72         1800    -4.28          1543. 0.585 0.533 0.601 -0.0527 0.0160
 4   1.08         1800    -3.86          1544. 0.584 0.531 0.599 -0.0523 0.0157
 5   1.44         1800    -3.44          1545. 0.582 0.530 0.597 -0.0519 0.0154
 6   1.8          1800    -3.03          1545. 0.580 0.528 0.595 -0.0515 0.0150
 7   2.16         1800    -2.61          1546. 0.578 0.527 0.593 -0.0511 0.0147
 8   2.52         1800    -2.19          1546. 0.576 0.526 0.591 -0.0508 0.0144
 9   2.88         1800    -1.77          1547. 0.575 0.524 0.589 -0.0504 0.0140
10   3.24         1800    -1.35          1548. 0.573 0.523 0.587 -0.0500 0.0137
# … with 62,991 more rows
mm_sate_gam_blackm <- 
  sate(
    x_0 = val_grid,
    x_t = val_grid_t_n_blackm,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_gam_blackm_0,
    mod_1 = reg_gam_blackm_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate_gam_blackm
# A tibble: 63,001 × 9
   wtgain birth_weight wtgain_t birth_weight_t   y_0   y_1 y_1_t    CATE  SCATE
    <dbl>        <dbl>    <dbl>          <dbl> <dbl> <dbl> <dbl>   <dbl>  <dbl>
 1   0            1800    -5.85          1429. 0.592 0.543 0.621 -0.0485 0.0299
 2   0.36         1800    -5.44          1430. 0.590 0.542 0.620 -0.0475 0.0303
 3   0.72         1800    -5.03          1430. 0.588 0.541 0.618 -0.0466 0.0307
 4   1.08         1800    -4.62          1430. 0.586 0.540 0.617 -0.0457 0.0311
 5   1.44         1800    -4.21          1430. 0.584 0.539 0.615 -0.0447 0.0315
 6   1.8          1800    -3.80          1430. 0.582 0.538 0.614 -0.0438 0.0319
 7   2.16         1800    -3.39          1430. 0.580 0.537 0.613 -0.0430 0.0323
 8   2.52         1800    -2.98          1430. 0.579 0.536 0.611 -0.0421 0.0327
 9   2.88         1800    -2.58          1430. 0.577 0.536 0.610 -0.0412 0.0330
10   3.24         1800    -2.17          1430. 0.575 0.535 0.609 -0.0404 0.0334
# … with 62,991 more rows
mm_sate_gam_sex <- 
  sate(
    x_0 = val_grid,
    x_t = val_grid_t_n_sex,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_gam_sex_0,
    mod_1 = reg_gam_sex_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate_gam_sex
# A tibble: 63,001 × 9
   wtgain birth_weight wtgain_t birth_weight_t   y_0   y_1 y_1_t    CATE   SCATE
    <dbl>        <dbl>    <dbl>          <dbl> <dbl> <dbl> <dbl>   <dbl>   <dbl>
 1   0            1800   -0.470          1770. 0.583 0.584 0.593 6.88e-4 0.00972
 2   0.36         1800   -0.113          1770. 0.581 0.582 0.591 8.24e-4 0.00988
 3   0.72         1800    0.244          1770. 0.579 0.580 0.589 9.58e-4 0.0100 
 4   1.08         1800    0.602          1769. 0.578 0.579 0.588 1.09e-3 0.0102 
 5   1.44         1800    0.959          1769. 0.576 0.577 0.586 1.22e-3 0.0103 
 6   1.8          1800    1.32           1769. 0.574 0.575 0.584 1.34e-3 0.0105 
 7   2.16         1800    1.67           1769. 0.572 0.574 0.583 1.47e-3 0.0106 
 8   2.52         1800    2.03           1769. 0.570 0.572 0.581 1.59e-3 0.0108 
 9   2.88         1800    2.39           1768. 0.569 0.571 0.580 1.71e-3 0.0109 
10   3.24         1800    2.74           1768. 0.567 0.569 0.578 1.82e-3 0.0110 
# … with 62,991 more rows

6.2 GAM (with more knots), Gaussian assumption for transport

mm_sate_gam_2_smoker <- 
  sate(
    x_0 = val_grid,
    x_t = val_grid_t_n_smoker,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_gam_2_smoker_0,
    mod_1 = reg_gam_2_smoker_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate_gam_2_smoker
# A tibble: 63,001 × 9
   wtgain birth_weight wtgain_t birth_weight_t   y_0   y_1 y_1_t    CATE   SCATE
    <dbl>        <dbl>    <dbl>          <dbl> <dbl> <dbl> <dbl>   <dbl>   <dbl>
 1   0            1800    -5.12          1542. 0.675 0.589 0.610 -0.0867 -0.0656
 2   0.36         1800    -4.70          1543. 0.675 0.590 0.615 -0.0852 -0.0605
 3   0.72         1800    -4.28          1543. 0.675 0.591 0.619 -0.0839 -0.0558
 4   1.08         1800    -3.86          1544. 0.675 0.592 0.623 -0.0827 -0.0514
 5   1.44         1800    -3.44          1545. 0.675 0.593 0.627 -0.0817 -0.0474
 6   1.8          1800    -3.03          1545. 0.674 0.594 0.631 -0.0808 -0.0437
 7   2.16         1800    -2.61          1546. 0.674 0.594 0.634 -0.0801 -0.0403
 8   2.52         1800    -2.19          1546. 0.674 0.594 0.637 -0.0795 -0.0372
 9   2.88         1800    -1.77          1547. 0.674 0.595 0.639 -0.0791 -0.0343
10   3.24         1800    -1.35          1548. 0.674 0.595 0.642 -0.0788 -0.0318
# … with 62,991 more rows
mm_sate_gam_2_blackm <- 
  sate(
    x_0 = val_grid,
    x_t = val_grid_t_n_blackm,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_gam_2_blackm_0,
    mod_1 = reg_gam_2_blackm_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate_gam_2_blackm
# A tibble: 63,001 × 9
   wtgain birth_weight wtgain_t birth_weight_t   y_0   y_1 y_1_t    CATE  SCATE
    <dbl>        <dbl>    <dbl>          <dbl> <dbl> <dbl> <dbl>   <dbl>  <dbl>
 1   0            1800    -5.85          1429. 0.675 0.619 0.728 -0.0562 0.0523
 2   0.36         1800    -5.44          1430. 0.675 0.618 0.725 -0.0569 0.0491
 3   0.72         1800    -5.03          1430. 0.675 0.618 0.722 -0.0575 0.0462
 4   1.08         1800    -4.62          1430. 0.675 0.618 0.719 -0.0580 0.0435
 5   1.44         1800    -4.21          1430. 0.675 0.617 0.716 -0.0583 0.0410
 6   1.8          1800    -3.80          1430. 0.675 0.617 0.714 -0.0585 0.0388
 7   2.16         1800    -3.39          1430. 0.675 0.617 0.712 -0.0585 0.0368
 8   2.52         1800    -2.98          1430. 0.675 0.617 0.710 -0.0585 0.0350
 9   2.88         1800    -2.58          1430. 0.675 0.617 0.708 -0.0584 0.0335
10   3.24         1800    -2.17          1430. 0.675 0.617 0.707 -0.0582 0.0322
# … with 62,991 more rows
mm_sate_gam_2_sex <- 
  sate(
    x_0 = val_grid,
    x_t = val_grid_t_n_sex,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_gam_2_sex_0,
    mod_1 = reg_gam_2_sex_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate_gam_2_sex
# A tibble: 63,001 × 9
   wtgain birth_weight wtgain_t birth_weigh…¹   y_0   y_1 y_1_t     CATE   SCATE
    <dbl>        <dbl>    <dbl>         <dbl> <dbl> <dbl> <dbl>    <dbl>   <dbl>
 1   0            1800   -0.470         1770. 0.666 0.661 0.670 -0.00501 0.00390
 2   0.36         1800   -0.113         1770. 0.666 0.661 0.670 -0.00458 0.00442
 3   0.72         1800    0.244         1770. 0.665 0.661 0.670 -0.00416 0.00492
 4   1.08         1800    0.602         1769. 0.665 0.661 0.670 -0.00376 0.00541
 5   1.44         1800    0.959         1769. 0.665 0.661 0.671 -0.00338 0.00588
 6   1.8          1800    1.32          1769. 0.664 0.661 0.671 -0.00301 0.00634
 7   2.16         1800    1.67          1769. 0.664 0.661 0.671 -0.00266 0.00679
 8   2.52         1800    2.03          1769. 0.664 0.661 0.671 -0.00232 0.00722
 9   2.88         1800    2.39          1768. 0.663 0.661 0.671 -0.00199 0.00764
10   3.24         1800    2.74          1768. 0.663 0.661 0.671 -0.00169 0.00804
# … with 62,991 more rows, and abbreviated variable name ¹​birth_weight_t

7 Results

7.1 GAM (with cubic splines), Gaussian assumption for transport

7.1.1 Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=0]\)

Now, we can plot some level curves to visualize the estimated probabilities of the target variable at different points in the 2-dimensional space in which the mediator lie. Let us focus on the sub-sample of the non-treated.

treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"

mat_gam_0_smoker <- matrix(mm_sate_gam_smoker$y_0, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_1_smoker <- matrix(mm_sate_gam_smoker$y_1, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_1_t_smoker <- matrix(mm_sate_gam_smoker$y_1_t, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
Show the codes used create the plot
image(val_birth_weight, val_wtgain, mat_gam_0_smoker,
      axes = FALSE, xlab = "", ylab = "", ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("non-smoker mother") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "non-smoker mother" * phantom(")")),
      side=1,line=3, col = couleur1)
mtext(expression("Weight gain of the mother (" * phantom("non-smoker") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "non-smoker" * phantom(")")),
      side=2, line=3, col = couleur1)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_0_smoker, add=TRUE)

Figure 7. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=0]\), with \(T=0\) indicating a non-smoker mother, estimated with logistic GAM models (cubic splines).

treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"

mat_gam_0_blackm <- matrix(mm_sate_gam_blackm$y_0, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_1_blackm <- matrix(mm_sate_gam_blackm$y_1, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_1_t_blackm <- matrix(mm_sate_gam_blackm$y_1_t, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
Show the codes used create the plot
image(val_birth_weight, val_wtgain, mat_gam_0_blackm,
      axes = FALSE, xlab = "", ylab = "", ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("non-Black mother") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "non-Black mother" * phantom(")")),
      side=1,line=3, col = couleur1)
mtext(expression("Weight gain of the mother (" * phantom("non-Black") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "non-Black" * phantom(")")),
      side=2, line=3, col = couleur1)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_0_blackm, add=TRUE)

Figure 8. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=0]\), with \(T=0\) indicating a non-Black mother, estimated with logistic GAM models (cubic splines).

treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"

mat_gam_0_sex <- matrix(mm_sate_gam_sex$y_0, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_1_sex <- matrix(mm_sate_gam_sex$y_1, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_1_t_sex <- matrix(mm_sate_gam_sex$y_1_t, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
Show the codes used create the plot
image(val_birth_weight, val_wtgain, mat_gam_0_sex,
      axes = FALSE, xlab = "", ylab = "", ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("baby boy") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby boy" * phantom(")")),
      side=1,line=3, col = couleur1)
mtext(expression("Weight gain of the mother (" * phantom("baby boy") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "baby boy" * phantom(")")),
      side=2, line=3, col = couleur1)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_0_sex, add=TRUE)

Figure 9. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=0]\), with \(T=0\) indicating a baby boy, estimated with logistic GAM models (cubic splines).

7.1.2 Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=1]\)

Let us focus now on the sub-sample of the treated.

Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_smoker,
      xlab = "", ylab = "", axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(")")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("smoker") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "smoker" * phantom(")")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_smoker, add=TRUE)

Figure 10. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=1]\), with \(T=1\) indicating a smoker mother, estimated with logistic GAM models (cubic splines).

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_blackm,
      xlab = "", ylab = "", axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("Black mother") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(")")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("Black") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "Black" * phantom(")")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_blackm, add=TRUE)

Figure 11. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=1]\), with \(T=1\) indicating a Black mother, estimated with logistic GAM models (cubic splines).

When the treatment \(T\) indicates that the baby is mother is a girl:

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_1_sex,
      xlab = "", ylab = "", axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("baby girl") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(")")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("baby girl") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "baby girl" * phantom(")")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_sex, add=TRUE)

Figure 12. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=1]\), with \(T=1\) indicating a baby girl, estimated with logistic GAM models (cubic splines).

7.1.3 Contours of the ceteris paribus \(\boldsymbol{x}\mapsto\text{CATE}[\boldsymbol{x}]\) without any transport

Now, we can turn to the estimation of the treatment effect. We will fist compute the CATE without any transport, and then present the Mutatis Mutandis SCATE.

Let us create a palette of colors ranging from dark blue to dark red.

nb_colors <- 18
CLR <- c(hcl.colors(nb_colors, palette = "Blues"), rev(hcl.colors(nb_colors, palette = "Reds")))
Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_smoker-mat_gam_0_smoker,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("smoker") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "smoker" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_smoker-mat_gam_0_smoker, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 13. Contours of the ceteris paribus \(\boldsymbol{x}\mapsto\text{CATE}[\boldsymbol{x}]\) without any transport.

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_blackm - mat_gam_0_blackm,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("Black mother") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("Black") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "Black" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_blackm-mat_gam_0_blackm, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 14. Contours of the ceteris paribus \(\boldsymbol{x}\mapsto\text{CATE}[\boldsymbol{x}]\) without any transport.

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_1_sex - mat_gam_0_sex,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("baby girl") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("girl") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "girl" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_sex-mat_gam_0_sex, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 15. Contours of the ceteris paribus \(\boldsymbol{x}\mapsto\text{CATE}[\boldsymbol{x}]\) without any transport.

7.1.4 Contours of the mutatis mutandis \(\boldsymbol{x}\mapsto\text{SCATE}[\boldsymbol{x}]\).

Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_t_smoker-mat_gam_0_smoker,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("smoker") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "smoker" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_t_smoker-mat_gam_0_smoker, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 16. Contours of the mutatis mutandis \(\boldsymbol{x}\mapsto\text{SCATE}[\boldsymbol{x}]\).

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_t_blackm-mat_gam_0_blackm,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("Black mother") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("Black") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "Black" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_t_blackm-mat_gam_0_blackm, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 17. Contours of the mutatis mutandis \(\boldsymbol{x}\mapsto\text{SCATE}[\boldsymbol{x}]\).

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_1_t_sex-mat_gam_0_sex,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("baby girl") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("girl") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "girl" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_t_sex-mat_gam_0_sex, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 18. Contours of the mutatis mutandis \(\boldsymbol{x}\mapsto\text{SCATE}[\boldsymbol{x}]\).

7.1.5 Positive vs negative CATE: boundaries, without transport

We can graph a positive/negative distinction for the conditional average treatment effect (positive is red, negative is blue).

Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_smoker-mat_gam_0_smoker,
      xlab = "",
      ylab = "", 
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("smoker") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "smoker" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_smoker-mat_gam_0_smoker,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 19. Positive/negative distinction for the conditional average treatment effect, without transport.

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_blackm-mat_gam_0_blackm,
      xlab = "",
      ylab = "", 
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("Black mother") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("Black") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "Black" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_blackm-mat_gam_0_blackm,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 20. Positive/negative distinction for the conditional average treatment effect, without transport.

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_1_sex-mat_gam_0_sex,
      xlab = "",
      ylab = "", 
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("baby girl") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("girl") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "girl" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_sex-mat_gam_0_sex,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 21. Positive/negative distinction for the conditional average treatment effect, without transport.

7.1.6 Positive vs negative CATE: boundaries, with transport

Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_t_smoker-mat_gam_0_smoker,
      xlab = "",
      ylab = "",
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("smoker") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "smoker" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_t_smoker-mat_gam_0_smoker,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 22. Positive/negative distinction for the conditional average treatment effect, with transport.

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_1_t_blackm-mat_gam_0_blackm,
      xlab = "",
      ylab = "",
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("Black mother") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("Black") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "Black" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_t_blackm-mat_gam_0_blackm,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 23. Positive/negative distinction for the conditional average treatment effect, with transport.

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_1_t_sex-mat_gam_0_sex,
      xlab = "",
      ylab = "",
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("baby girl") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("girl") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "girl" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_1_t_sex-mat_gam_0_sex,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 24. Positive/negative distinction for the conditional average treatment effect, with transport.

7.2 GAM (with more knots), Gaussian assumption for transport

7.2.1 Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=0]\)

Now, we can plot some level curves to visualize the estimated probabilities of the target variable at different points in the 2-dimensional space in which the mediator lie. Let us focus on the sub-sample of the non-treated.

treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"

mat_gam_2_0_smoker <- matrix(mm_sate_gam_2_smoker$y_0, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_2_1_smoker <- matrix(mm_sate_gam_2_smoker$y_1, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_2_1_t_smoker <- matrix(mm_sate_gam_2_smoker$y_1_t, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
Show the codes used create the plot
image(val_birth_weight, val_wtgain, mat_gam_2_0_smoker,
      axes = FALSE, xlab = "", ylab = "", ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("non-smoker mother") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "non-smoker mother" * phantom(")")),
      side=1,line=3, col = couleur1)
mtext(expression("Weight gain of the mother (" * phantom("non-smoker") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "non-smoker" * phantom(")")),
      side=2, line=3, col = couleur1)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_0_smoker, add=TRUE)

Figure 25. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=0]\), with \(T=0\) indicating a non-smoker mother, estimated with logistic GAM models (cubic splines, with more knots).

treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"

mat_gam_2_0_blackm <- matrix(mm_sate_gam_2_blackm$y_0, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_2_1_blackm <- matrix(mm_sate_gam_2_blackm$y_1, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_2_1_t_blackm <- matrix(mm_sate_gam_2_blackm$y_1_t, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
Show the codes used create the plot
image(val_birth_weight, val_wtgain, mat_gam_2_0_blackm,
      axes = FALSE, xlab = "", ylab = "", ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("non-Black mother") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "non-Black mother" * phantom(")")),
      side=1,line=3, col = couleur1)
mtext(expression("Weight gain of the mother (" * phantom("non-Black") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "non-Black" * phantom(")")),
      side=2, line=3, col = couleur1)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_0_blackm, add=TRUE)

Figure 26. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=0]\), with \(T=0\) indicating a non-Black mother, estimated with logistic GAM models (cubic splines, with more knots).

treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"

mat_gam_2_0_sex <- matrix(mm_sate_gam_2_sex$y_0, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_2_1_sex <- matrix(mm_sate_gam_2_sex$y_1, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
mat_gam_2_1_t_sex <- matrix(mm_sate_gam_2_sex$y_1_t, length(val_wtgain), length(val_birth_weight), byrow = TRUE)
Show the codes used create the plot
image(val_birth_weight, val_wtgain, mat_gam_2_0_sex,
      axes = FALSE, xlab = "", ylab = "", ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("baby boy") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby boy" * phantom(")")),
      side=1,line=3, col = couleur1)
mtext(expression("Weight gain of the mother (" * phantom("baby boy") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "baby boy" * phantom(")")),
      side=2, line=3, col = couleur1)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_0_sex, add=TRUE)

Figure 27. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=0]\), with \(T=0\) indicating a baby boy, estimated with logistic GAM models (cubic splines, with more knots).

7.2.2 Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=1]\)

Let us focus now on the sub-sample of the treated.

Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_smoker,
      xlab = "", ylab = "", axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(")")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("smoker") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "smoker" * phantom(")")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_smoker, add=TRUE)

Figure 28. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=1]\), with \(T=1\) indicating a smoker mother, estimated with logistic GAM models (cubic splines, with more knots).

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_blackm,
      xlab = "", ylab = "", axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("Black mother") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(")")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("Black") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "Black" * phantom(")")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_blackm, add=TRUE)

Figure 29. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=1]\), with \(T=1\) indicating a Black mother, estimated with logistic GAM models (cubic splines, with more knots).

When the treatment \(T\) indicates that the baby is mother is a girl:

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_2_1_sex,
      xlab = "", ylab = "", axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = hcl.colors(20, "YlOrRd", rev = TRUE), breaks=(0:20)/20)
mtext("Weight of the baby (quantile)",side=3,line=3)
mtext(expression("Weight of the baby (" * phantom("baby girl") * ")"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(")")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("baby girl") * ")"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "baby girl" * phantom(")")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_1],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_sex, add=TRUE)

Figure 30. Contours of \(\boldsymbol{x}\mapsto\mathbb{E}[Y|\boldsymbol{X}=\boldsymbol{x},T=1]\), with \(T=1\) indicating a baby girl, estimated with logistic GAM models (cubic splines, with more knots).

7.2.3 Contours of the ceteris paribus \(\boldsymbol{x}\mapsto\text{CATE}[\boldsymbol{x}]\) without any transport

Now, we can turn to the estimation of the treatment effect. We will fist compute the CATE without any transport, and then present the Mutatis Mutandis SCATE.

Let us create a palette of colors ranging from dark blue to dark red.

nb_colors <- 18
CLR <- c(hcl.colors(nb_colors, palette = "Blues"), rev(hcl.colors(nb_colors, palette = "Reds")))
Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_smoker-mat_gam_2_0_smoker,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("smoker") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "smoker" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_smoker-mat_gam_2_0_smoker, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 31. Contours of the ceteris paribus \(\boldsymbol{x}\mapsto\text{CATE}[\boldsymbol{x}]\) without any transport.

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_blackm - mat_gam_2_0_blackm,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("Black mother") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("Black") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "Black" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_blackm-mat_gam_2_0_blackm, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 32. Contours of the ceteris paribus \(\boldsymbol{x}\mapsto\text{CATE}[\boldsymbol{x}]\) without any transport.

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_2_1_sex - mat_gam_2_0_sex,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("baby girl") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("girl") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "girl" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_sex-mat_gam_2_0_sex, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 33. Contours of the ceteris paribus \(\boldsymbol{x}\mapsto\text{CATE}[\boldsymbol{x}]\) without any transport.

7.2.4 Contours of the mutatis mutandis \(\boldsymbol{x}\mapsto\text{SCATE}[\boldsymbol{x}]\).

Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_t_smoker-mat_gam_2_0_smoker,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("smoker") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "smoker" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_t_smoker-mat_gam_2_0_smoker, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 34. Contours of the mutatis mutandis \(\boldsymbol{x}\mapsto\text{SCATE}[\boldsymbol{x}]\).

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_t_blackm-mat_gam_2_0_blackm,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("Black mother") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("Black") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "Black" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_t_blackm-mat_gam_2_0_blackm, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 35. Contours of the mutatis mutandis \(\boldsymbol{x}\mapsto\text{SCATE}[\boldsymbol{x}]\).

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_2_1_t_sex-mat_gam_2_0_sex,
      xlab = "",
      ylab = "",
      axes=FALSE, ylim=c(0, 90), xlim = c(1800, 4600),
      col = CLR, breaks = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))
mtext(expression("Weight of the baby (" * phantom("baby girl") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("girl") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "girl" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_t_sex-mat_gam_2_0_sex, add = TRUE,
        levels = seq(-nb_colors/100, nb_colors/100, length = (2*nb_colors)+1))

Figure 36. Contours of the mutatis mutandis \(\boldsymbol{x}\mapsto\text{SCATE}[\boldsymbol{x}]\).

7.2.5 Positive vs negative CATE: boundaries, without transport

We can graph a positive/negative distinction for the conditional average treatment effect (positive is red, negative is blue).

Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_smoker-mat_gam_2_0_smoker,
      xlab = "",
      ylab = "", 
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("smoker") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "smoker" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_smoker-mat_gam_2_0_smoker,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 37. Positive/negative distinction for the conditional average treatment effect (GAM with more knots), without transport.

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_blackm-mat_gam_2_0_blackm,
      xlab = "",
      ylab = "", 
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("Black mother") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("Black") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "Black" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_blackm-mat_gam_2_0_blackm,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 38. Positive/negative distinction for the conditional average treatment effect (GAM with more knots), without transport.

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_2_1_sex-mat_gam_2_0_sex,
      xlab = "",
      ylab = "", 
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("baby girl") * ", no transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(", no transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother (" * phantom("girl") * ", no transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother (") * "girl" * phantom(", no transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_sex-mat_gam_2_0_sex,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 39. Positive/negative distinction for the conditional average treatment effect (GAM with more knots), without transport.

7.2.6 Positive vs negative CATE: boundaries, with transport

Show the codes used create the plot
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_t_smoker-mat_gam_2_0_smoker,
      xlab = "",
      ylab = "",
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("smoker mother") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "smoker mother" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("smoker") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "smoker" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_t_smoker-mat_gam_2_0_smoker,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 40. Positive/negative distinction for the conditional average treatment effect (GAM with more knots), with transport.

Show the codes used create the plot
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
image(val_birth_weight, val_wtgain, mat_gam_2_1_t_blackm-mat_gam_2_0_blackm,
      xlab = "",
      ylab = "",
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("Black mother") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "Black mother" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("Black") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "Black" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_t_blackm-mat_gam_2_0_blackm,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 41. Positive/negative distinction for the conditional average treatment effect (GAM with more knots), with transport.

Show the codes used create the plot
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
image(val_birth_weight, val_wtgain, mat_gam_2_1_t_sex-mat_gam_2_0_sex,
      xlab = "",
      ylab = "",
      axes = FALSE, ylim = c(0, 90), xlim = c(1800, 4600),
      col = CLR[c(10, 24)], breaks = c(-5, 0, 5))
mtext(expression("Weight of the baby (" * phantom("baby girl") * ", with transport)"),
      side=1,line=3, col = "black")
mtext(expression(phantom("Weight of the baby (") * "baby girl" * phantom(", with transport)")),
      side=1,line=3, col = couleur2)
mtext(expression("Weight gain of the mother(" * phantom("girl") * ", w/ transport)"),
      side=2,line=3)
mtext(expression(phantom("Weight gain of the mother(") * "girl" * phantom(", w/ transport)")),
      side=2, line=3, col = couleur2)
axis(1)
axis(2)
axis(3, at = quantile(base$birth_weight[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), 
     label = paste(seq(10, 90, by = 10), "%"))
axis(4, at = quantile(base$wtgain[pull(base, treatment_name) == treatment_0],
                      seq(10, 90, by = 10)/100), label = paste(seq(10, 90, by = 10), "%"))
contour(val_birth_weight, val_wtgain, mat_gam_2_1_t_sex-mat_gam_2_0_sex,
        add = TRUE, levels = c(-1, 0, 1), lwd = 2)

Figure 42. Positive/negative distinction for the conditional average treatment effect (GAM with more knots), with transport.

References

Charpentier, Arthur, Emmanuel Flachaire, and Ewen Gallic. 2023. “Causal Inference with Optimal Transport.” In Optimal Transport Statistics for Economics and Related Topics, edited by Nguyen Ngoc Thach, Vladik Kreinovich, Doan Thanh Ha, and Nguyen Duc Trung. Springer Verlag.
Higham, Nicholas J. 2008. Functions of Matrices: Theory and Computation. Society for Industrial; Applied Mathematics. https://doi.org/10.1137/1.9780898717778.
Villani, Cédric. 2003. Topics in Optimal Transportation. Vol. 58. American Mathematical Society. https://doi.org/10.1090/gsm/058.