Optimal Transport on Categorical Data for Counterfactuals using Compositional Data and Dirichlet Transport

optimal transport
compositional data
fairness
arxiv
working paper
Arthur
Agathe
Author

Ewen Gallic

Published

January 28, 2025

Our working paper titled Optimal Transport on Categorical Data for Counterfactuals using Compositional Data and Dirichlet Transport, co-authored with Agathe Fernandes Machado, Arthur Charpentier, is now online!

The working paper is available here:

Abstract

Recently, optimal transport-based approaches have gained attention for deriving counterfactuals, e.g., to quantify algorithmic discrimination. However, in the general multivariate setting, these methods are often opaque and difficult to interpret. To address this, alternative methodologies have been proposed, using causal graphs combined with iterative quantile regressions (Plečko and Meinshausen (2020)) or sequential transport (Fernandes Machado et al. (2025)) to examine fairness at the individual level, often referred to as “counterfactual fairness.” Despite these advancements, transporting categorical variables remains a significant challenge in practical applications with real datasets. In this paper, we propose a novel approach to address this issue. Our method involves (1) converting categorical variables into compositional data and (2) transporting these compositions within the probabilistic simplex of \(\mathbb{R}^d\). We demonstrate the applicability and effectiveness of this approach through an illustration on real-world data, and discuss limitations.

A replication ebook is available on Agathe’ Github:

The corresponding R codes are also available on Agathe’s Github:

1 Objectives

We are interested in deriving counterfactuals for categorical data \(\mathbf{x}\) with \(d\) categories. In a first step, the categorical data is converted into compositional data. Then, we use optimal transport to build counterfactuals. We propose two methods:

  1. The first method consists in using Gaussian optimal transport based on an alternative representation of the probability vector (in the Euclidean space \(\mathbb{R}^{d-1}\)).
  2. The second method uses transport and matching directly within the simplex \(\mathcal{S}_d\) using an appropriate cost function.
Figure 1: Optimal Transport using clr transform. Points in red are compositions for women, whereas points in blue are compositions for men. The lines indicate the displacement interpolation when generating counterfactuals.

2 Small Package

We defined some of the functions used in this ebook in a small R package, {transportsimplex}, which can be downloaded from the github repository associated with the paper.

To install the package:

remotes::install_github(repo = "fer-agathe/transport-simplex")

Then, the package can be loaded as follows:

library(transportsimplex)

The following small examples show how to use the package:

# First three columns: probabilities of being of class A, B, or C.
# Last column: group (0 or 1)
data(toydataset)
X0 <- toydataset[toydataset$group == 0, c("A", "B", "C")]
X1 <- toydataset[toydataset$group == 1, c("A", "B", "C")]

# Method 1: Gaussian OT in the Euclidean Space
# --------------------------------------------
# Transport only, from group 0 to group 1, using centered log ratio transform:
transp <- transport_simplex(X0 = X0, X1 = X1, isomorphism = "clr")

# If we want to transport new points:
new_obs <- data.frame(A = c(.2, .1), B = c(.6, .5), C = c(.2, .4))
transport_simplex_new(transport = transp, newdata = new_obs)

# If we want to get interpolated values using McCann (1997) displacement
# interpolation: (here, with 31 intermediate points)
transp_with_interp <- transport_simplex(
  X0 = X0, X1 = X1, isomorphism = "clr", n_interp = 31
)
interpolated(transp_with_interp)[[1]] # first obs
interpolated(transp_with_interp)[[2]] # second obs

# And displacement interpolation for the new obs:
transp_new_obs_with_interp <- transport_simplex_new(
  transport = transp, newdata = new_obs, n_interp = 5
)
interpolated(transp_new_obs_with_interp)[[1]] # first new obs
interpolated(transp_new_obs_with_interp)[[1]] # second new obs

# Method 2: Optimal Transport within the simplex
# ----------------------------------------------
# Optimal Transport using Linear Programming:
mapping <- wasserstein_simplex(as.matrix(X0), as.matrix(X1))
# The counterfactuals of observations of group 0 in group 1
counterfactuals_0_1 <- counterfactual_w(mapping, X0, X1)