Optimal Transport for Counterfactual Estimation: A Method for Causal Inference

Online Appendix: simulations (multivariate case)

Authors

Arthur Charpentier

Emmanuel Flachaire

Ewen Gallic

In this notebook, we compute the SCATE at three different points: \(\boldsymbol{x}=(2500,60)\), \(\boldsymbol{x}=(4200,60)\), and \(\boldsymbol{x}=(2500,20)\). At each point, we will estimate the SCATE using a GAM model with Gaussian transport depending on the number of observations used to train the models.

1 Data

To illustrate the methods, we rely on the Linked Birth/Infant Death Cohort Data (2013 cohort). The CSV files for the 2013 cohort were downloaded from the NBER collection of Birth Cohort Linked Birth and Infant Death Data of the National Vital Statistics System of the National Center for Health Statistics, on the NBER website. We formated the data (see ./births_stats.html).

Let us load the birth data.

library(tidyverse)
load("../data/births.RData")

Then, we can only keep a subsample of variables to illustrate the method.

base <- 
  births %>% 
  mutate(
    black_mother = mracerec == "Black",
    nonnatural_delivery = rdmeth_rec != "Vaginal"
  ) %>% 
  select(
    sex, dbwt, cig_rec, wtgain, 
    black_mother, nonnatural_delivery,
    mracerec, rdmeth_rec
  ) %>% 
  mutate(
    cig_rec = replace_na(cig_rec, "Unknown or not stated"),
    black_mother = ifelse(black_mother, yes = "Yes", no = "No"),
    is_girl = ifelse(sex == "Female", yes = "Yes", no = "No")
  ) %>% 
  labelled::set_variable_labels(
    black_mother = "Is the mother Black?",
    is_girl = "Is the newborn a girl?",
    nonnatural_delivery = "Is the delivery method non-natural?"
  ) %>% 
  rename(birth_weight = dbwt)
base
# A tibble: 3,940,764 × 9
   sex    birth_weight cig_rec wtgain black_mo…¹ nonna…² mrace…³ rdmet…⁴ is_girl
   <fct>         <dbl> <fct>    <dbl> <chr>      <lgl>   <fct>   <fct>   <chr>  
 1 Male           3017 No           1 No         FALSE   "White" Vaginal No     
 2 Female         3367 No          18 No         FALSE   "White" Vaginal Yes    
 3 Male           4409 No          36 No         FALSE   "White" Vaginal No     
 4 Male           4047 No          25 No         FALSE   "White" Vaginal No     
 5 Male           3119 Yes          0 No         FALSE   "Ameri… Vaginal No     
 6 Female         4295 Yes         34 No         FALSE   "White" Vaginal Yes    
 7 Male           3160 No          47 No         FALSE   "Ameri… Vaginal No     
 8 Male           3711 No          39 No         FALSE   "White" Vaginal No     
 9 Male           3430 No          NA No         FALSE   "White" Vaginal No     
10 Female         3772 No          36 No         FALSE   "White" Vaginal Yes    
# … with 3,940,754 more rows, and abbreviated variable names ¹​black_mother,
#   ²​nonnatural_delivery, ³​mracerec, ⁴​rdmeth_rec

The variables that were kept are the following:

  • sex: sex of the newborn
  • birth_weight: Birth Weight (in Grams)
  • cig_rec: mother smokes cigarettes
  • wtgain: mother weight gain
  • mracerec: mother’s race
  • black_mother: is mother black?
  • rdmeth_rec: delivery method
  • nonnatural_delivery: is the delivery method non-natural?

Then, let us discard individuals with missing values.

base <- 
  base %>% 
  filter(
    !is.na(birth_weight),
    !is.na(wtgain),
    !is.na(nonnatural_delivery),
    !is.na(black_mother)
  )

Let us define some colours for later use:

library(wesanderson)
colr1 <- wes_palette("Darjeeling1")
colr2 <- wes_palette("Darjeeling2")
couleur1 <- colr1[2]
couleur2 <- colr1[4]
couleur3 <- colr2[2]

coul1 <- "#882255"
coul2 <- "#DDCC77"

We need to load some packages.

library(car)
library(transport)
library(splines)
library(mgcv)
library(expm)
library(tidyverse)

Let us define the same functions as that used in reproduction_multivariate.html.

2 Functions

2.1 Models

Let us estimate \(\color{couleur1}{\widehat{m}_0}\) and \(\color{couleur2}{\widehat{m}_1}\) with a logistic GAM model (cubic spline first, then with more knots and degrees). To that end, we can define a function that will estimate both models and return them within a list.

