Skip to content

[MRG] Correct pointer overflow in EMD #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ POT provides the following generic OT solvers (links to examples):
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
* Weak OT solver between empirical distributions [39]
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from Graph Dictionary Learning [38]
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
* [Stochastic
solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and
Expand Down
5 changes: 4 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
- Fixed an issue where Sinkhorn solver assumed a symmetric cost matrix (Issue #374, PR #375)
- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status (PR #377)
- Fixed an issue where the metric argument in ot.dist did not allow a callable parameter (Issue #378, PR #379)
- Fixed an issue where the max number of iterations in ot.emd was not allow to go beyond 2^31 (PR #380)
- Fixed an issue where the max number of iterations in ot.emd was not allowed to go beyond 2^31 (PR #380)
- Fixed an issue where pointers would overflow in the EMD solver, returning an
incomplete transport plan above a certain size (slightly above 46k, its square being
roughly 2^31) (PR #381)


## 0.8.2
Expand Down
36 changes: 18 additions & 18 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// beware M and C are stored in row major C style!!!

using namespace lemon;
int n, m, cur;
uint64_t n, m, cur;

typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
Expand All @@ -51,15 +51,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,

// Define the graph

std::vector<int> indI(n), indJ(m);
std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter);

// Set supply and demand, don't account for 0 values (faster)

cur=0;
for (int i=0; i<n1; i++) {
for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
Expand All @@ -70,7 +70,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...

cur=0;
for (int i=0; i<n2; i++) {
for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
Expand All @@ -79,12 +79,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
}


net.supplyMap(&weights1[0], n, &weights2[0], m);
net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);

// Set the cost of each edge
int64_t idarc = 0;
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) {
for (uint64_t i=0; i<n; i++) {
for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
Expand All @@ -95,7 +95,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm

int ret=net.run();
int i, j;
uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
Expand Down Expand Up @@ -126,7 +126,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// beware M and C are stored in row major C style!!!

using namespace lemon_omp;
int n, m, cur;
uint64_t n, m, cur;

typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);
Expand All @@ -153,15 +153,15 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,

// Define the graph

std::vector<int> indI(n), indJ(m);
std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter, numThreads);

// Set supply and demand, don't account for 0 values (faster)

cur=0;
for (int i=0; i<n1; i++) {
for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
Expand All @@ -172,7 +172,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Demand is actually negative supply...

cur=0;
for (int i=0; i<n2; i++) {
for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
Expand All @@ -181,12 +181,12 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
}


net.supplyMap(&weights1[0], n, &weights2[0], m);
net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);

// Set the cost of each edge
int64_t idarc = 0;
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) {
for (uint64_t i=0; i<n; i++) {
for (uint64_t j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
Expand All @@ -197,7 +197,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
// Solve the problem with the network simplex algorithm

int ret=net.run();
int i, j;
uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
Expand Down
4 changes: 2 additions & 2 deletions ot/lp/network_simplex_simple_omp.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
#undef EPSILON
#undef _EPSILON
#undef MAX_DEBUG_ITER
#define EPSILON std::numeric_limits<Cost>::epsilon()*10
#define _EPSILON 1e-8
#define EPSILON std::numeric_limits<Cost>::epsilon()
#define _EPSILON 1e-14
#define MAX_DEBUG_ITER 100000

/// \ingroup min_cost_flow_algs
Expand Down