Skip to content

Commit ae0470d

Browse files
authored
Merge pull request #170 from AdrienCorenflos/master
fix array bounds issue
2 parents 94d5c8c + ea6642c commit ae0470d

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

ot/lp/emd_wrap.pyx

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,12 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
157157
cost associated to the optimal transportation
158158
"""
159159
cdef double cost = 0.
160-
cdef int n = u_weights.shape[0]
161-
cdef int m = v_weights.shape[0]
160+
cdef Py_ssize_t n = u_weights.shape[0]
161+
cdef Py_ssize_t m = v_weights.shape[0]
162162

163-
cdef int i = 0
163+
cdef Py_ssize_t i = 0
164164
cdef double w_i = u_weights[0]
165-
cdef int j = 0
165+
cdef Py_ssize_t j = 0
166166
cdef double w_j = v_weights[0]
167167

168168
cdef double m_ij = 0.
@@ -171,8 +171,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
171171
dtype=np.float64)
172172
cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
173173
dtype=np.int)
174-
cdef int cur_idx = 0
175-
while i < n and j < m:
174+
cdef Py_ssize_t cur_idx = 0
175+
while True:
176176
if metric == 'sqeuclidean':
177177
m_ij = (u[i] - v[j]) * (u[i] - v[j])
178178
elif metric == 'cityblock' or metric == 'euclidean':
@@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
188188
indices[cur_idx, 0] = i
189189
indices[cur_idx, 1] = j
190190
i += 1
191+
if i == n:
192+
break
191193
w_j -= w_i
192194
w_i = u_weights[i]
193195
else:
@@ -196,7 +198,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
196198
indices[cur_idx, 0] = i
197199
indices[cur_idx, 1] = j
198200
j += 1
201+
if j == m:
202+
break
199203
w_i -= w_j
200204
w_j = v_weights[j]
201205
cur_idx += 1
206+
cur_idx += 1
202207
return G[:cur_idx], indices[:cur_idx], cost

0 commit comments

Comments
 (0)