library(splines)

#' Returns $\hat{m}_0()$ and $\hat{m}_1()$, using GAM
#' @param target name of the target variable
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param x_m_names names of the mediator variable
#' @param data data frame with the observations
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
#' @param df degrees of freedom (in the `bs()` function, deefault to `df=3`)
models_spline <- function(target, treatment_name, x_m_names, data, treatment_0, treatment_1, df = 3){
  
  # \hat{m}_0()
  formula_glm <- paste0(target, "~", paste0("bs(", x_m_names, ", df = ",df, ")", collapse = " + "))
  reg_0 <- bquote(
    glm(formula = .(formula_glm), data=data, family = binomial,
        subset = (.(treatment_name) == .(treatment_0))),
    list(
      formula_glm = formula_glm,
      treatment_name = as.name(treatment_name),
      treatment_0 = treatment_0)
  ) %>% eval()
  # \hat{m}_1()
  reg_1 <- bquote(
    glm(formula = .(formula_glm), data=data, family = binomial,
        subset = (.(treatment_name) == .(treatment_1))),
    list(
      formula_glm = formula_glm,
      treatment_name = as.name(treatment_name),
      treatment_1 = treatment_1)
  ) %>% eval()
  
  list(reg_0 = reg_0, reg_1 = reg_1)
}

2.1.1 GAM (with cubic splines)

target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"

\(\color{couleur1}{\widehat{m}_0}\) and \(\color{couleur2}{\widehat{m}_1}\):

reg_gam_smoker <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_names = x_m_names,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(3, 3))
reg_gam_smoker_0 <- reg_gam_smoker$reg_0
reg_gam_smoker_1 <- reg_gam_smoker$reg_1
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
reg_gam_blackm <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_names = x_m_names,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(3, 3))
reg_gam_blackm_0 <- reg_gam_blackm$reg_0
reg_gam_blackm_1 <- reg_gam_blackm$reg_1
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"

\(\color{couleur1}{\widehat{m}_0}\) and \(\color{couleur2}{\widehat{m}_1}\):

reg_gam_sex <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_names = x_m_names,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(3, 3))
reg_gam_sex_0 <- reg_gam_sex$reg_0
reg_gam_sex_1 <- reg_gam_sex$reg_1

2.2 Prediction Functions

The prediction function for GAM:

#' @param object regression model (GAM)
#' @param newdata data frame in which to look for the mediator variable used to predict the target
model_spline_predict <- function(object, newdata){
  predict(object, newdata = newdata, type="response")
}

2.3 Transport

2.3.1 Gaussian Assumption

The transport function can be defined as follows:

#' Optimal Transport assuming Gaussian distribution for the mediator variables (helper function)
#' @return A list with the mean and variance of the mediator in each subset, and the symmetric matrix A.
#' @param target name of the target variable
#' @param x_m_names vector of names of the mediator variables
#' @param scale vector of scaling to apply to each `x_m_names` variable to transport (default to 1)
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
transport_gaussian_param <- function(target, x_m_names, scale=1, treatment_name, treatment_0, treatment_1){
  base_0_unscaled <- 
    base %>% 
    select(!!c(x_m_names, treatment_name)) %>% 
    filter(!!sym(treatment_name) ==  treatment_0)
  base_1_unscaled <- 
    base %>% 
    select(!!c(x_m_names, treatment_name)) %>% 
    filter(!!sym(treatment_name) ==  treatment_1)
  
  for(i in 1:length(scale)){
    base_0_scaled <- 
      base_0_unscaled %>% 
      mutate(!!sym(x_m_names[i]) := !!sym(x_m_names[i]) * scale[i])
    base_1_scaled <- 
      base_1_unscaled %>% 
      mutate(!!sym(x_m_names[i]) := !!sym(x_m_names[i]) * scale[i])
  }
  
  # Mean in each subset (i.e., T=0, T=1)
  m_0 <- base_0_scaled %>% summarise(across(!!x_m_names, mean)) %>% as_vector()
  m_1 <- base_1_scaled %>% summarise(across(!!x_m_names, mean)) %>% as_vector()
  # Variance
  S_0 <- base_0_scaled %>% select(!!x_m_names) %>% var()
  S_1 <- base_1_scaled %>% select(!!x_m_names) %>% var()
  # Matrix A
  A <- (solve(sqrtm(S_0))) %*% sqrtm( sqrtm(S_0) %*% S_1 %*% (sqrtm(S_0)) ) %*% solve(sqrtm(S_0))
  
  list(m_0 = m_0, m_1 = m_1, S_0 = S_0, S_1 = S_1, A = A, scale = scale, x_m_names = x_m_names)
}

We create a function that transports a single observation, given the different parameters:

#' Gaussian transport for a single observation (helper function)
#' @param z vector of variables to transport (for a single observation)
#' @param A symmetric positive matrix that satisfies \(A\sigma_0A=\sigma_1\)
#' @param m_0,m_1 vectors of mean values in the subsets \(\mathcal{D}_0\) and \(\mathcal{D}_1\)
#' @param scale vector of scaling to apply to each variable to transport (default to 1)
T_C_single <- function(z, A, m_0, m_1, scale = 1){
  z <- z*scale
  as.vector(m_1 + A %*% (z-m_0))*(1/scale)
}

Lastly, we define a function that transports the mediator variables, under the Gaussian assumption. If the params argument is NULL, the arguments needed to estimate them through the transport_gaussian_param() function are required (target, x_m_names, treatment_name, treatment_0, treatment_1).

#' Gaussian transport
#' @param z data frame of variables to transport
#' @param params (optional) parameters to use for the transport (result of `transport_gaussian_param()`, default to NULL)
#' @param target name of the target variable
#' @param x_m_names vector of names of the mediator variables
#' @param scale vector of scaling to apply to each `x_m_names` variable to transport (default to 1)
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
gaussian_transport <- function(z, params = NULL, ..., x_m_names, scale, treatment_name, treatment_0, treatment_1){
  if(is.null(params)){
    # If the parameters of the transport function are not provided
    # they need to be computed first
    params <- 
      transport_gaussian_param(target = target, x_m_names = x_m_names, 
                               scale = scale, treatment_name = treatment_name, 
                               treatment_0 = treatment_0, treatment_1 = treatment_1)
  }
  A <- params$A
  m_0 <- params$m_0 ; m_1 <- params$m_1
  scale <- params$scale
  x_m_names <- params$x_m_names
  
  values_to_transport <- z %>% select(!!x_m_names)
  transported_val <- 
    apply(values_to_transport, 1, T_C_single, A = A, m_0 = m_0, m_1 = m_1,
          scale = scale, simplify = FALSE) %>% 
    do.call("rbind", .)
  colnames(transported_val) <- colnames(z)
  transported_val <- as_tibble(transported_val)
  
  structure(.Data = transported_val, params = params)
}

3 Example with \(n=1000\)

We will now compute the SCATE at three different points: \(\boldsymbol{x}=(2500,60)\), \(\boldsymbol{x}=(4200,60)\), and \(\boldsymbol{x}=(2500,20)\). At each point, we will estimate the SCATE using a GAM model with Gaussian transport depending on the number of observations used to train the models.

point_X = tibble(
  wtgain = c(60,60,20),
  birth_weight = c(2500,4200,2500)
  )

We will make the sample size \(n\) vary. Let us first consider \(n=1000\).

size <- 1000

Let us create a sample with 1,000 observation from base:

sbase <- base[sample(1:nrow(base), size=size), ]

The transported values:

target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"
point_X_t_n_smoker <- 
  gaussian_transport(z = point_X, target = target,
                     x_m_names = x_m_names, 
                     scale = scale, treatment_name = treatment_name,
                     treatment_0 = treatment_0, treatment_1 = treatment_1)
point_X_t_n_smoker
# A tibble: 3 × 2
  wtgain birth_weight
*  <dbl>        <dbl>
1   64.9        2353.
2   65.2        4071.
3   18.3        2284.
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"
point_X_t_n_blackm <- 
  gaussian_transport(z = point_X, target = target,
                     x_m_names = x_m_names, 
                     scale = scale, treatment_name = treatment_name,
                     treatment_0 = treatment_0, treatment_1 = treatment_1)
point_X_t_n_blackm
# A tibble: 3 × 2
  wtgain birth_weight
*  <dbl>        <dbl>
1   62.4        2208.
2   62.4        4087.
3   16.9        2205.
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"
point_X_t_n_sex <- 
  gaussian_transport(z = point_X, target = target,
                     x_m_names = x_m_names, 
                     scale = scale, treatment_name = treatment_name,
                     treatment_0 = treatment_0, treatment_1 = treatment_1)
point_X_t_n_sex
# A tibble: 3 × 2
  wtgain birth_weight
*  <dbl>        <dbl>
1   59.0        2405.
2   58.9        4030.
3   19.3        2428.

The estimated parameters can be extracted as follows:

