Optimal Transport for Counterfactual Estimation: A Method for Causal Inference

Online Appendix: reproduction of the results (univariate case)

Authors

Arthur Charpentier

Emmanuel Flachaire

Ewen Gallic

This online appendix provides R codes to apply the methods presented in the companion paper (Charpentier, Flachaire, and Gallic (2023)).

1 Data

To illustrate the methods, we rely on the Linked Birth/Infant Death Cohort Data (2013 cohort). The CSV files for the 2013 cohort were downloaded from the NBER collection of Birth Cohort Linked Birth and Infant Death Data of the National Vital Statistics System of the National Center for Health Statistics, on the NBER website. We formated the data (see ./births_stats.html).

Let us load the birth data.

library(tidyverse)
load("../data/births.RData")

Then, we can only keep a subsample of variables to illustrate the method.

base <- 
  births %>% 
  mutate(
    black_mother = mracerec == "Black",
    nonnatural_delivery = rdmeth_rec != "Vaginal"
  ) %>% 
  select(
    sex, dbwt, cig_rec, wtgain, 
    black_mother, nonnatural_delivery,
    mracerec, rdmeth_rec
  ) %>% 
  mutate(
    cig_rec = replace_na(cig_rec, "Unknown or not stated"),
    black_mother = ifelse(black_mother, yes = "Yes", no = "No"),
    is_girl = ifelse(sex == "Female", yes = "Yes", no = "No")
  ) %>% 
  labelled::set_variable_labels(
    black_mother = "Is the mother Black?",
    is_girl = "Is the newborn a girl?",
    nonnatural_delivery = "Is the delivery method non-natural?"
  ) %>% 
  rename(birth_weight = dbwt)
base
# A tibble: 3,940,764 × 9
   sex    birth_weight cig_rec wtgain black_mo…¹ nonna…² mrace…³ rdmet…⁴ is_girl
   <fct>         <dbl> <fct>    <dbl> <chr>      <lgl>   <fct>   <fct>   <chr>  
 1 Male           3017 No           1 No         FALSE   "White" Vaginal No     
 2 Female         3367 No          18 No         FALSE   "White" Vaginal Yes    
 3 Male           4409 No          36 No         FALSE   "White" Vaginal No     
 4 Male           4047 No          25 No         FALSE   "White" Vaginal No     
 5 Male           3119 Yes          0 No         FALSE   "Ameri… Vaginal No     
 6 Female         4295 Yes         34 No         FALSE   "White" Vaginal Yes    
 7 Male           3160 No          47 No         FALSE   "Ameri… Vaginal No     
 8 Male           3711 No          39 No         FALSE   "White" Vaginal No     
 9 Male           3430 No          NA No         FALSE   "White" Vaginal No     
10 Female         3772 No          36 No         FALSE   "White" Vaginal Yes    
# … with 3,940,754 more rows, and abbreviated variable names ¹​black_mother,
#   ²​nonnatural_delivery, ³​mracerec, ⁴​rdmeth_rec

The variables that were kept are the following:

  • sex: sex of the newborn
  • birth_weight: Birth Weight (in Grams)
  • cig_rec: mother smokes cigarettes
  • wtgain: mother weight gain
  • mracerec: mother’s race
  • black_mother: is mother black?
  • rdmeth_rec: delivery method
  • nonnatural_delivery: is the delivery method non-natural?

Then, let us discard individuals with missing values.

base <- 
  base %>% 
  filter(
    !is.na(birth_weight),
    !is.na(wtgain),
    !is.na(nonnatural_delivery),
    !is.na(black_mother)
  )

Let us define some colours for later use:

library(wesanderson)
colr1 <- wes_palette("Darjeeling1")
colr2 <- wes_palette("Darjeeling2")
couleur1 <- colr1[2]
couleur2 <- colr1[4]
couleur3 <- colr2[2]

coul1 <- "#882255"
coul2 <- "#DDCC77"

We need to load some packages.

library(car)
library(transport)
library(splines)
library(mgcv)
library(expm)
library(tidyverse)

2 Objectives

