Optimal Transport on Categorical Data for Counterfactuals using Compositional Data and Dirichlet Transport
Our working paper titled Optimal Transport on Categorical Data for Counterfactuals using Compositional Data and Dirichlet Transport, co-authored with Agathe Fernandes Machado, Arthur Charpentier, is now online!
The working paper is available here: Working paper
Recently, optimal transport-based approaches have gained attention for deriving counterfactuals, e.g., to quantify algorithmic discrimination. However, in the general multivariate setting, these methods are often opaque and difficult to interpret. To address this, alternative methodologies have been proposed, using causal graphs combined with iterative quantile regressions (Plečko and Meinshausen (2020)) or sequential transport (Fernandes Machado et al. (2025)) to examine fairness at the individual level, often referred to as “counterfactual fairness.” Despite these advancements, transporting categorical variables remains a significant challenge in practical applications with real datasets. In this paper, we propose a novel approach to address this issue. Our method involves (1) converting categorical variables into compositional data and (2) transporting these compositions within the probabilistic simplex of \(\mathbb{R}^d\). We demonstrate the applicability and effectiveness of this approach through an illustration on real-world data, and discuss limitations.
A replication ebook is available on Agathe’ Github: Replication codes (ebook)
The corresponding R codes are also available on Agathe’s Github: GitHub
1 Objectives
We are interested in deriving counterfactuals for categorical data \(\mathbf{x}\) with \(d\) categories. In a first step, the categorical data is converted into compositional data. Then, we use optimal transport to build counterfactuals. We propose two methods:
- The first method consists in using Gaussian optimal transport based on an alternative representation of the probability vector (in the Euclidean space \(\mathbb{R}^{d-1}\)).
- The second method uses transport and matching directly within the simplex \(\mathcal{S}_d\) using an appropriate cost function.
2 Small Package
We defined some of the functions used in this ebook in a small R package, {transportsimplex}, which can be downloaded from the github repository associated with the paper.
To install the package:
::install_github(repo = "fer-agathe/transport-simplex") remotes
Then, the package can be loaded as follows:
library(transportsimplex)
The following small examples show how to use the package:
# First three columns: probabilities of being of class A, B, or C.
# Last column: group (0 or 1)
data(toydataset)
<- toydataset[toydataset$group == 0, c("A", "B", "C")]
X0 <- toydataset[toydataset$group == 1, c("A", "B", "C")]
X1
# Method 1: Gaussian OT in the Euclidean Space
# --------------------------------------------
# Transport only, from group 0 to group 1, using centered log ratio transform:
<- transport_simplex(X0 = X0, X1 = X1, isomorphism = "clr")
transp
# If we want to transport new points:
<- data.frame(A = c(.2, .1), B = c(.6, .5), C = c(.2, .4))
new_obs transport_simplex_new(transport = transp, newdata = new_obs)
# If we want to get interpolated values using McCann (1997) displacement
# interpolation: (here, with 31 intermediate points)
<- transport_simplex(
transp_with_interp X0 = X0, X1 = X1, isomorphism = "clr", n_interp = 31
)interpolated(transp_with_interp)[[1]] # first obs
interpolated(transp_with_interp)[[2]] # second obs
# And displacement interpolation for the new obs:
<- transport_simplex_new(
transp_new_obs_with_interp transport = transp, newdata = new_obs, n_interp = 5
)interpolated(transp_new_obs_with_interp)[[1]] # first new obs
interpolated(transp_new_obs_with_interp)[[1]] # second new obs
# Method 2: Optimal Transport within the simplex
# ----------------------------------------------
# Optimal Transport using Linear Programming:
<- wasserstein_simplex(as.matrix(X0), as.matrix(X1))
mapping # The counterfactuals of observations of group 0 in group 1
<- counterfactual_w(mapping, X0, X1) counterfactuals_0_1