Skip to content

milenagazdieva/U-NOTBarycenters

Repository files navigation

Robust Barycenter Estimation using Semi-unbalanced Neural Optimal Transport

This is the official Python implementation of the ICLR 2025 paper Robust Barycenter Estimation using Semi-unbalanced Neural Optimal Transport by Milena Gazdieva, Jaemoo Choi, Alexander Kolesov, Jewoong Choi, Petr Mokrov, Alexander Korotin.

The repository contains reproducible PyTorch source code for estimating the robust continuous barycenter of distributions, leveraging the dual formulation of the (semi-)unbalanced optimal transport (OT) problem. The method (U-NOTB) is based on the usage of neural networks and can be applied in high dimensions. We experimentally show that U-NOTB allows for computing barycenters which are robust to potential outliers and class imbalances in the distributions. Examples are provided for toy (2D) problems and high-dimensional experiments with manipulating images.

Related repositories

Citation

@article{gazdieva2024robust,
  title={Robust barycenter estimation using semi-unbalanced neural optimal transport},
  author={Gazdieva, Milena and Choi, Jaemoo and Kolesov, Alexander and Choi, Jaewoong and Mokrov, Petr and Korotin, Alexander},
  journal={International Conference on Learning Representations},
  year={2025}
}

Shape-color Experiment

We illustrate one interesting example demonstrating how our U-NOTB approach can be used to calculate the barycenter of distributions in the case of general costs and presence of outliers in the input distributions. We consider the problem of calculating semi-unbalanced barycenter of distributions of colors ('green'&'red', $\mathbb{P}_1$), and grayscale images of digits ('2'&'3', $\mathbb{P}_2$) in the latent space of a StyleGAN model pretrained on colored MNIST images of digits '2', '3'. To complexify the exposition, we include a small portion (1%) of outliers both in datasets of colors ('white' color) and grayscales digits (digit '7').

For the first marginal sample $x_1\sim \mathbb{P}_1$ and its corresponding barycenter point $y_1$, we use the following shape-preserving cost: $c_1(x_1, y_1) = \frac{1}{2} | x_1 - H_g (y_1)|^2$, where $H_g$ is a decolorization operator. For the second marginal sample $x_2 \sim \mathbb{P}_2$ and the corresponding barycenter point $y_2$, we use the following color-preserving cost: $c_2(x_2, y_2) = \frac{1}{2} | x_2 - H_c (y_2)|^2$, where $H_c$ is a color projection operator.

In this setup, classic OT barycenter of clean datasets (w/o outliers) of colors and shapes corresponds to 'green' and 'red images of digits '2' and '3'. However, the outliers present in these distributions negatively affect the classic barycenter which includes the unnecessary points. Fortunately, our solver allows for successful elimination of outliers in the input distributions.

Examples of grayscale digits, colors and their corresponding barycenter points are showm. The calculated acceptance rates show that U-NOTB solver successfully eliminates the outlier points.

Repository structure

The implementation is GPU-based. One GPU RTX 2080Ti is enough to run each of the experiments. Most of the experiments are issued in the form of self-explanatory jupyter notebooks (notebooks/). For clarity, most of the evaluation output is preserved. Additional source code in located in .py files from src/, torch_utils/, dnnlib/ folders and in legacy.py file from the root folder.

  • notebooks/Source-fixed-UOTB.ipynb - toy experiment in Section 5.1.
  • notebooks/Toy_Outlier.ipynb - toy experiment on the robustness to outliers (Section 5.2).
  • notebooks/Toy_class_imbalance.ipynb - toy experiment on the robustness to class imbalancedness (Section 5.2).
  • notebooks/StyleGAN.ipynb - StyleGAN based Shape-Color Experiment on Colored-MNIST (Section 5.3).

Moreover, train.py runs the image-to-image translation experiment from young-to-old human faces (Appendix D.3). To run this experiment, you need to download the data from drive link and locate it in a new data in the root directory.

Checkpoints for the trained models also could be found in .pretrained_models and train_logs folder from drive link.

Credits

About

PyTorch implementation of "Robust Barycenter Estimation using Semi-unbalanced Neural Optimal Transport" (ICLR 2025)

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages