From daec9fe15f9728080b54a7ddbfdb67075e78c6bd Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 5 May 2020 13:14:35 +0100 Subject: [PATCH 1/3] break before exceeding array size --- ot/lp/emd_wrap.pyx | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index c16796441..e9e8fba57 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -172,7 +172,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), dtype=np.int) cdef int cur_idx = 0 - while i < n and j < m: + while True: if metric == 'sqeuclidean': m_ij = (u[i] - v[j]) * (u[i] - v[j]) elif metric == 'cityblock' or metric == 'euclidean': @@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j i += 1 + if i == n: + break w_j -= w_i w_i = u_weights[i] else: @@ -196,6 +198,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j j += 1 + if j == m: + break w_i -= w_j w_j = v_weights[j] cur_idx += 1 From ea2890aa3cfbf09a32f8ef3063b6a413f485526b Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 5 May 2020 13:19:13 +0100 Subject: [PATCH 2/3] Some improvements for platform compatibility --- ot/lp/emd_wrap.pyx | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index e9e8fba57..10bc5cf6e 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -157,12 +157,12 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cost associated to the optimal transportation """ cdef double cost = 0. - cdef int n = u_weights.shape[0] - cdef int m = v_weights.shape[0] + cdef Py_ssize_t n = u_weights.shape[0] + cdef Py_ssize_t m = v_weights.shape[0] - cdef int i = 0 + cdef Py_ssize_t i = 0 cdef double w_i = u_weights[0] - cdef int j = 0 + cdef Py_ssize_t j = 0 cdef double w_j = v_weights[0] cdef double m_ij = 0. @@ -171,7 +171,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, dtype=np.float64) cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), dtype=np.int) - cdef int cur_idx = 0 + cdef Py_ssize_t cur_idx = 0 while True: if metric == 'sqeuclidean': m_ij = (u[i] - v[j]) * (u[i] - v[j]) From ea6642c4873b557b4d284f6f3717d8990e23ad51 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 5 May 2020 13:37:08 +0100 Subject: [PATCH 3/3] fix failing test - cur_idx needs to be incremented by 1 after the loop --- ot/lp/emd_wrap.pyx | 1 + 1 file changed, 1 insertion(+) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 10bc5cf6e..d79d0ca7c 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -203,4 +203,5 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, w_i -= w_j w_j = v_weights[j] cur_idx += 1 + cur_idx += 1 return G[:cur_idx], indices[:cur_idx], cost