params_smoker <- attr(point_X_t_n_smoker, "params")
params_smoker
$m_0
      wtgain birth_weight 
    30.40885     32.92174 

$m_1
      wtgain birth_weight 
    30.55169     31.02590 

$S_0
                wtgain birth_weight
wtgain       218.95677     14.54058
birth_weight  14.54058     34.09920

$S_1
                wtgain birth_weight
wtgain       297.62083     22.09039
birth_weight  22.09039     35.41197

$A
           [,1]       [,2]
[1,] 1.16471785 0.01715655
[2,] 0.01715655 1.01085009

$scale
[1] 1.00 0.01

$x_m_names
[1] "wtgain"       "birth_weight"

The estimated parameters can be extracted as follows:

params_blackm <- attr(point_X_t_n_blackm, "params")
params_blackm
$m_0
      wtgain birth_weight 
    30.67049     33.09063 

$m_1
      wtgain birth_weight 
    29.02406     31.00026 

$S_0
                wtgain birth_weight
wtgain       215.83720     14.15844
birth_weight  14.15844     32.86016

$S_1
               wtgain birth_weight
wtgain       278.8963     17.99050
birth_weight  17.9905     40.18716

$A
             [,1]         [,2]
[1,] 1.1366862866 0.0007016672
[2,] 0.0007016672 1.1055783355

$scale
[1] 1.00 0.01

$x_m_names
[1] "wtgain"       "birth_weight"

The estimated parameters can be extracted as follows:

params_sex <- attr(point_X_t_n_sex, "params")
params_sex
$m_0
      wtgain birth_weight 
    30.79877     33.32068 

$m_1
      wtgain birth_weight 
    30.00267     32.17229 

$S_0
               wtgain birth_weight
wtgain       227.8193     16.11780
birth_weight  16.1178     35.86006

$S_1
                wtgain birth_weight
wtgain       224.11168     13.81543
birth_weight  13.81543     32.60395

$A
            [,1]        [,2]
[1,]  0.99222733 -0.00565936
[2,] -0.00565936  0.95595986

$scale
[1] 1.00 0.01

$x_m_names
[1] "wtgain"       "birth_weight"

3.1 Estimation

Now that the two models \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\) are defined and fitted to the data, we can compute the Mutatis Mutandis Sample Conditional Average Effect.

#' Computes the Sample Average Treatment Effect with and without transport
#' @param x_0 vector of values at which to compute $\hat{m}_0(x_0)$, $\hat{m}_1(x_0)$
#' @param x_t vector of values at which to compute $\hat{m}_1(\mathcal{T}(x_t))$
#' @param x_m_name vector of names of the mediator variables
#' @param mod_0 model $\hat{m}_0$
#' @param mod_1 model $\hat{m}_1$
#' @param pred_mod_0 prediction function for model $\hat{m}_0(\cdot)$
#' @param pred_mod_1 prediction function for model $\hat{m}_1(\cdot)$
#' @param return_x if `TRUE` (default) the mediator variables are returned in the table, as well as their transported values
sate <- function(x_0, x_t, x_m_names, mod_0, mod_1, pred_mod_0, pred_mod_1, return_x = TRUE){
  if(is.vector(x_0)){
    # Univariate case
    new_data <- tibble(!!x_m_names := x_0)
  }else{
    new_data <- x_0
  }
  if(is.vector(x_t)){
    # Univariate case
    new_data_t <- tibble(!!x_m_names := x_t)
  }else{
    new_data_t <- x_t
  }
  
  
  # $\hat{m}_0(x_0)$
  y_0 <- pred_mod_0(object = mod_0, newdata = new_data)
  # $\hat{m}_1(x_0)$
  y_1 <- pred_mod_1(object = mod_1, newdata = new_data)
  # $\hat{m}_1(\mathcal{T}(x_0))$
  y_1_t <- pred_mod_1(object = mod_1, newdata = new_data_t)
  
  
  scate_tab <- 
    tibble(y_0 = y_0, y_1 = y_1, y_1_t = y_1_t,
         CATE = y_1-y_0, SCATE = y_1_t-y_0)
  
  if(return_x){
    new_data_t <- new_data_t %>% rename_all(~paste0(.x, "_t"))
    scate_tab <- bind_cols(new_data, new_data_t, scate_tab)
  }
  scate_tab
}

3.2 GAM (with cubic splines), Gaussian assumption for transport

Let us compute here the SCATE where the models \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\) are GAM and where the transport method assumes a Gaussian distribution for both mediators.