We would like to measure the effect of a treatment variable \(T\) on the probability of observing a binary outcome \(y\), depending on a binary treatment \(T\). The outcome is assumed to depend on another variable \(\boldsymbol{x}^m\) that is also influenced by the treatment; this variable is called a mediator.

  • Output (\(y\), nonnatural_delivery): Probability of having a non-natural delivery

  • Treatment (\(T\)), either:

    • cig_rec: Whether the mother smokes \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\)
    • black_mother: Whether the mother is Black \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\)
    • sex: Whether the baby is a girl \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\)
  • Mediator (\(\boldsymbol{x}^m\)) is either:

    • birth_weight: Birth weight of the newborn
    • wtgain: weight gain of the mother.
Sample Conditional Average Treatment Effect: \(\text{SCATE}(\boldsymbol{x})\)

Consider two models, \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\), that estimate, respectively, \(\color{couleur1}{\mathbb{E}[Y|X=x,T=0]}\) and \(\color{couleur2}{\mathbb{E}[Y|X=x,T=1]}\), \[ \text{SCATE}(x)=\color{couleur2}{\widehat{m}_1}\big(\widehat{\mathcal{T}}(x)\big)\color{black}{} - \color{couleur1}{\widehat{m}_0}\big(x\big) \] where \(\widehat{\mathcal{T}}(\cdot)\) is a transport function:

  • if we consider a quantile-based matching, \(\widehat{\mathcal{T}}(x)= \color{couleur2}{\widehat{F}_1^{-1}}\color{black}{}\circ \color{couleur1}{\widehat{F}_0}(x)\), with \(\color{couleur1}{\widehat{F}_0}\) and \(\color{couleur2}{\widehat{F}_1}\) denoting the empirical distribution functions of \(x\) conditional on \(\color{couleur1}{t=0}\) and \(\color{couleur2}{t=1}\), respectively.
  • assuming a Gaussian distribution of the mediator variable, we can consider \(\widehat{\mathcal{T}}(x) := \widehat{\mathcal{T}}_{\mathcal{N}}(x)= \color{couleur2}{\overline{x}_1}\color{black}{}+\color{couleur2}{s_1}\color{couleur1}{s_0^{-1}}\color{black}{} (x-\color{couleur1}{\overline{x}_0}\color{black}{})\), \(\color{couleur1}{\overline{x}_0}\) and \(\color{couleur2}{\overline{x}_1}\) being respectively the averages of \(x\) in the two sub-populations, and \(\color{couleur1}{s_0}\) and \(\color{couleur2}{s_1}\) the sample standard deviations.
What is Needed

To estimate the mutatis mutandis Sample Conditional Average Treatment Effect, the following are required:

  1. The two estimated models \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\).
  2. A function to predict new values with these each of these two models.
  3. A transport method \(\mathcal{T}(\cdot)\).

We will compute the SCATE at the different values of \(\boldsymbol{x}\):

x_0_birth_weight <- seq(2000,4500, by=500)
x_0_weight_gain <- seq(5,55, by=10)

3 Models

We will consider here different models:

  1. GAM
  2. Local averages computed using kernels

For each model, we need to define two functions: a first one that estimates the model, and a second one that makes predictions based on the estimated model.

3.1 GAM (with cubic splines)

The function used to estimate both models \(\color{couleur1}{\hat{m}_0}\color{black}{}()\) and \(\color{couleur2}{\hat{m}_1}\color{black}{}()\)

library(splines)

#' Returns $\hat{m}_0()$ and $\hat{m}_1()$, using GAM
#' @param target name of the target variable
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param x_m_name name of the mediator variable
#' @param data data frame with the observations
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
#' @param df degrees of freedom (in the `bs()` function, deefault to `df=3`)
models_spline <- function(target, treatment_name, x_m_name, data, treatment_0, treatment_1, df = 3){
  # \hat{m}_0()
  reg_0 <-
    bquote(glm(.(target)~bs(.(x_m_name), df = df), data=data,
               family=binomial, subset = (.(treatment_name) == .(treatment_0))),
           list(target = as.name(target), x_m_name=as.name(x_m_name),
                treatment_name = as.name(treatment_name),
                treatment_0 = treatment_0)) %>% 
    eval()
  # \hat{m}_1()
  reg_1 <- bquote(glm(.(target)~bs(.(x_m_name), df = df), data=data,
                      family=binomial, subset = (.(treatment_name) == .(treatment_1))),
                  list(target = as.name(target), x_m_name=as.name(x_m_name),
                       treatment_name = as.name(treatment_name),
                       treatment_1 = treatment_1)) %>% 
    eval()
  
  list(reg_0 = reg_0, reg_1 = reg_1)
}

