\[ \usepackage{dsfont} \usepackage{xcolor} \require{mathtools} \definecolor{bayesred}{RGB}{147, 30, 24} \definecolor{bayesblue}{RGB}{32, 35, 91} \definecolor{bayesorange}{RGB}{218, 120, 1} \definecolor{grey}{RGB}{128, 128, 128} \definecolor{couleur1}{RGB}{0,163,137} \definecolor{couleur2}{RGB}{255,124,0} \definecolor{couleur3}{RGB}{0, 110, 158} \definecolor{coul1}{RGB}{255,37,0} \definecolor{coul2}{RGB}{242,173,0} \definecolor{col_neg}{RGB}{155, 191, 221} \definecolor{col_pos}{RGB}{255, 128, 106} {\color{bayesorange} P (\text{H} \mid \text{E})} = \frac {{\color{bayesred} P(\text{H})} \times {\color{bayesblue}P(\text{E} \mid \text{H})}} {\color{grey} {P(\text{E})}} \]
1 Data
To illustrate the methods, we rely on the Linked Birth/Infant Death Cohort Data (2013 cohort). The CSV files for the 2013 cohort were downloaded from the NBER collection of Birth Cohort Linked Birth and Infant Death Data of the National Vital Statistics System of the National Center for Health Statistics, on the NBER website. We formated the data (see ./births_stats.html).
Let us load the birth data.
library(tidyverse)
load("../data/births.RData")
Then, we can only keep a subsample of variables to illustrate the method.
<-
base %>%
births mutate(
black_mother = mracerec == "Black",
nonnatural_delivery = rdmeth_rec != "Vaginal"
%>%
) select(
sex, dbwt, cig_rec, wtgain,
black_mother, nonnatural_delivery,
mracerec, rdmeth_rec%>%
) mutate(
cig_rec = replace_na(cig_rec, "Unknown or not stated"),
black_mother = ifelse(black_mother, yes = "Yes", no = "No"),
is_girl = ifelse(sex == "Female", yes = "Yes", no = "No")
%>%
) ::set_variable_labels(
labelledblack_mother = "Is the mother Black?",
is_girl = "Is the newborn a girl?",
nonnatural_delivery = "Is the delivery method non-natural?"
%>%
) rename(birth_weight = dbwt)
base
# A tibble: 3,940,764 × 9
sex birth_weight cig_rec wtgain black_mo…¹ nonna…² mrace…³ rdmet…⁴ is_girl
<fct> <dbl> <fct> <dbl> <chr> <lgl> <fct> <fct> <chr>
1 Male 3017 No 1 No FALSE "White" Vaginal No
2 Female 3367 No 18 No FALSE "White" Vaginal Yes
3 Male 4409 No 36 No FALSE "White" Vaginal No
4 Male 4047 No 25 No FALSE "White" Vaginal No
5 Male 3119 Yes 0 No FALSE "Ameri… Vaginal No
6 Female 4295 Yes 34 No FALSE "White" Vaginal Yes
7 Male 3160 No 47 No FALSE "Ameri… Vaginal No
8 Male 3711 No 39 No FALSE "White" Vaginal No
9 Male 3430 No NA No FALSE "White" Vaginal No
10 Female 3772 No 36 No FALSE "White" Vaginal Yes
# … with 3,940,754 more rows, and abbreviated variable names ¹black_mother,
# ²nonnatural_delivery, ³mracerec, ⁴rdmeth_rec
The variables that were kept are the following:
sex
: sex of the newbornbirth_weight
: Birth Weight (in Grams)cig_rec
: mother smokes cigaretteswtgain
: mother weight gainmracerec
: mother’s raceblack_mother
: is mother black?rdmeth_rec
: delivery methodnonnatural_delivery
: is the delivery method non-natural?
Then, let us discard individuals with missing values.
<-
base %>%
base filter(
!is.na(birth_weight),
!is.na(wtgain),
!is.na(nonnatural_delivery),
!is.na(black_mother)
)
Let us define some colours for later use:
library(wesanderson)
<- wes_palette("Darjeeling1")
colr1 <- wes_palette("Darjeeling2")
colr2 <- colr1[2]
couleur1 <- colr1[4]
couleur2 <- colr2[2]
couleur3
<- "#882255"
coul1 <- "#DDCC77" coul2
We need to load some packages.
library(car)
library(transport)
library(splines)
library(mgcv)
library(expm)
library(tidyverse)
2 Univariate Case
Let us first consider the univariate case.
2.1 Objective
The outcome is assumed to depend on the treatment and on a single mediator variable \(\boldsymbol{x}^m\) (the latter is assumed to be also influenced by the treatment).
Variable | Name | Description |
---|---|---|
Output (\(y\)) | nonnatural_delivery |
Probability of having a non-natural delivery |
Treatment (\(T\)) | cig_rec |
Whether the mother smokes \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\) |
Mediator (\(\boldsymbol{x}^m\)) | birth_weight |
Birth weight of the newborn |
Let us print some summary statistics of the variables considered.
library(gtsummary)
%>%
base select(nonnatural_delivery, cig_rec, birth_weight) %>%
tbl_summary(
by = cig_rec,
type = all_continuous() ~ "continuous2",
statistic = list(
all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}, {p75})"),
all_categorical() ~ "{n} ({p}%)"),
digits = list(
all_continuous() ~ 2,
all_categorical() ~ 0
),missing_text = "Missing value"
%>%
) add_p() %>%
add_overall(col_label = "Whole sample") %>%
modify_header(label ~ "**Variable**") %>%
modify_spanning_header(c("stat_1", "stat_2") ~ "**Mother is a smoker**") %>%
add_stat_label(
label = list(
all_continuous() ~ c("Mean (Std)", "Median (IQR)"),
all_categorical() ~ "n (%)"
) )
Variable | Whole sample | Mother is a smoker | Unknown or not stated, N = 147,766 | p-value1 | |
---|---|---|---|---|---|
Yes, N = 273,685 | No, N = 2,959,847 | ||||
Is the delivery method non-natural?, n (%) | 1,159,776 (34%) | 94,580 (35%) | 1,013,995 (34%) | 51,201 (35%) | <0.001 |
Birth Weight (in Grams) | <0.001 | ||||
Mean (Std) | 3,276.01 (588.22) | 3,102.59 (595.08) | 3,292.17 (583.95) | 3,273.51 (608.59) | |
Median (IQR) | 3,317.00 (2,977.00, 3,635.00) | 3,147.00 (2,802.00, 3,472.00) | 3,330.00 (3,001.00, 3,657.00) | 3,317.00 (2,970.00, 3,657.00) | |
1 Pearson's Chi-squared test; Kruskal-Wallis rank sum test |
We will compute the SCATE at the following values of \(\boldsymbol{x}\):
<- seq(2000,4500, by=500)
x_0_birth_weight x_0_birth_weight
[1] 2000 2500 3000 3500 4000 4500
2.2 Models
Let us estimate \(\color{couleur1}{\widehat{m}_0(x)}\) using a GAM model, on the subset of the non-smokers:
<- glm(nonnatural_delivery ~ bs(birth_weight),
reg_0 data=base, family=binomial,subset = (cig_rec == "No"))
And \(\color{couleur2}{\widehat{m}_1(x)}\), on the subset of the smokers:
<- glm(nonnatural_delivery ~ bs(birth_weight),
reg_1 data=base, family=binomial,subset = (cig_rec == "Yes"))
2.3 Prediction Function
Let us define a prediction function for these types of models:
#' @param object regression model (GAM)
#' @param newdata data frame in which to look for the mediator variable used to predict the target
<- function(object, newdata){
model_spline_predict predict(object, newdata = newdata, type="response")
}
2.4 Transport
Let us now turn to the transport of the mediator variable.
2.4.1 Quantile-based
We define a function to transport the values of a mediator.
#' Quantile-based transport (Univariate)
#' @param x_0 vector of values to be transported
#' @param x_m_name name of the mediator variable
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
#' @param data data frame with both T=0 or T=1
<- function(x_0, x_m_name, treatment_name, treatment_0, treatment_1, data){
transport_quantile <- pull(data, treatment_name) == treatment_0
ind_0 <- pull(data, x_m_name)
x_val # Empirical Cumulative Distribution Function for values of the mediator variable for non-treated
<- ecdf(x_val)
Fn # Probability associated in the non-treated
<- Fn(x_0)
u # Transported_values
<- pull(data, x_m_name)[pull(data, treatment_name) == treatment_1]
x_1 <- quantile(x_1, u)
x_t_quantile
list(x_0 = x_0, u = u, x_t_quantile = x_t_quantile)
}
The transported values of \(\boldsymbol{x}\), i.e., \(\mathcal{T}(\boldsymbol{x})\)
<-
x_0_birth_weight_t_q transport_quantile(x_0 = x_0_birth_weight,
x_m_name = "birth_weight",
treatment_name = "cig_rec",
treatment_0 = "No",
treatment_1 = "Yes",
data = base)
x_0_birth_weight_t_q
$x_0
[1] 2000 2500 3000 3500 4000 4500
$u
[1] 0.02883478 0.07896465 0.26154364 0.65182572 0.92058079 0.98907431
$x_t_quantile
2.883478% 7.896465% 26.15436% 65.18257% 92.05808% 98.90743%
1800 2301 2815 3340 3850 4338
2.4.2 Gaussian Assumption
Now, let us consider the case in which the mediator is assumed to be Gaussian.
The transport function can be defined as follows:
#' @param x_0 vector of values to be transported
#' @param x_m_name name of the mediator variable
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1value for treated
#' @param data data frame with both T=0 or T=1
<- function(x_0, x_m_name, treatment_name, treatment_0, treatment_1, data){
transport_univ_gaussian <- mean(pull(data, x_m_name)[pull(data, treatment_name) == treatment_0])
x0 <- mean(pull(data, x_m_name)[pull(data, treatment_name) == treatment_1])
x1 <- sd(pull(data, x_m_name)[pull(data, treatment_name) == treatment_0])
s0 <- sd(pull(data, x_m_name)[pull(data, treatment_name) == treatment_1])
s1 <- pnorm(x_0,x0,s0)
u_N <- qnorm(u_N, x1, s1)
x_t_N list(x_0 = x_0, u_N = u_N, x_t_N = x_t_N)
}
The transported values of \(\boldsymbol{x}\), i.e., \(\mathcal{T}_\mathcal{N}(\boldsymbol{x})\):
<-
x_0_birth_weight_t_n transport_univ_gaussian(x_0 = x_0_birth_weight,
x_m_name = "birth_weight",
treatment_name = "cig_rec",
treatment_0 = "No",
treatment_1 = "Yes",
data = base)
x_0_birth_weight_t_n
$x_0
[1] 2000 2500 3000 3500 4000 4500
$u_N
[1] 0.01345452 0.08745575 0.30841597 0.63904204 0.88727139 0.98069826
$x_t_N
[1] 1785.778 2295.311 2804.845 3314.379 3823.913 4333.446
2.5 Estimation of the Sample Conditional Average Treatment Effect
Let us define a function that computes the Mutatis Mutandis Sample Conditional Average Effect:
#' Computes the Sample Average Treatment Effect with and without transport
#' @param x_0 vector of values at which to compute $\hat{m}_0(x_0)$, $\hat{m}_1(x_0)$
#' @param x_t vector of values at which to compute $\hat{m}_1(\mathcal{T}(x_t))$
#' @param x_m_name vector of names of the mediator variables
#' @param mod_0 model $\hat{m}_0$
#' @param mod_1 model $\hat{m}_1$
#' @param pred_mod_0 prediction function for model $\hat{m}_0(\cdot)$
#' @param pred_mod_1 prediction function for model $\hat{m}_1(\cdot)$
#' @param return_x if `TRUE` (default) the mediator variables are returned in the table, as well as their transported values
<- function(x_0, x_t, x_m_names, mod_0, mod_1, pred_mod_0, pred_mod_1, return_x = TRUE){
sate if(is.vector(x_0)){
# Univariate case
<- tibble(!!x_m_names := x_0)
new_data else{
}<- x_0
new_data
}if(is.vector(x_t)){
# Univariate case
<- tibble(!!x_m_names := x_t)
new_data_t else{
}<- x_t
new_data_t
}
# $\hat{m}_0(x_0)$
<- pred_mod_0(object = mod_0, newdata = new_data)
y_0 # $\hat{m}_1(x_0)$
<- pred_mod_1(object = mod_1, newdata = new_data)
y_1 # $\hat{m}_1(\mathcal{T}(x_0))$
<- pred_mod_1(object = mod_1, newdata = new_data_t)
y_1_t
<-
scate_tab tibble(y_0 = y_0, y_1 = y_1, y_1_t = y_1_t,
CATE = y_1-y_0, SCATE = y_1_t-y_0)
if(return_x){
<- new_data_t %>% rename_all(~paste0(.x, "_t"))
new_data_t <- bind_cols(new_data, new_data_t, scate_tab)
scate_tab
}
scate_tab }
The quantile-based SCATE (\(SCATE(\boldsymbol{x})\)), computed using GAM models for \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\):
<-
cate_q_gam_smoke sate(x_0 = x_0_birth_weight,
x_t = x_0_birth_weight_t_q$x_t_quantile,
x_m_names = "birth_weight",
mod_0 = reg_0,
mod_1 = reg_1,
pred_mod_0 = model_spline_predict,
pred_mod_1 = model_spline_predict)
cate_q_gam_smoke
# A tibble: 6 × 7
birth_weight birth_weight_t y_0 y_1 y_1_t CATE SCATE
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 2000 1800 0.512 0.468 0.506 -0.0441 -0.00553
2 2500 2301 0.404 0.378 0.412 -0.0255 0.00775
3 3000 2815 0.326 0.319 0.336 -0.00687 0.0101
4 3500 3340 0.301 0.309 0.306 0.00790 0.00488
5 4000 3850 0.349 0.369 0.342 0.0197 -0.00770
6 4500 4338 0.513 0.537 0.469 0.0250 -0.0436
The SCATE (\(SCATE(\boldsymbol{x})\)) assuming a Gaussian distribution of the mediator variable, computed using GAM models for \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\):
<-
cate_n_gam_smoke sate(x_0 = x_0_birth_weight,
x_t = x_0_birth_weight_t_n$x_t_N,
x_m_names = "birth_weight",
mod_0 = reg_0,
mod_1 = reg_1,
pred_mod_0 = model_spline_predict,
pred_mod_1 = model_spline_predict)
cate_n_gam_smoke
# A tibble: 6 × 7
birth_weight birth_weight_t y_0 y_1 y_1_t CATE SCATE
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 2000 1786. 0.512 0.468 0.509 -0.0441 -0.00282
2 2500 2295. 0.404 0.378 0.413 -0.0255 0.00876
3 3000 2805. 0.326 0.319 0.337 -0.00687 0.0112
4 3500 3314. 0.301 0.309 0.306 0.00790 0.00499
5 4000 3824. 0.349 0.369 0.338 0.0197 -0.0115
6 4500 4333. 0.513 0.537 0.467 0.0250 -0.0453
3 Multivariate Case
Now, let us turn to the multivariate case, where we consider two mediator variables: the birth weight and the weight gain of the mother.
3.1 Objective
The objective is the same as that in Section 2.
Variable | Name | Description |
---|---|---|
Output (\(y\)) | nonnatural_delivery |
Probability of having a non-natural delivery |
Treatment (\(T\)) | cig_rec |
Whether the mother smokes \(\color{couleur2}{t=1}\) or not \(\color{couleur1}{t=0}\) |
Mediators (\(\boldsymbol{x}^m\)) | birth_weight , wtgain |
Birth weight of the newborn, weight gain on the mother |
Let us print some summary statistics of the variables considered.
library(gtsummary)
%>%
base select(nonnatural_delivery, cig_rec, birth_weight, wtgain) %>%
tbl_summary(
by = cig_rec,
type = all_continuous() ~ "continuous2",
statistic = list(
all_continuous() ~ c("{mean} ({sd})", "{median} ({p25}, {p75})"),
all_categorical() ~ "{n} ({p}%)"),
digits = list(
all_continuous() ~ 2,
all_categorical() ~ 0
),missing_text = "Missing value"
%>%
) add_p() %>%
add_overall(col_label = "Whole sample") %>%
modify_header(label ~ "**Variable**") %>%
modify_spanning_header(c("stat_1", "stat_2") ~ "**Mother is a smoker**") %>%
add_stat_label(
label = list(
all_continuous() ~ c("Mean (Std)", "Median (IQR)"),
all_categorical() ~ "n (%)"
) )
Variable | Whole sample | Mother is a smoker | Unknown or not stated, N = 147,766 | p-value1 | |
---|---|---|---|---|---|
Yes, N = 273,685 | No, N = 2,959,847 | ||||
Is the delivery method non-natural?, n (%) | 1,159,776 (34%) | 94,580 (35%) | 1,013,995 (34%) | 51,201 (35%) | <0.001 |
Birth Weight (in Grams) | <0.001 | ||||
Mean (Std) | 3,276.01 (588.22) | 3,102.59 (595.08) | 3,292.17 (583.95) | 3,273.51 (608.59) | |
Median (IQR) | 3,317.00 (2,977.00, 3,635.00) | 3,147.00 (2,802.00, 3,472.00) | 3,330.00 (3,001.00, 3,657.00) | 3,317.00 (2,970.00, 3,657.00) | |
Weight Gain | <0.001 | ||||
Mean (Std) | 30.41 (15.04) | 30.55 (17.25) | 30.41 (14.80) | 30.17 (15.43) | |
Median (IQR) | 30.00 (20.00, 39.00) | 30.00 (19.00, 41.00) | 30.00 (21.00, 39.00) | 30.00 (20.00, 40.00) | |
1 Pearson's Chi-squared test; Kruskal-Wallis rank sum test |
We will compute the SCATE at the following values of \(\boldsymbol{x}\):
<- seq(1800, 4600, length = 251)
val_birth_weight <- seq(0, 90, length = 251)
val_wtgain <- expand.grid(wtgain = val_wtgain, birth_weight = val_birth_weight) %>% as_tibble()
val_grid val_grid
# A tibble: 63,001 × 2
wtgain birth_weight
<dbl> <dbl>
1 0 1800
2 0.36 1800
3 0.72 1800
4 1.08 1800
5 1.44 1800
6 1.8 1800
7 2.16 1800
8 2.52 1800
9 2.88 1800
10 3.24 1800
# … with 62,991 more rows
3.2 Models
Let us estimate \(\color{couleur1}{\widehat{m}_0(x)}\) using a GAM model, on the subset of the non-smokers:
<- glm(nonnatural_delivery ~ bs(birth_weight)+bs(wtgain),
reg_0 data=base, family=binomial,subset = (cig_rec == "No"))
And \(\color{couleur2}{\widehat{m}_1(x)}\), on the subset of the smokers:
<- glm(nonnatural_delivery ~ bs(birth_weight)+bs(wtgain),
reg_1 data=base, family=binomial,subset = (cig_rec == "Yes"))
3.3 Prediction Function
The prediction function is the same as that define in the univariate case:
#' @param object regression model (GAM)
#' @param newdata data frame in which to look for the mediator variable used to predict the target
<- function(object, newdata){
model_spline_predict predict(object, newdata = newdata, type="response")
}
3.4 Transport
Let us turn to the transport function.
3.4.1 Gaussian Assumption
We assume the mediator variables to be both Normally distributed. The parameters used to transport the mediator variables under the Gaussian assumption can be estimated thanks to the following function:
#' Optimal Transport assuming Gaussian distribution for the mediator variables (helper function)
#' @return A list with the mean and variance of the mediator in each subset, and the symmetric matrix A.
#' @param target name of the target variable
#' @param x_m_names vector of names of the mediator variables
#' @param scale vector of scaling to apply to each `x_m_names` variable to transport (default to 1)
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
<- function(target, x_m_names, scale=1, treatment_name, treatment_0, treatment_1){
transport_gaussian_param <-
base_0_unscaled %>%
base select(!!c(x_m_names, treatment_name)) %>%
filter(!!sym(treatment_name) == treatment_0)
<-
base_1_unscaled %>%
base select(!!c(x_m_names, treatment_name)) %>%
filter(!!sym(treatment_name) == treatment_1)
for(i in 1:length(scale)){
<-
base_0_scaled %>%
base_0_unscaled mutate(!!sym(x_m_names[i]) := !!sym(x_m_names[i]) * scale[i])
<-
base_1_scaled %>%
base_1_unscaled mutate(!!sym(x_m_names[i]) := !!sym(x_m_names[i]) * scale[i])
}
# Mean in each subset (i.e., T=0, T=1)
<- base_0_scaled %>% summarise(across(!!x_m_names, mean)) %>% as_vector()
m_0 <- base_1_scaled %>% summarise(across(!!x_m_names, mean)) %>% as_vector()
m_1 # Variance
<- base_0_scaled %>% select(!!x_m_names) %>% var()
S_0 <- base_1_scaled %>% select(!!x_m_names) %>% var()
S_1 # Matrix A
<- (solve(sqrtm(S_0))) %*% sqrtm( sqrtm(S_0) %*% S_1 %*% (sqrtm(S_0)) ) %*% solve(sqrtm(S_0))
A
list(m_0 = m_0, m_1 = m_1, S_0 = S_0, S_1 = S_1, A = A, scale = scale, x_m_names = x_m_names)
}
We create a function that transports a single observation, given the different parameters:
#' Gaussian transport for a single observation (helper function)
#' @param z vector of variables to transport (for a single observation)
#' @param A symmetric positive matrix that satisfies \(A\sigma_0A=\sigma_1\)
#' @param m_0,m_1 vectors of mean values in the subsets \(\mathcal{D}_0\) and \(\mathcal{D}_1\)
#' @param scale vector of scaling to apply to each variable to transport (default to 1)
<- function(z, A, m_0, m_1, scale = 1){
T_C_single <- z*scale
z as.vector(m_1 + A %*% (z-m_0))*(1/scale)
}
Lastly, we define a function that transports the mediator variables, under the Gaussian assumption. If the params
argument is NULL
, the arguments needed to estimate them through the transport_gaussian_param()
function are required (target
, x_m_names
, treatment_name
, treatment_0
, treatment_1
).
#' Gaussian transport
#' @param z data frame of variables to transport
#' @param params (optional) parameters to use for the transport (result of `transport_gaussian_param()`, default to NULL)
#' @param target name of the target variable
#' @param x_m_names vector of names of the mediator variables
#' @param scale vector of scaling to apply to each `x_m_names` variable to transport (default to 1)
#' @param treatment_name name of the treatment variable (column in `data`)
#' @param treatment_0 value for non treated
#' @param treatment_1 value for treated
<- function(z, params = NULL, ..., x_m_names, scale, treatment_name, treatment_0, treatment_1){
gaussian_transport if(is.null(params)){
# If the parameters of the transport function are not provided
# they need to be computed first
<-
params transport_gaussian_param(target = target, x_m_names = x_m_names,
scale = scale, treatment_name = treatment_name,
treatment_0 = treatment_0, treatment_1 = treatment_1)
}<- params$A
A <- params$m_0 ; m_1 <- params$m_1
m_0 <- params$scale
scale <- params$x_m_names
x_m_names
<- z %>% select(!!x_m_names)
values_to_transport <-
transported_val apply(values_to_transport, 1, T_C_single, A = A, m_0 = m_0, m_1 = m_1,
scale = scale, simplify = FALSE) %>%
do.call("rbind", .)
colnames(transported_val) <- colnames(z)
<- as_tibble(transported_val)
transported_val
structure(.Data = transported_val, params = params)
}
The transported values:
<-
val_grid_t_n gaussian_transport(z = val_grid, target = "nonnatural_delivery",
x_m_names = c("birth_weight", "wtgain"),
scale = c(1, 1/100), treatment_name = "cig_rec",
treatment_0 = "No", treatment_1 = "Yes")
head(val_grid_t_n)
# A tibble: 6 × 2
wtgain birth_weight
<dbl> <dbl>
1 1582. -6.71
2 1582. -6.29
3 1582. -5.88
4 1582. -5.46
5 1582. -5.04
6 1582. -4.63
The estimated parameters can be extracted as follows:
attr(val_grid_t_n, "params")
$m_0
birth_weight wtgain
3292.1742614 0.3040885
$m_1
birth_weight wtgain
3102.5903539 0.3055169
$S_0
birth_weight wtgain
birth_weight 340992.00650 14.54058206
wtgain 14.54058 0.02189568
$S_1
birth_weight wtgain
birth_weight 354119.68908 22.09039380
wtgain 22.09039 0.02976208
$A
[,1] [,2]
[1,] 1.0190674679 0.0000143175
[2,] 0.0000143175 1.1550371625
$scale
[1] 1.00 0.01
$x_m_names
[1] "birth_weight" "wtgain"
These can be used to transport other values, without estimating the parameters again:
gaussian_transport(z = tibble(birth_weight = 2500, wtgain = 2),
params = attr(val_grid_t_n, "params"))
# A tibble: 1 × 2
birth_weight wtgain
* <dbl> <dbl>
1 2295. -3.40
3.5 Estimation of the Sample Conditional Average Treatment Effect
Let us use the sate()
function that computes the Mutatis Mutandis Sample Condition Average Effect, defined in Section 2.5.
The Mutatis Mutandis SCATE (\(SCATE(\boldsymbol{x})\)), computed using GAM models for \(\color{couleur1}{\widehat{m}_0(x)}\) and \(\color{couleur2}{\widehat{m}_1(x)}\), under the Gaussian assumption can then be obtained as follows:
<-
mm_sate sate(
x_0 = val_grid,
x_t = val_grid_t_n,
x_m_names = c("wtgain", "birth_weight"),
mod_0 = reg_0,
mod_1 = reg_1,
pred_mod_0 = model_spline_predict, pred_mod_1 = model_spline_predict,
return = TRUE)
mm_sate
# A tibble: 63,001 × 9
wtgain birth_weight wtgain_t birth_weig…¹ y_0 y_1 y_1_t CATE SCATE
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 0 1800 1582. -6.71 0.589 0.536 2.22e-16 -0.0536 -0.589
2 0.36 1800 1582. -6.29 0.587 0.534 2.22e-16 -0.0531 -0.587
3 0.72 1800 1582. -5.88 0.585 0.533 2.22e-16 -0.0527 -0.585
4 1.08 1800 1582. -5.46 0.584 0.531 2.22e-16 -0.0523 -0.584
5 1.44 1800 1582. -5.04 0.582 0.530 2.22e-16 -0.0519 -0.582
6 1.8 1800 1582. -4.63 0.580 0.528 2.22e-16 -0.0515 -0.580
7 2.16 1800 1582. -4.21 0.578 0.527 2.22e-16 -0.0511 -0.578
8 2.52 1800 1582. -3.80 0.576 0.526 2.22e-16 -0.0508 -0.576
9 2.88 1800 1582. -3.38 0.575 0.524 2.22e-16 -0.0504 -0.575
10 3.24 1800 1582. -2.97 0.573 0.523 2.22e-16 -0.0500 -0.573
# … with 62,991 more rows, and abbreviated variable name ¹birth_weight_t