@@ -157,12 +157,12 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
157
157
cost associated to the optimal transportation
158
158
"""
159
159
cdef double cost = 0.
160
- cdef int n = u_weights.shape[0 ]
161
- cdef int m = v_weights.shape[0 ]
160
+ cdef Py_ssize_t n = u_weights.shape[0 ]
161
+ cdef Py_ssize_t m = v_weights.shape[0 ]
162
162
163
- cdef int i = 0
163
+ cdef Py_ssize_t i = 0
164
164
cdef double w_i = u_weights[0 ]
165
- cdef int j = 0
165
+ cdef Py_ssize_t j = 0
166
166
cdef double w_j = v_weights[0 ]
167
167
168
168
cdef double m_ij = 0.
@@ -171,8 +171,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
171
171
dtype = np.float64)
172
172
cdef np.ndarray[long , ndim= 2 , mode= " c" ] indices = np.zeros((n + m - 1 , 2 ),
173
173
dtype = np.int)
174
- cdef int cur_idx = 0
175
- while i < n and j < m :
174
+ cdef Py_ssize_t cur_idx = 0
175
+ while True :
176
176
if metric == ' sqeuclidean' :
177
177
m_ij = (u[i] - v[j]) * (u[i] - v[j])
178
178
elif metric == ' cityblock' or metric == ' euclidean' :
@@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
188
188
indices[cur_idx, 0 ] = i
189
189
indices[cur_idx, 1 ] = j
190
190
i += 1
191
+ if i == n:
192
+ break
191
193
w_j -= w_i
192
194
w_i = u_weights[i]
193
195
else :
@@ -196,7 +198,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
196
198
indices[cur_idx, 0 ] = i
197
199
indices[cur_idx, 1 ] = j
198
200
j += 1
201
+ if j == m:
202
+ break
199
203
w_i -= w_j
200
204
w_j = v_weights[j]
201
205
cur_idx += 1
206
+ cur_idx += 1
202
207
return G[:cur_idx], indices[:cur_idx], cost
0 commit comments