Let us estimate these models.

target         <- "nonnatural_delivery"
treatment_name <- "cig_rec"
x_m_name       <- "birth_weight"
treatment_0    <- "No"
treatment_1    <- "Yes"

reg_univ_smoker_bw <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_name = x_m_name,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = 3)
target         <- "nonnatural_delivery"
treatment_name <- "cig_rec"
x_m_name       <- "wtgain"
treatment_0    <- "No"
treatment_1    <- "Yes"

reg_univ_smoker_wg <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_name = x_m_name,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = 3)
target         <- "nonnatural_delivery"
treatment_name <- "black_mother"
x_m_name       <- "birth_weight"
treatment_0    <- "No"
treatment_1    <- "Yes"

reg_univ_blackm_bw <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_name = x_m_name,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = 3)
target         <- "nonnatural_delivery"
treatment_name <- "black_mother"
x_m_name       <- "wtgain"
treatment_0    <- "No"
treatment_1    <- "Yes"

reg_univ_blackm_wg <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_name = x_m_name,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = 3)
target         <- "nonnatural_delivery"
treatment_name <- "sex"
x_m_name       <- "birth_weight"
treatment_0    <- "Male"
treatment_1    <- "Female"

reg_univ_sex_bw <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_name = x_m_name,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = 3)
target         <- "nonnatural_delivery"
treatment_name <- "sex"
x_m_name       <- "wtgain"
treatment_0    <- "Male"
treatment_1    <- "Female"

reg_univ_sex_wg <- 
  models_spline(target = target,
              treatment_name = treatment_name, 
              x_m_name = x_m_name,
              data = base, treatment_0 = treatment_0, treatment_1 = treatment_1, df = 3)

3.2 Kernels

Let us define a function to compute the local average of a single observation:

#' @param target name of the response variable
#' @param x_m_name name of the mediator variable
#' @param x single value which will be transported
#' @param h bandwidth
#' @param data data frame, subset of data with t=0 or t=1
pred_kernel_single <- function(target, x_m_name, x, h=10, data){
  w <- dnorm(pull(data, x_m_name), x, h)
  sum(w*pull(data, target))/sum(w)
}
target         <- "nonnatural_delivery"
treatment_name <- "cig_rec"
x_m_name       <- "birth_weight"
treatment_0    <- "No"
treatment_1    <- "Yes"

mod_0_k_smoker_bw <- list(target = target, x_m_name = x_m_name, h= 50,
                     data = base[pull(base, treatment_name) == treatment_0,])
mod_1_k_smoker_bw <- list(target = target, x_m_name = x_m_name, h= 50, 
                     data = base[pull(base, treatment_name) == treatment_1,])
target         <- "nonnatural_delivery"
treatment_name <- "cig_rec"
x_m_name       <- "wtgain"
treatment_0    <- "No"
treatment_1    <- "Yes"

mod_0_k_smoker_wg <- list(target = target, x_m_name = x_m_name, h = 50,
                     data = base[pull(base, treatment_name) == treatment_0,])
mod_1_k_smoker_wg <- list(target = target, x_m_name = x_m_name, h = 50, 
                     data = base[pull(base, treatment_name) == treatment_1,])
target         <- "nonnatural_delivery"
treatment_name <- "black_mother"
x_m_name       <- "birth_weight"
treatment_0    <- "No"
treatment_1    <- "Yes"

mod_0_k_blackm_bw <- list(target = target, x_m_name = x_m_name, h= 50,
                     data = base[pull(base, treatment_name) == treatment_0,])
mod_1_k_blackm_bw <- list(target = target, x_m_name = x_m_name, h= 50, 
                     data = base[pull(base, treatment_name) == treatment_1,])
target         <- "nonnatural_delivery"
treatment_name <- "black_mother"
x_m_name       <- "wtgain"
treatment_0    <- "No"
treatment_1    <- "Yes"

mod_0_k_blackm_wg <- list(target = target, x_m_name = x_m_name, h= 50,
                     data = base[pull(base, treatment_name) == treatment_0,])