mm_sate_gam_smoker <- 
  sate(
    x_0 = point_X,
    x_t = point_X_t_n_smoker,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_gam_smoker_0,
    mod_1 = reg_gam_smoker_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate_gam_smoker
# A tibble: 3 × 9
  wtgain birth_weight wtgain_t birth_weight_t   y_0   y_1 y_1_t     CATE   SCATE
   <dbl>        <dbl>    <dbl>          <dbl> <dbl> <dbl> <dbl>    <dbl>   <dbl>
1     60         2500     64.9          2353. 0.488 0.435 0.476 -0.0528  -0.0112
2     60         4200     65.2          4071. 0.469 0.461 0.443 -0.00779 -0.0255
3     20         2500     18.3          2284. 0.384 0.363 0.401 -0.0208   0.0166
mm_sate_gam_blackm <- 
  sate(
    x_0 = point_X,
    x_t = point_X_t_n_blackm,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_gam_blackm_0,
    mod_1 = reg_gam_blackm_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate_gam_blackm
# A tibble: 3 × 9
  wtgain birth_weight wtgain_t birth_weight_t   y_0   y_1 y_1_t     CATE  SCATE
   <dbl>        <dbl>    <dbl>          <dbl> <dbl> <dbl> <dbl>    <dbl>  <dbl>
1     60         2500     62.4          2208. 0.483 0.465 0.527 -0.0177  0.0441
2     60         4200     62.4          4087. 0.460 0.551 0.515  0.0917  0.0558
3     20         2500     16.9          2205. 0.380 0.387 0.441  0.00713 0.0608
mm_sate_gam_sex <- 
  sate(
    x_0 = point_X,
    x_t = point_X_t_n_sex,
    x_m_names = c("wtgain", "birth_weight"),
    mod_0 = reg_gam_sex_0,
    mod_1 = reg_gam_sex_1,
    pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
    return = TRUE)
mm_sate_gam_sex
# A tibble: 3 × 9
  wtgain birth_weight wtgain_t birth_weigh…¹   y_0   y_1 y_1_t     CATE    SCATE
   <dbl>        <dbl>    <dbl>         <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
1     60         2500     59.0         2405. 0.490 0.474 0.491 -0.0161   6.34e-4
2     60         4200     58.9         4030. 0.466 0.469 0.417  0.00245 -4.89e-2
3     20         2500     19.3         2428. 0.391 0.375 0.390 -0.0159  -9.93e-4
# … with abbreviated variable name ¹​birth_weight_t

4 Simulations

Now, let us wrap all these estimation stepts in a single function, so that we can make simulations where we vary the size of the sample \(n\) used to compute the SCTE. For each value of \(n\), let us compute the mutatis mutandis CATE (SCATE) for our three individuals of \(\boldsymbol{x}\). We will further draw 500 different sub-samples over which the estimations will be computed.

#' @param size size of the sample over which to compute the SCATE
#' @param target name of the target variable
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param x_m_names names of the mediator variables
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
#' @param scale vector of scaling to apply to each `x_m_names` variable to transport (default to 1)
simul_n <- function(size, target, x_m_names, treatment_name, treatment_0, treatment_1, scale){
  # Random sub-sample of size `size`
  sbase <- base[sample(1:nrow(base), size=size), ]
  # Regressions on both subsets
  reg_gam <- 
    models_spline(target = target,
                  treatment_name = treatment_name, 
                  x_m_names = x_m_names,
                  data = sbase, treatment_0 = treatment_0, treatment_1 = treatment_1, df = c(3, 3))
  reg_gam_0 <- reg_gam$reg_0
  reg_gam_1 <- reg_gam$reg_1
  
  # Transported values
  point_X_t_n <- 
    gaussian_transport(z = point_X, target = target,
                       x_m_names = x_m_names, 
                       scale = scale, treatment_name = treatment_name,
                       treatment_0 = treatment_0, treatment_1 = treatment_1)
  
  # SCATE
  mm_sate_gam <- 
    sate(
      x_0 = point_X,
      x_t = point_X_t_n,
      x_m_names = c("wtgain", "birth_weight"),
      mod_0 = reg_gam_0,
      mod_1 = reg_gam_1,
      pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
      return = TRUE)
  
  tibble(SCATE = mm_sate_gam$SCATE, idx = 1:nrow(point_X), size = size, treatment_name = treatment_name)
}

For example, with a size \(n=100\), when the treatment is whether the mother is a smoker or not:

target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"

resul_simul_100 <- 
  simul_n(size = 100, target = target, x_m_names = x_m_names, 
        treatment_name = treatment_name, treatment_0 = treatment_0, treatment_1 = treatment_1, scale = scale)
resul_simul_100
# A tibble: 3 × 4
  SCATE   idx  size treatment_name
  <dbl> <int> <dbl> <chr>         
1 0.268     1   100 cig_rec       
2 0.388     2   100 cig_rec       
3 0.286     3   100 cig_rec       

The different values of \(n\):

sample_size <- 10^(seq(3,5.2,length=51))
sample_size
 [1]   1000.000   1106.624   1224.616   1355.189   1499.685   1659.587
 [7]   1836.538   2032.357   2249.055   2488.857   2754.229   3047.895
[13]   3372.873   3732.502   4130.475   4570.882   5058.247   5597.576
[19]   6194.411   6854.882   7585.776   8394.600   9289.664  10280.163
[25]  11376.273  12589.254  13931.568  15417.005  17060.824  18879.913
[31]  20892.961  23120.648  25585.859  28313.920  31332.857  34673.685
[37]  38370.725  42461.956  46989.411  51999.600  57543.994  63679.552
[43]  70469.307  77983.011  86297.855  95499.259 105681.751 116949.939
[49] 129419.584 143218.790 158489.319

The number of replications of the simulations for each value of \(n\):

nb_replicate <- 500

Let us perform the simulations in parallel.

library(parallel)
ncl <- detectCores()-1
cl <- makeCluster(ncl)
invisible(clusterEvalQ(cl, library(tidyverse, warn.conflicts=FALSE, quietly=TRUE)))
invisible(clusterEvalQ(cl, library(splines, warn.conflicts=FALSE, quietly=TRUE)))
invisible(clusterEvalQ(cl, library(expm, warn.conflicts=FALSE, quietly=TRUE)))

The simul_n() function we defined need to access some functions and some data in each cluster:

clusterExport(cl, c("models_spline", "gaussian_transport", "transport_gaussian_param",
                    "T_C_single" ,"mm_sate_gam", "model_spline_predict", "sate"))
clusterExport(cl, c("point_X", "base"))
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "cig_rec"
treatment_0    <- "No"
treatment_1    <- "Yes"

cate_sim_smoker <- 
  pbapply::pblapply(rep(sample_size, each = nb_replicate), 
                    simul_n, cl = cl, target = target, x_m_names = x_m_names, 
                    treatment_name = treatment_name, 
                    treatment_0 = treatment_0, treatment_1 = treatment_1, scale = scale)

The results can be saved:

save(cate_sim_smoker, file = "../output/simulations/cate_sim_smoker.rda")

And then loaded:

load("../output/simulations/cate_sim_smoker.rda")
cate_sim_smoker[[1]]
# A tibble: 3 × 4
   SCATE   idx  size treatment_name
   <dbl> <int> <dbl> <chr>         
1 0.0827     1  1000 cig_rec       
2 0.162      2  1000 cig_rec       
3 0.0970     3  1000 cig_rec       
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "black_mother"
treatment_0    <- "No"
treatment_1    <- "Yes"

cate_sim_blackm <- 
  pbapply::pblapply(rep(sample_size, each = nb_replicate), 
                    simul_n, cl = cl, target = target, x_m_names = x_m_names, 
                    treatment_name = treatment_name, 
                    treatment_0 = treatment_0, treatment_1 = treatment_1, scale = scale)

The results can be saved:

save(cate_sim_blackm, file = "../output/simulations/cate_sim_blackm.rda")

And then loaded:

load("../output/simulations/cate_sim_blackm.rda")
target         <- "nonnatural_delivery"
x_m_names       <- c("wtgain", "birth_weight")
scale          <- c(1,1/100)
treatment_name <- "sex"
treatment_0    <- "Male"
treatment_1    <- "Female"

cate_sim_sex <- 
  pbapply::pblapply(rep(sample_size, each = nb_replicate), 
                    simul_n, cl = cl, target = target, x_m_names = x_m_names, 
                    treatment_name = treatment_name, 
                    treatment_0 = treatment_0, treatment_1 = treatment_1, scale = scale)

The results can be saved:

save(cate_sim_sex, file = "../output/simulations/cate_sim_sex.rda")

And then loaded:

load("../output/simulations/cate_sim_sex.rda")
stopCluster(cl = cl)

Then, we can plot the results of the simulations. First, we can compute the average value of the SCATE over the 500 runs of the simulation, for each sample size.

cate_sim_df <- 
  bind_rows(cate_sim_smoker) %>% 
  bind_rows(
    bind_rows(cate_sim_blackm)
  ) %>% 
  bind_rows(
    bind_rows(cate_sim_sex)
  ) %>% 
  group_by(treatment_name, idx, size) %>% 
  summarise(
    q1_05 = quantile(SCATE, probs = .05),
    mean = mean(SCATE),
    q1_95 = quantile(SCATE, probs = .95),
  )
cate_sim_df %>% 
  mutate(
    name = factor(idx, levels = c(1:3),
                  labels = c("x = (2500, 60)", "x = (4200, 60)", "x = (2500, 20)")),
    treatment_name = factor(treatment_name, levels = c("cig_rec", "black_mother", "sex"),
                            labels = c("Smoker mother", "Black mother", "Baby girl"))
  ) %>% 
ggplot(data = .,
       mapping = aes(x = size, y = mean)) +
  geom_ribbon(mapping = aes(ymin = q1_05, ymax = q1_95), fill = couleur1, alpha = .3) +
  geom_hline(yintercept = 0, colour = "grey", linetype = "dashed") +
  geom_line(colour = couleur3) +
  facet_grid(name~treatment_name, scales = "free") +
  scale_x_log10() +
  labs(x = "Number of observations (log scales)", y = "SCATE(x)") +
  theme_bw()

Figure 1. Estimation of the SCATE at different values of x, using a GAM model with Gaussian transport, on n observations when Y=1 (non-natural delivery) and X is the weight of the newborn infant and the weight gain of the mother when T indicates either whether the mother is a smoker or not, Black or not or if the baby is a girl or not.

Equivalently, we can create beautiful plots with {graphics}.

cate_sim_smoker_df <- 
  bind_rows(cate_sim_smoker) %>% 
  group_by(idx, size) %>% 
  summarise(
    q1_05 = quantile(SCATE, probs = .05),
    mean = mean(SCATE),
    q1_95 = quantile(SCATE, probs = .95),
  )

par(mfrow = c(1,3))
CATE <- cate_sim_smoker_df %>% filter(idx == 1)
plot(x = CATE$size, y = CATE$mean,
     col=couleur3, lwd=2,
     type="l", ylim=c(-.5,.5),
     log="x",
     xlab="Number of observations (log scales)",
     ylab="CATE(2500,60)")
polygon(c(CATE$size, rev(CATE$size)), c(CATE$q1_05, rev(CATE$q1_95)),
        col=scales::alpha(couleur1,.3),border=NA)
lines(CATE$size, CATE$q1_05, col = couleur1)
lines(CATE$size, CATE$q1_95, col = couleur1)

CATE <- cate_sim_smoker_df %>% filter(idx == 2)
plot(x = CATE$size, y = CATE$mean,
     col=couleur3, lwd=2,
     type = "l", ylim = c(-.5,.5), log = "x",
     xlab = "Number of observations (log scales)",
     ylab = "CATE(4200,60)")
polygon(c(CATE$size, rev(CATE$size)), c(CATE$q1_05, rev(CATE$q1_95)),
        col = scales::alpha(couleur1,.3), border = NA)
lines(CATE$size, CATE$q1_05, col = couleur1)
lines(CATE$size, CATE$q1_95, col = couleur1)

CATE <- cate_sim_smoker_df %>% filter(idx == 3)
plot(x = CATE$size, y = CATE$mean,
     col = couleur3, lwd=2,
     type = "l",ylim=c(-.5,.5),log="x",
     xlab = "Number of observations (log scales)",
     ylab = "CATE(2500,20)")
polygon(c(CATE$size, rev(CATE$size)), c(CATE$q1_05, rev(CATE$q1_95)),
        col = scales::alpha(couleur1,.3), border = NA)
lines(CATE$size, CATE$q1_05, col = couleur1)
lines(CATE$size, CATE$q1_95, col = couleur1)

Figure 2. Estimation of the Mutatis Mutandis CATE for \(\boldsymbol{x}=(2500,60)\), using a GAM model with Gaussian transport, on n observations when Y=1 (non-natural delivery) and X is the weight of the newborn infant and the weight gain of the mother when T indicates whether the mother is a smoker or not.

cate_sim_blackm_df <- 
  bind_rows(cate_sim_blackm) %>% 
  group_by(idx, size) %>% 
  summarise(
    q1_05 = quantile(SCATE, probs = .05),
    mean = mean(SCATE),
    q1_95 = quantile(SCATE, probs = .95),
  )

par(mfrow = c(1,3))
CATE <- cate_sim_blackm_df %>% filter(idx == 1)
plot(x = CATE$size, y = CATE$mean,
     col=couleur3, lwd=2,
     type = "l", ylim = c(-.3,.3), log = "x",
     xlab="Number of observations (log scales)",
     ylab="CATE(2500,60)")
polygon(c(CATE$size, rev(CATE$size)), c(CATE$q1_05, rev(CATE$q1_95)),
        col=scales::alpha(couleur1,.3),border=NA)
lines(CATE$size, CATE$q1_05, col = couleur1)
lines(CATE$size, CATE$q1_95, col = couleur1)

CATE <- cate_sim_blackm_df %>% filter(idx == 2)
plot(x = CATE$size, y = CATE$mean,
     col=couleur3, lwd=2,
     type = "l", ylim = c(-.3,.3), log = "x",
     xlab = "Number of observations (log scales)",
     ylab = "CATE(4200,60)")
polygon(c(CATE$size, rev(CATE$size)), c(CATE$q1_05, rev(CATE$q1_95)),
        col = scales::alpha(couleur1,.3), border = NA)
lines(CATE$size, CATE$q1_05, col = couleur1)
lines(CATE$size, CATE$q1_95, col = couleur1)

CATE <- cate_sim_blackm_df %>% filter(idx == 3)
plot(x = CATE$size, y = CATE$mean,
     col = couleur3, lwd=2,
     type = "l", ylim = c(-.3,.3), log = "x",
     xlab = "Number of observations (log scales)",
     ylab = "CATE(2500,20)")
polygon(c(CATE$size, rev(CATE$size)), c(CATE$q1_05, rev(CATE$q1_95)),
        col = scales::alpha(couleur1,.3), border = NA)
lines(CATE$size, CATE$q1_05, col = couleur1)
lines(CATE$size, CATE$q1_95, col = couleur1)

Figure 3. Estimation of the Mutatis Mutandis CATE for \(\boldsymbol{x}=(2500,60)\), using a GAM model with Gaussian transport, on n observations when Y=1 (non-natural delivery) and X is the weight of the newborn infant and the weight gain of the mother when T indicates whether the mother is Black or not.

cate_sim_sex_df <- 
  bind_rows(cate_sim_sex) %>% 
  group_by(idx, size) %>% 
  summarise(
    q1_05 = quantile(SCATE, probs = .05),
    mean = mean(SCATE),
    q1_95 = quantile(SCATE, probs = .95),
  )

par(mfrow = c(1,3))
CATE <- cate_sim_sex_df %>% filter(idx == 1)
plot(x = CATE$size, y = CATE$mean,
     col=couleur3, lwd=2,
     type="l", ylim=c(-.2,.2), log="x",
     xlab="Number of observations (log scales)",
     ylab="CATE(2500,60)")
polygon(c(CATE$size, rev(CATE$size)), c(CATE$q1_05, rev(CATE$q1_95)),
        col=scales::alpha(couleur1,.3),border=NA)
lines(CATE$size, CATE$q1_05, col = couleur1)
lines(CATE$size, CATE$q1_95, col = couleur1)

CATE <- cate_sim_sex_df %>% filter(idx == 2)
plot(x = CATE$size, y = CATE$mean,
     col=couleur3, lwd=2,
     type="l", ylim=c(-.2,.2), log="x",
     xlab = "Number of observations (log scales)",
     ylab = "CATE(4200,60)")
polygon(c(CATE$size, rev(CATE$size)), c(CATE$q1_05, rev(CATE$q1_95)),
        col = scales::alpha(couleur1,.3), border = NA)
lines(CATE$size, CATE$q1_05, col = couleur1)
lines(CATE$size, CATE$q1_95, col = couleur1)

CATE <- cate_sim_sex_df %>% filter(idx == 3)
plot(x = CATE$size, y = CATE$mean,
     col = couleur3, lwd=2,
     type="l", ylim=c(-.2,.2), log="x",
     xlab = "Number of observations (log scales)",
     ylab = "CATE(2500,20)")
polygon(c(CATE$size, rev(CATE$size)), c(CATE$q1_05, rev(CATE$q1_95)),
        col = scales::alpha(couleur1,.3), border = NA)
lines(CATE$size, CATE$q1_05, col = couleur1)
lines(CATE$size, CATE$q1_95, col = couleur1)

Figure 4. Estimation of the Mutatis Mutandis CATE for \(\boldsymbol{x}=(2500,60)\), using a GAM model with Gaussian transport, on n observations when Y=1 (non-natural delivery) and X is the weight of the newborn infant and the weight gain of the mother when T indicates whether the newborn is a girl.