mod_1_k_blackm_wg <- list(target = target, x_m_name = x_m_name, h= 50, 
                     data = base[pull(base, treatment_name) == treatment_1,])
target         <- "nonnatural_delivery"
treatment_name <- "sex"
x_m_name       <- "birth_weight"
treatment_0    <- "Male"
treatment_1    <- "Female"

mod_0_k_sex_bw <- list(target = target, x_m_name = x_m_name, h= 50,
                     data = base[pull(base, treatment_name) == treatment_0,])
mod_1_k_sex_bw <- list(target = target, x_m_name = x_m_name, h= 50, 
                     data = base[pull(base, treatment_name) == treatment_1,])
target         <- "nonnatural_delivery"
treatment_name <- "sex"
x_m_name       <- "wtgain"
treatment_0    <- "Male"
treatment_1    <- "Female"

mod_0_k_sex_wg <- list(target = target, x_m_name = x_m_name, h= 50,
                     data = base[pull(base, treatment_name) == treatment_0,])
mod_1_k_sex_wg <- list(target = target, x_m_name = x_m_name, h= 50, 
                     data = base[pull(base, treatment_name) == treatment_1,])

4 Prediction Function

4.1 GAM

The prediction function for GAM:

#' @param object regression model (GAM)
#' @param newdata data frame in which to look for the mediator variable used to predict the target
model_spline_predict <- function(object, newdata){
  predict(object, newdata = newdata, type="response")
}

4.2 Kernels

Function used to compute the local average of a single observation:

#' @param target name of the response variable
#' @param x_m_name name of the mediator variable
#' @param x single value which will be transported
#' @param h bandwidth
#' @param data data frame, subset of data with t=0 or t=1
pred_kernel_single <- function(target, x_m_name, x, h=10, data){
  w <- dnorm(pull(data, x_m_name), x, h)
  sum(w*pull(data, target))/sum(w)
}

And the prediction function:

#' @param object regression model (GAM)
#' @param newdata data frame in which to look for the mediator variable used to predict the target
pred_kernel <- function(object, newdata){
  target <- object$target
  x_m_name <- object$x_m_name
  h <- object$h
  data <- object$data
  pull(newdata, 1) %>% 
    map_dbl(~pred_kernel_single(target = target, x_m_name = x_m_name, x = ., h = h,
                                data = data))
}

5 Transport

Let us now turn to the transport of the mediator variable.

5.1 Quantile-based

Then, we define a function to transport the values of a mediator.

#' Quantile-based transport (Univariate)
#' @param x_0 vector of values to be transported
#' @param x_m_name name of the mediator variable
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
#' @param data data frame with both T=0 or T=1
transport_quantile <- function(x_0, x_m_name, treatment_name, treatment_0, treatment_1, data){
  ind_0 <- pull(data, treatment_name) == treatment_0
  x_val <- pull(data, x_m_name)
  # Empirical Cumulative Distribution Function for values of the mediator variable for non-treated
  Fn <- ecdf(x_val)
  # Probability associated in the non-treated
  u <- Fn(x_0)
  # Transported_values
  x_1 <- pull(data, x_m_name)[pull(data, treatment_name) == treatment_1]
  x_t_quantile <- quantile(x_1, u)
  list(x_0 = x_0, u = u, x_t_quantile = x_t_quantile)
}

The transported values of \(\boldsymbol{x}\), i.e., \(\mathcal{T}(\boldsymbol{x})\):

target         <- "nonnatural_delivery"
treatment_name <- "cig_rec"
x_m_name       <- "birth_weight"
treatment_0    <- "No"
treatment_1    <- "Yes"

# Only at a few points
trans_q_smoker_bw <- 
  transport_quantile(x_0 = x_0_birth_weight, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)

# At more points
probs <- seq(0,1,length=601)
# Quantiles for non treated
x_0_probs_smoker_bw <- quantile(pull(base, x_m_name)[pull(base, treatment_name) == treatment_0], probs = probs)
# Transported values
trans_q_smoker_bw_probs <- 
  transport_quantile(x_0 = x_0_probs_smoker_bw, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)
target         <- "nonnatural_delivery"
treatment_name <- "cig_rec"
x_m_name       <- "wtgain"
treatment_0    <- "No"
treatment_1    <- "Yes"

# Only at a few points
trans_q_smoker_wg <- 
  transport_quantile(x_0 = x_0_weight_gain, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)

# At more points
probs <- seq(0,1,length=601)
# Quantiles for non treated
x_0_probs_smoker_wg <- quantile(pull(base, x_m_name)[pull(base, treatment_name) == treatment_0], probs = probs)
# Transported values
trans_q_smoker_wg_probs <- 
  transport_quantile(x_0 = x_0_probs_smoker_wg, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)
target         <- "nonnatural_delivery"
treatment_name <- "black_mother"
x_m_name       <- "birth_weight"
treatment_0    <- "No"
treatment_1    <- "Yes"

# Only at a few points
trans_q_blackm_bw <- 
  transport_quantile(x_0 = x_0_birth_weight, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)

# At more points
probs <- seq(0,1,length=601)
# Quantiles for non treated
x_0_probs_blackm_bw <- quantile(pull(base, x_m_name)[pull(base, treatment_name) == treatment_0], probs = probs)
# Transported values
trans_q_blackm_bw_probs <- 
  transport_quantile(x_0 = x_0_probs_blackm_bw, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)
target         <- "nonnatural_delivery"
treatment_name <- "black_mother"
x_m_name       <- "wtgain"
treatment_0    <- "No"
treatment_1    <- "Yes"

# Only at a few points
trans_q_blackm_wg <- 
  transport_quantile(x_0 = x_0_weight_gain, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)

# At more points
probs <- seq(0,1,length=601)
# Quantiles for non treated
x_0_probs_blackm_wg <- quantile(pull(base, x_m_name)[pull(base, treatment_name) == treatment_0], probs = probs)
# Transported values
trans_q_blackm_wg_probs <- 
  transport_quantile(x_0 = x_0_probs_blackm_wg, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)
target         <- "nonnatural_delivery"
treatment_name <- "sex"
x_m_name       <- "birth_weight"
treatment_0    <- "Male"
treatment_1    <- "Female"

# Only at a few points
trans_q_sex_bw <- 
  transport_quantile(x_0 = x_0_birth_weight, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)

# At more points
probs <- seq(0,1,length=601)
# Quantiles for non treated
x_0_probs_sex_bw <- quantile(pull(base, x_m_name)[pull(base, treatment_name) == treatment_0], probs = probs)
# Transported values
trans_q_sex_bw_probs <- 
  transport_quantile(x_0 = x_0_probs_sex_bw, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)
target         <- "nonnatural_delivery"
treatment_name <- "sex"
x_m_name       <- "wtgain"
treatment_0    <- "Male"
treatment_1    <- "Female"

# Only at a few points
trans_q_sex_wg <- 
  transport_quantile(x_0 = x_0_weight_gain, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)

# At more points
probs <- seq(0,1,length=601)
# Quantiles for non treated
x_0_probs_sex_wg <- quantile(pull(base, x_m_name)[pull(base, treatment_name) == treatment_0], probs = probs)
# Transported values
trans_q_sex_wg_probs <- 
  transport_quantile(x_0 = x_0_probs_sex_wg, x_m_name = x_m_name, 
                     treatment_name = treatment_name, 
                     treatment_0 = treatment_0, treatment_1 = treatment_1, data = base)

We can visualize the transported values from the control group to the treated group. To do so, let us define a function.

Show the R codes
#' @param x_0 vector of values
#' @param x_0_t vector of transported values
#' @param data data frame with the observations
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
#' @param x_m_name name of the mediator variable
#' @param distrib_plot_type type of visualization for the distribution in each subset ("density", "histogram", "none")
#' @param breaks breaks for histogram (if `distrib_plot_type` is not "none")
#' @param xlim limits on the x-axis
#' @param x_c_label label of the mediator variable
#' @param treated_label label of the treated variable
#' @param non_treated_label label of the non-treated variable
plot_transport_univariate <- function(x_0, x_0_t, data,
         treatment_name, treatment_0, treatment_1, x_m_name,
         distrib_plot_type = c("density", "histogram", "none"),
         breaks = NULL,
         xlim, x_c_label, treated_label, non_treated_label){

  distrib_plot_type <- match.arg(distrib_plot_type)

  # Subset of non treated
  d_0 <- filter(data, !!sym(treatment_name) == treatment_0)
  # Subset of treated
  d_1 <- filter(data, !!sym(treatment_name) == treatment_1)

  if(distrib_plot_type == "density"){
    # Density for non-treated
    dens_0 <- density(pull(d_0, x_m_name))
    # Density for treated
    dens_1 <- density(pull(d_1, x_m_name))
  }else if(distrib_plot_type == "histogram"){
    hist_0 <- hist(pull(d_0, x_m_name), breaks = breaks, plot = FALSE)
    hist_1 <- hist(pull(d_1, x_m_name), breaks = breaks, plot = FALSE)
  }

  if(distrib_plot_type != "none"){
    mat <- matrix(c(1,2,0,3), 2)
    par(mfrow = c(2,2))
    layout(mat, c(3.5,1), c(1,3))
    par(mar = c(0.5, 4.5, 0.5, 0.5))
  }

  if(distrib_plot_type == "density"){
    # Density of the mediator variable on the subset of non treated (green)
    plot(dens_0$x, dens_0$y,
         main="", axes=FALSE, xlab="", ylab="",
         xlim = xlim, col="white")
    polygon(dens_0$x, dens_0$y, col = couleur1, border = NA)
  }else if(distrib_plot_type == "histogram"){
    # Histogram of the mediator variable on the subset of non treated (green)
    barplot(hist_0$counts[hist_0$breaks >= min(xlim) & hist_0$breaks <= max(xlim)],
            width=5, space=0, main="",
            axes=FALSE,xlab="", ylab="",
            xlim=xlim, col=couleur1, border="white")

  }
  
  par(mar=c(4.5, 4.5, 0.5, 0.5))
  # Optimal transport (quantile based)
  plot(x_0, x_0_t,
       col = couleur3, lwd = 2,
       type = "l",
       xlab = "",
       ylab = "",
       xlim = xlim, ylim = xlim)
  abline(a = 0, b = 1, col = couleur3, lty = 2)
  
  ylab_1 <- bquote(.(x_c_label) * " (" *phantom(.(non_treated_label)) * ")")
  ylab_2 <- bquote(phantom(.(x_c_label)) * phantom(" (") *.(non_treated_label) * phantom(")"))
  mtext(ylab_1, side=1,line=3, col = "black")
  mtext(ylab_2, side=1,line=3, col = couleur1)
  
  xlab_1 <- bquote(.(x_c_label) * " (" *phantom(.(treated_label)) * ")")
  xlab_2 <- bquote(phantom(.(x_c_label)) * phantom(" (") *.(treated_label) *phantom(")"))
  mtext(xlab_1, side=2,line=3, col = "black")
  mtext(xlab_2, side=2,line=3, col = couleur2)

  if(distrib_plot_type == "density"){
    # Density of the mediator variable the subset of treated (orange)
    par(mar = c(4.5, 0.5, 0.5, 0.5))
    plot(dens_1$y, dens_1$x,
         main="", axes=FALSE, xlab="", ylab="",
         ylim = xlim, col="white")
    polygon(dens_1$y, dens_1$x, col = couleur2, border = NA)
  }else if(distrib_plot_type == "histogram"){
    # Histogram of the mediator variable the subset of treated (orange)
    par(mar = c(4.5, 0.5, 0.5, 0.5))
    barplot(hist_1$counts[hist_1$breaks >= min(xlim) & hist_1$breaks <= max(xlim)],
            width=5, space=0, main="",
            axes=FALSE,xlab="", ylab="",
            ylim=xlim, col=couleur2, border="white", horiz=TRUE)
  }
}
Show the R codes
plot_transport_univariate(x_0 = trans_q_smoker_bw_probs$x_0, 
                          x_0_t = trans_q_smoker_bw_probs$x_t_quantile, 
                          data = base, 
                          treatment_name = "cig_rec", 
                          treatment_0 = "No", treatment_1 = "Yes", 
                          x_m_name = "birth_weight", distrib_plot_type = "density",
                          breaks = NULL, xlim = c(1800,4600), 
                          x_c_label = "Weight of the baby",
                          treated_label = "Smoker mother", non_treated_label = "non-smoker mother")

Figure 1. Transported newborn weights from the control group (non-smoking mothers) to the treated group (smoking mothers), with estimated densities